In [20]:
%load_ext autoreload
%autoreload 2

In [21]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import yaml
from torch_geometric.datasets import Planetoid
from tqdm import tqdm

import wandb
from GCN import GCN, EdgePrediction
from utils import (
    build_adj_mat,
    build_classifier_batch,
    build_edge_pred_datasets,
    compute_A_hat,
)

with open("configEdgePred.yaml", "r") as file:
    config = yaml.safe_load(file)
    config = config["GCN"]

wandb.init(project="gnn-from-scratch", config=config)

negative_samples_factor = config["negative_samples_factor"]
# dimensions = config["dimensions"]
dimensions = [tuple(dim) for dim in config["dimensions"]]
batch_size = config["batch_size"]
lr = config["lr"]
n_epochs = config["n_epochs"]
n_train = config["n_train"]
n_val = config["n_val"]
n_test = config["n_test"]
dropout = config["dropout"]
n_batches = config["n_batches"]
weight_decay = config["weight_decay"]
dataset_name = config["dataset"]
hits_k_rank = config["HITS@K_rank"]
hits_k_positive_samples = config["HITS@K_positive_samples"]
hits_k_negative_samples = config["HITS@K_negative_samples_factor"]

dataset = Planetoid(
    "./data/", dataset_name, num_train_per_class=n_train, num_val=n_val, num_test=n_test
)

data = dataset[0]  # there is only one graph
# # One hot encoding labels for classification task
# data.y = F.one_hot(data.y).float()
# data.adj_mat = build_adj_mat(data.x, data.edge_index)

train_edge_index, val_edge_index, test_edge_index = build_edge_pred_datasets(
    data, n_train, n_val, n_test
)

gcn = GCN(
    input_dim=node_dim,
    hidden_dim=hidden_dim,
    output_dim=hidden_dim,
    n_layers=3,
    dropout=dropout,
)

edge_pred = EdgePrediction(embedding_dim=hidden_dim)

loss_fn = nn.BCEWithLogitsLoss()
optimizer_gcn = optim.Adam(gcn.parameters(), lr=lr, weight_decay=weight_decay)
optimizer_edge_pred = optim.Adam(
    edge_pred.parameters(), lr=lr, weight_decay=weight_decay
)
scheduler_gcn = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer_gcn, mode="min", factor=0.2, patience=5, verbose=True
)
scheduler_edge_pred = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer_edge_pred, mode="min", factor=0.2, patience=5, verbose=True
)

data.A_hat = compute_A_hat(data.x, data.edge_index)




In [None]:
from torchinfo import summary

summary(gcn)

Layer (type:depth-idx)                   Param #
GCN                                      --
├─ModuleList: 1-1                        --
│    └─GCNLayer: 2-1                     95,808
│    └─GCNLayer: 2-2                     8,192
│    └─GCNLayer: 2-3                     8,192
├─Dropout: 1-2                           --
Total params: 112,192
Trainable params: 112,192
Non-trainable params: 0