In [44]:
import torch
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from torch_geometric.nn import GATConv
from torch_geometric.data import Batch


from sklearn.model_selection import train_test_split

In [36]:
def split_questions(questions, train_ratio=0.8, val_ratio=0.1):
    stratify_labels = [q.y[:, 0].item() if q.y.dim() > 1 else q.y.item() for q in questions]
    
    try:
        train_questions, test_questions = train_test_split(questions, test_size=1 - train_ratio, random_state=42, stratify=stratify_labels)
    except ValueError:
        print("Stratification not possible, splitting without stratification.")
        train_questions, test_questions = train_test_split(questions, test_size=1 - train_ratio, random_state=42, stratify=None)
    
    val_ratio_adjusted = val_ratio / (1 - train_ratio)
    
    try:
        train_questions, val_questions = train_test_split(train_questions, test_size=val_ratio_adjusted, random_state=42, stratify=[q.y[:, 0].item() if q.y.dim() > 1 else q.y.item() for q in train_questions])
    except ValueError:
        print("Stratification not possible, splitting without stratification.")
        train_questions, val_questions = train_test_split(train_questions, test_size=val_ratio_adjusted, random_state=42, stratify=None)

    return train_questions, val_questions, test_questions

In [16]:
q_test_graphs = torch.load('../../local_datasets/bsard_extra/q_test_graphs.pt')
q_train_graphs = torch.load('../../local_datasets/bsard_extra/q_train_graphs.pt')
q_synth_graphs = torch.load('../../local_datasets/bsard_extra/q_synth_graphs.pt')
articles_graphs = torch.load('../../local_datasets/bsard_extra/article_graphs.pt')

In [19]:
expert_questions = q_train_graphs + q_test_graphs

In [37]:
# Load and split datasets
expert_train, expert_val, expert_test = split_questions(expert_questions)
synthetic_train, synthetic_val, synthetic_test = split_questions(q_synth_graphs)

# Create PyTorch DataLoaders for synthetic dataset
synthetic_train_loader = DataLoader(synthetic_train, batch_size=32, shuffle=True)
synthetic_val_loader = DataLoader(synthetic_val, batch_size=32, shuffle=False)

# Create PyTorch DataLoaders for expert-annotated dataset
expert_train_loader = DataLoader(expert_train, batch_size=32, shuffle=True)
expert_val_loader = DataLoader(expert_val, batch_size=32, shuffle=False)

Stratification not possible, splitting without stratification.
Stratification not possible, splitting without stratification.




In [63]:
# Define GAT model
class GAT(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GAT, self).__init__()
        self.conv1 = GATConv(in_channels, 64, heads=8, dropout=0.6)
        self.conv2 = GATConv(64 * 8, out_channels, heads=1, concat=True, dropout=0.6)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return x
    
def compute_loss(similarity_matrix, y, article_sizes):
    # Filter out article labels that don't exist in current batch
    unique_labels = torch.unique(y)
    print(len(unique_labels))
    article_sizes = [article_sizes[label] for label in unique_labels]
    print(article_sizes)
    y = torch.tensor([unique_labels.tolist().index(label.item()) for label in y])

    pos_indices = []
    pos_similarities = []
    for i, label in enumerate(y):
        start_index = sum(article_sizes[:label])
        end_index = start_index + article_sizes[label]
        pos_indices.extend(list(range(start_index, end_index)))
        pos_similarities.extend([similarity_matrix[i, j] for j in range(start_index, end_index)])

    # Sample negatives with negative sampling
    neg_indices = torch.ones_like(similarity_matrix, dtype=torch.bool)
    neg_indices[torch.arange(y.size(0)), pos_indices] = 0
    neg_similarities = similarity_matrix[neg_indices].view(y.size(0), -1)

    # Compute loss
    loss = torch.clamp(neg_similarities - pos_similarities.view(-1, 1) + 1, min=0).mean()
    return loss

In [64]:
num_features = synthetic_train[0].x.shape[1]
num_classes = len(articles_graphs)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GAT(num_features, num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)

# def train():
#     model.train()
#     total_loss = 0
#     for data in train_loader:
#         data = data.to(device)
#         optimizer.zero_grad()
#         out = model(data)
#         loss = F.nll_loss(out, data.y)
#         loss.backward()
#         optimizer.step()
#         total_loss += loss.item()
#     return total_loss / len(train_loader)

# def validate():
#     model.eval()
#     total_loss = 0
#     for data in val_loader:
#         data = data.to(device)
#         out = model(data)
#         loss = F.nll_loss(out, data.y)
#         total_loss += loss.item()
#     return total_loss / len(val_loader)

# for epoch in range(1, 201):
#     train_loss = train()
#     val_loss = validate()
#     print('Epoch: {:03d}, Train Loss: {:.5f}, Validation Loss: {:.5f}'.format(epoch, train_loss, val_loss))

def train(loader, articles_graphs):
    model.train()

    total_loss = 0

    for batch in loader:
        optimizer.zero_grad()
        batch = batch.to(device)

        # Compute question embeddings
        question_embeddings = model(batch)

        # Compute article embeddings
        article_embeddings = []
        article_sizes = []
        for idx in torch.unique(batch.y).tolist():
            article = articles_graphs[idx].to(device)
            article_embedding = model(article)
            article_embeddings.append(article_embedding)
            article_sizes.append(article_embedding.size(0))

        # Concatenate article embeddings
        article_embeddings = torch.cat(article_embeddings, dim=0)

        # Compute similarity scores
        similarity_matrix = torch.matmul(question_embeddings, article_embeddings.t())

        print(len(article_sizes))
        # Compute loss and update model
        loss = compute_loss(similarity_matrix, batch.y, article_sizes)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(loader)

In [65]:
for epoch in range(2):
    loss = train(synthetic_train_loader, articles_graphs)
    print(f"Epoch: {epoch+1}, Loss: {loss:.4f}")

32
32


IndexError: list index out of range