In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from torch_geometric.nn import GATConv

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Load and split datasets
expert_train, expert_val, expert_test = split_expert_questions(question_train_and_test)
synthetic_train, synthetic_val, synthetic_test = split_synthetic_questions(question_synthetic)

train_dataset = expert_train + synthetic_train
val_dataset = expert_val + synthetic_val
test_dataset = expert_test + synthetic_test

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [4]:
# Define GAT model
class GAT(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GAT, self).__init__()
        self.conv1 = GATConv(in_channels, 32, heads=8, dropout=0.6)
        self.conv2 = GATConv(32 * 8, out_channels, heads=1, concat=True, dropout=0.6)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return x

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GAT(in_channels=301, out_channels=128).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)

# def train():
#     model.train()
#     total_loss = 0
#     for data in train_loader:
#         data = data.to(device)
#         optimizer.zero_grad()
#         out = model(data)
#         loss = F.nll_loss(out, data.y)
#         loss.backward()
#         optimizer.step()
#         total_loss += loss.item()
#     return total_loss / len(train_loader)

# def validate():
#     model.eval()
#     total_loss = 0
#     for data in val_loader:
#         data = data.to(device)
#         out = model(data)
#         loss = F.nll_loss(out, data.y)
#         total_loss += loss.item()
#     return total_loss / len(val_loader)

# for epoch in range(1, 201):
#     train_loss = train()
#     val_loss = validate()
#     print('Epoch: {:03d}, Train Loss: {:.5f}, Validation Loss: {:.5f}'.format(epoch, train_loss, val_loss))

def train(loader, articles_graphs):
    model.train()

    total_loss = 0

    for batch in loader:
        optimizer.zero_grad()
        batch = batch.to(device)

        # Compute question embeddings
        question_embeddings = model.encode_questions(batch)

        # Compute article embeddings
        article_embeddings = []
        for idx in torch.unique(batch.y).tolist():
            article = articles_graphs[idx].to(device)
            article_embedding = model.encode_articles(article)
            article_embeddings.append(article_embedding)
        article_embeddings = torch.stack(article_embeddings)

        # Compute similarity scores
        similarity_matrix = torch.matmul(question_embeddings, article_embeddings.t())

        # Compute loss and update model
        loss = compute_loss(similarity_matrix, batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(loader)



In [None]:
def test():
    model.eval()
    correct = 0
    total = 0
    for data in test_loader:
        data = data.to(device)
        out = model(data)
        pred = out.argmax(dim=1)
        correct += pred.eq(data.y).sum().item()
        total += data.num_graphs
    return correct / total

test_acc = test()
print('Test Accuracy: {:.5f}'.format(test_acc))