In [97]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [98]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import yaml
from torch_geometric.datasets import Planetoid
from tqdm import tqdm

import wandb
from GCN import GCN, EdgePrediction
from utils import build_adj_mat, build_edge_pred_datasets, compute_A_hat

with open("configEdgePred.yaml", "r") as file:
    config = yaml.safe_load(file)
    config = config["GCN"]

wandb.init(project="gnn-from-scratch", config=config)

node_dim = config["node_dim"]
hidden_dim = config["hidden_dim"]
batch_size = config["batch_size"]
lr = config["lr"]
n_epochs = config["n_epochs"]
n_train = config["n_train"]
n_val = config["n_val"]
n_test = config["n_test"]
dropout = config["dropout"]
weight_decay = config["weight_decay"]
dataset_name = config["dataset"]

dataset = Planetoid(
    "./data/", dataset_name, num_train_per_class=n_train, num_val=n_val, num_test=n_test
)

data = dataset[0]  # there is only one graph
# # One hot encoding labels for classification task
# data.y = F.one_hot(data.y).float()
# data.adj_mat = build_adj_mat(data.x, data.edge_index)

train_edge_index, val_edge_index, test_edge_index = build_edge_pred_datasets(
    data, n_train, n_val, n_test
)

gcn = GCN(
    input_dim=node_dim,
    hidden_dim=hidden_dim,
    output_dim=hidden_dim,
    n_layers=3,
    dropout=dropout,
)

edge_pred = EdgePrediction(embedding_dim=hidden_dim)

loss_fn = nn.CrossEntropyLoss()
optimizer_gcn = optim.Adam(gcn.parameters(), lr=lr, weight_decay=weight_decay)
optimizer_edge_pred = optim.Adam(
    edge_pred.parameters(), lr=lr, weight_decay=weight_decay
)

data.A_hat = compute_A_hat(data.x, data.edge_index)

gcn.train()
edge_pred.train()
optimizer_edge_pred.zero_grad()
optimizer_gcn.zero_grad()

# No matter what edges will be compared, apply the GCN to the whole graph
node_embeddings = gcn(data.x, data.A_hat)

In [99]:
def build_classifier_batch(
    edge_index, node_embeddings, batch_size: int, negative_sampling: int
) -> torch.Tensor:
    """Build a batch of positive and negative examples for the edge prediction task
    negative_sampling: int, number of negative examples to sample for each positive example"""

    # Start by building the positive examples
    list(zip(edge_index[0].tolist(), edge_index[1].tolist()))
    batch_indices = np.random.choice(edge_index.shape[1], batch_size, replace=False)
    edges = list(zip(edge_index[0].tolist(), edge_index[1].tolist()))
    pairs = [edge for i, edge in enumerate(edges) if i in batch_indices]
    positive_samples = torch.stack(
        [
            torch.cat([node_embeddings[pair[0]], node_embeddings[pair[1]]])
            for pair in pairs
        ]
    )

    # Now add negative examples
    random_edges = np.random.choice(
        node_embeddings.shape[0], (batch_size * negative_sampling, 2), replace=False
    )
    random_edges = filter(lambda x: x[0] != x[1], random_edges)
    negative_samples = torch.stack(
        [
            torch.cat([node_embeddings[edge[0]], node_embeddings[edge[1]]])
            for edge in random_edges
        ]
    )

    batch = torch.cat([positive_samples, negative_samples])

    return batch

In [100]:
build_classifier_batch(train_edge_index, node_embeddings, 4, 10).shape

torch.Size([44, 32])

In [104]:
torch.argsort(torch.argsort(torch.tensor([1, 2, 4, 5, 3]), descending=True))

tensor([4, 3, 1, 0, 2])