In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModel, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

class EdgeNetwork(nn.Module):
    def __init__(self, llm_backbone_name, num_edges):
        super(EdgeNetwork, self).__init__()
        self.llm_backbone = AutoModel.from_pretrained(llm_backbone_name)
        self.tokenizer = AutoTokenizer.from_pretrained(llm_backbone_name)
        self.linear = nn.Linear(self.llm_backbone.config.hidden_size, num_edges)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_text):
        input_ids = self.tokenizer.batch_encode_plus(input_text, padding=True, truncation=True, return_tensors='pt')['input_ids']
        llm_output = self.llm_backbone(input_ids)[0]
        edge_logits = self.linear(llm_output)
        edge_probs = self.sigmoid(edge_logits)
        return edge_probs

def reinforcement_loss(edge_probs, target_graph):
    # Compute the loss based on the generated edge probabilities and the target graph
    # Here, target_graph is a binary matrix indicating the presence of edges
    # Loss penalizes for creating graphs that are not directed acyclic graphs
    num_nodes = target_graph.size(0)
    eye = torch.eye(num_nodes).unsqueeze(0).to(target_graph.device)
    graph_matrix = torch.matmul(edge_probs, edge_probs.t())
    loss = (torch.triu(graph_matrix, diagonal=1) * target_graph).sum() - \
           (torch.triu(graph_matrix * target_graph, diagonal=1)).sum() + \
           ((graph_matrix * eye) > 0).sum()
    return loss

In [3]:
# Example usage
llm_backbone_name = "bert-base-uncased"
num_edges = 10  # Replace with the actual number of possible edges
edge_network = EdgeNetwork(llm_backbone_name, num_edges)
optimizer = optim.Adam(edge_network.parameters(), lr=0.001)

# Training loop
for epoch in range(num_epochs):
    optimizer.zero_grad()
    input_text = [...]  # Your input text list
    target_graph = torch.tensor(...)  # Your target graph tensor
    edge_probs = edge_network(input_text)
    loss = reinforcement_loss(edge_probs, target_graph)
    loss.backward()
    optimizer.step()

config.json: 100%|██████████| 570/570 [00:00<00:00, 313kB/s]
  torch.utils._pytree._register_pytree_node(
model.safetensors:  17%|█▋        | 73.4M/440M [00:04<00:23, 15.5MB/s]


KeyboardInterrupt: 