# Setup

In [1]:
import random
import numpy as np
from tqdm import tqdm
import os
from datetime import datetime

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
import torch
from torch.nn import (ModuleList, Linear, Embedding, ReLU, 
                      LazyBatchNorm1d as LBN)
from torch.nn import BCEWithLogitsLoss, Sigmoid
import torch.nn.functional as F
from torch_geometric.utils import subgraph
from torch_geometric.data import Data
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.nn import GCNConv

In [4]:
from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score

In [5]:
def set_seed(seed: int=42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    print("**** The seed has been initialized ****")
    return None
set_seed()

**** The seed has been initialized ****


In [6]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data Preprocessing

[*Ref: A tour of PyG’s data loaders*](https://medium.com/stanford-cs224w/a-tour-of-pygs-data-loaders-9f2384e48f8f)

[*MovieLens & LinkNeighborLoader*](https://colab.research.google.com/drive/1GrAxHyZCZ13jpTkMy9vVO_v_U9nHDdvB)

In [123]:
#TODO:change the root path
folder = '../SEAL_OGB/dataset/movie_actor'
filename = 'movie_transformed.pt'
data = torch.load(os.path.join(folder, filename))
all_pos_edges = torch.cat([
    data.train_pos_edge_index,
    data.val_pos_edge_index,
    data.test_pos_edge_index
    ],
    dim=1
)
all_pos_edges.size()

torch.Size([2, 1313])

In [124]:
data

Data(x=[1942, 2], x_stat=[1942, 5], train_neg_edge_index=[2, 1051], val_pos_edge_index=[2, 131], test_pos_edge_index=[2, 131], train_pos_edge_index=[2, 1051], train_neg_adj_mask=[1942, 1942], val_neg_edge_index=[2, 131], test_neg_edge_index=[2, 131])

# Data Preprocessing

In [125]:
# utils
def get_pos_edges(batch_edge_index, all_pos_edges):
    mask = []
    for link in batch_edge_index.t():
        if link in all_pos_edges.t():
            mask.append(True)
        else:
            mask.append(False)
    batch_pos_edges = torch.stack([
        batch_edge_index[0][mask],
        batch_edge_index[1][mask]
        ], dim=0
        )
    return batch_pos_edges

def get_edge_mask(edge_index, all_pos_edges):
    mask = torch.zeros(1313, dtype=bool)
    for link in edge_index.t():
        tmp = (all_pos_edges.t() == link)
        idx = int((tmp[:, 0] * tmp[:, 1]).nonzero(as_tuple=True)[0])
        mask[idx] = True
    return mask

train_mask = get_edge_mask(data.train_pos_edge_index, all_pos_edges)
test_mask = get_edge_mask(data.test_pos_edge_index, all_pos_edges)

## Node2Vec
[*Ref: node2vec pytorch geometric*](https://juejin.cn/s/node2vec%20pytorch%20geometric)

In [126]:
from torch_geometric.nn import Node2Vec
torch.manual_seed(0)
EMBEDDING_DIM = 64

# Initialize Node2Vec
model = Node2Vec(all_pos_edges, embedding_dim=EMBEDDING_DIM, walk_length=10,
                 context_size=3, walks_per_node=20, num_negative_samples=1,
                 p=0.8, q=0.3, sparse=True)

# Train model
loader = model.loader(batch_size=128, shuffle=True)
optimizer = torch.optim.SparseAdam(model.parameters(), lr=0.01)
model.train()
for epoch in range(1, 101):
    total_loss = 0
    for pos_rw, neg_rw in loader:
        optimizer.zero_grad()
        loss = model.loss(pos_rw, neg_rw)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print('Epoch: {:02d}, Loss: {:.4f}'.format(epoch, total_loss / len(loader)))

z = model()
node_embeddings = z.detach().cpu()
node_embeddings.size()

Epoch: 01, Loss: 3.5137
Epoch: 02, Loss: 3.0573
Epoch: 03, Loss: 2.6701
Epoch: 04, Loss: 2.3868
Epoch: 05, Loss: 2.1373
Epoch: 06, Loss: 1.9191
Epoch: 07, Loss: 1.7372
Epoch: 08, Loss: 1.5801
Epoch: 09, Loss: 1.4490
Epoch: 10, Loss: 1.3358
Epoch: 11, Loss: 1.2353
Epoch: 12, Loss: 1.1495
Epoch: 13, Loss: 1.0781
Epoch: 14, Loss: 1.0193
Epoch: 15, Loss: 0.9683
Epoch: 16, Loss: 0.9314
Epoch: 17, Loss: 0.8960
Epoch: 18, Loss: 0.8691
Epoch: 19, Loss: 0.8473
Epoch: 20, Loss: 0.8307
Epoch: 21, Loss: 0.8154
Epoch: 22, Loss: 0.8032
Epoch: 23, Loss: 0.7939
Epoch: 24, Loss: 0.7859
Epoch: 25, Loss: 0.7798
Epoch: 26, Loss: 0.7737
Epoch: 27, Loss: 0.7710
Epoch: 28, Loss: 0.7672
Epoch: 29, Loss: 0.7638
Epoch: 30, Loss: 0.7602
Epoch: 31, Loss: 0.7583
Epoch: 32, Loss: 0.7564
Epoch: 33, Loss: 0.7544
Epoch: 34, Loss: 0.7537
Epoch: 35, Loss: 0.7512
Epoch: 36, Loss: 0.7496
Epoch: 37, Loss: 0.7490
Epoch: 38, Loss: 0.7485
Epoch: 39, Loss: 0.7472
Epoch: 40, Loss: 0.7451
Epoch: 41, Loss: 0.7454
Epoch: 42, Loss:

torch.Size([1942, 64])

In [151]:
torch.save(node_embeddings, 'node_embeddings')

## For Dataloader

In [127]:
def prepare_data(data: Data, mode: str) -> Data:
    d = data.clone().detach()
    if mode == 'train':
        d.edge_index = torch.cat([
            data.train_pos_edge_index,
            data.train_neg_edge_index
            ],
            dim=1
        )
        d.edge_label_index = d.edge_index.clone()
        pos_cnt = data.train_pos_edge_index.size(1)
        neg_cnt = data.train_neg_edge_index.size(1)
        d.edge_label = torch.cat([torch.ones(pos_cnt), torch.zeros(neg_cnt)])
    elif mode == 'valid':
        d.edge_index = torch.cat([
            data.val_pos_edge_index,
            data.val_neg_edge_index
            ],
            dim=1
        )
        d.edge_label_index = d.edge_index.clone()
        pos_cnt = data.val_pos_edge_index.size(1)
        neg_cnt = data.val_neg_edge_index.size(1)
        d.edge_label = torch.cat([torch.ones(pos_cnt), torch.zeros(neg_cnt)])
    elif mode == 'test':
        d.edge_index = torch.cat([
            data.test_pos_edge_index,
            data.test_neg_edge_index
            ],
            dim=1
        )
        d.edge_label_index = d.edge_index.clone()
        pos_cnt = data.test_pos_edge_index.size(1)
        neg_cnt = data.test_neg_edge_index.size(1)
        d.edge_label = torch.cat([torch.ones(pos_cnt), torch.zeros(neg_cnt)])
    else: # all
        d.edge_index = torch.cat([
            data.train_pos_edge_index,
            data.val_pos_edge_index,
            data.test_pos_edge_index,
            data.train_neg_edge_index,
            data.val_neg_edge_index,
            data.test_neg_edge_index
            ],
            dim=1
        )
        d.edge_label_index = d.edge_index.clone()
        pos_cnt = data.train_pos_edge_index.size(1) + data.val_pos_edge_index.size(1) + data.test_pos_edge_index.size(1)
        neg_cnt = data.train_neg_edge_index.size(1) + data.val_neg_edge_index.size(1) + data.test_neg_edge_index.size(1)
        d.edge_label = torch.cat([torch.ones(pos_cnt), torch.zeros(neg_cnt)])
    
    # clean output
    del d.train_pos_edge_index
    del d.train_neg_edge_index
    del d.val_pos_edge_index
    del d.val_neg_edge_index
    del d.test_pos_edge_index
    del d.test_neg_edge_index
    del d.train_neg_adj_mask

    return d

## Model Hyper-Parameter Setting

In [129]:
#TODO: Model Hyper-Parameters
# dataloader
BATCH_SIZE = 64
NUM_NEIGHBORS = [5, 5]

# GCN
HIDDEN_CHANNEL = 64
NUM_LAYERS = 2
USE_GRAPH_STAT = False
DROPOUT = 0.3

# Optimization
EPOCHS = 300
LR = 1e-4
PATIENCE = 80

MODEL_PATH = './model_states'

In [116]:
torch.manual_seed(0)
train_data = prepare_data(data, 'train')
valid_data = prepare_data(data, 'valid')
test_data = prepare_data(data, 'test')

print(train_data)
print(valid_data)
print(test_data)

train_loader = LinkNeighborLoader(train_data,
                                  num_neighbors=NUM_NEIGHBORS,
                                  batch_size=BATCH_SIZE,
                                  edge_label=train_data.edge_label,
                                  edge_label_index=train_data.edge_label_index,
                                  # these are only set for train!
                                  shuffle=True)
val_loader   = LinkNeighborLoader(valid_data,
                                  num_neighbors=NUM_NEIGHBORS,
                                  batch_size=BATCH_SIZE,
                                  edge_label=valid_data.edge_label,
                                  edge_label_index=valid_data.edge_label_index,
                                  shuffle=True)
test_loader  = LinkNeighborLoader(test_data,
                                  num_neighbors=NUM_NEIGHBORS,
                                  batch_size=BATCH_SIZE,
                                  edge_label=test_data.edge_label,
                                  edge_label_index=test_data.edge_label_index,
                                  shuffle=True)

Data(x=[1942, 2], x_stat=[1942, 5], edge_index=[2, 2102], edge_label_index=[2, 2102], edge_label=[2102])
Data(x=[1942, 2], x_stat=[1942, 5], edge_index=[2, 262], edge_label_index=[2, 262], edge_label=[262])
Data(x=[1942, 2], x_stat=[1942, 5], edge_index=[2, 262], edge_label_index=[2, 262], edge_label=[262])


# Model

In [117]:
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels, num_layers, use_graph_stat=False, 
                 dropout=0.5, feature_dim = 2, graph_stat_dim = 3, node_embeddings = node_embeddings):
        super(GCN, self).__init__()
        self.use_graph_stat = use_graph_stat
        self.node_embeddings = node_embeddings
        self.actor_emb = Embedding(1942, hidden_channels)

        initial_channels = hidden_channels
        initial_channels += feature_dim
        # if graph stats included
        if self.use_graph_stat:
            initial_channels += graph_stat_dim

        # self.batchnorm = LBN(feature_dim + graph_stat_dim) if use_graph_stat else LBN(feature_dim)
        self.convs = ModuleList()
        self.convs.append(GCNConv(initial_channels, hidden_channels))
        for _ in range(num_layers - 1):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))

        self.dropout = dropout
        # self.lin1 = Linear(hidden_channels, hidden_channels)
        # self.lin2 = Linear(hidden_channels, 1)
        self.lin = Linear(hidden_channels, 1)

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, node_ids, edge_index, edge_label_index, edge_weight=None):
        # x = self.batchnorm(x)
        # x = torch.cat([self.actor_emb(node_ids), x], dim=-1)
        x = torch.cat([self.node_embeddings[node_ids], x], dim=-1)
        
        pos_edges = get_pos_edges(edge_index, all_pos_edges)
        # edge_index, _ = subgraph(node_ids, edge_index)
        for conv in self.convs[:-1]:
            x = conv(x, pos_edges, edge_weight)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index, edge_weight)
        
        x_src = x[edge_label_index[0]]
        x_dst = x[edge_label_index[1]]
        x = (x_src * x_dst) # node pair
        # x = F.relu(self.lin1(x))
        # x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin(x)
        
        return x

### Train


In [118]:
# Model Config
model = GCN(HIDDEN_CHANNEL, NUM_LAYERS, USE_GRAPH_STAT, DROPOUT)
parameters = list(model.parameters())
optimizer = torch.optim.Adam(params=parameters, lr=LR)
loss_fn = BCEWithLogitsLoss()

In [119]:
def train(model, train_loader=train_loader, val_loader=val_loader, verbose=True):
    print(f"***** Epochs: {EPOCHS}, Batch_size: {BATCH_SIZE}, Use_graph_stat: {USE_GRAPH_STAT}, Num_neighbors: {NUM_NEIGHBORS} *****")
    print(f"***** Num_layers: {NUM_LAYERS}, Hidden_channels: {HIDDEN_CHANNEL}, Dropout: {DROPOUT}, lr: {LR} *****\n")
    set_seed()
    torch.manual_seed(0)

    model.train()
    THRESHOLD = 0.5
    trigger = 0
    best_valid_auc = 0
    for epoch_num in range(1, EPOCHS+1):
        train_losses = []
        y_pred_tensors = torch.tensor([])
        y_true_tensors = torch.tensor([])

        total_loss = 0
        pbar = tqdm(train_loader, ncols=70, disable=(not verbose))
        for batch_data in pbar:
            batch_data = batch_data.to(DEVICE)
            optimizer.zero_grad()
            x = data.x_stat if USE_GRAPH_STAT else data.x #TODO: use all data for message passing

            # node_ids = torch.LongTensor(range(data.num_nodes))
            logits = model(batch_data.x, batch_data.n_id, batch_data.edge_index, batch_data.edge_label_index)
            loss = loss_fn(logits.view(-1), batch_data.edge_label)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            logits = Sigmoid()(logits.view(-1))
            logits[logits >= THRESHOLD] = 1
            logits[logits < THRESHOLD] = 0
            y_pred_tensors = torch.cat([y_pred_tensors, logits.cpu()])
            y_true_tensors = torch.cat([y_true_tensors, batch_data.edge_label.cpu()])

        avg_loss = round(total_loss / len(train_loader), 3)
        train_losses.append(avg_loss)
        train_auc = roc_auc_score(y_true_tensors.detach(), y_pred_tensors.detach())

        # ======================================================
        # Set the state of the model to "Evaluation"
        model.eval()

        valid_true = torch.tensor([])
        valid_pred = torch.tensor([])

        # Validation
        with torch.no_grad():
            for valid_batch_data in tqdm(val_loader, ncols=70, disable=(not verbose)):
                valid_batch_data = valid_batch_data.to(DEVICE)
                x = valid_batch_data.x_stat if USE_GRAPH_STAT else valid_batch_data.x
                logits = model(x, valid_batch_data.n_id, valid_batch_data.edge_index, valid_batch_data.edge_label_index)

                logits = Sigmoid()(logits.view(-1))
                logits[logits >= THRESHOLD] = 1
                logits[logits < THRESHOLD] = 0
                valid_pred = torch.cat([valid_pred, logits.cpu()])
                valid_true = torch.cat([valid_true, valid_batch_data.edge_label.cpu()])
            
            valid_auc = roc_auc_score(valid_true.detach(), valid_pred.detach())
        
        if epoch_num % 10 == 0:
            print(
                f"Epochs: {epoch_num}  | Train Loss: {avg_loss: .3f} | Train AUC: {train_auc: .3f} | Valid AUC: {valid_auc: .3f}  --best: {best_valid_auc: .3f}"
            )
        
        # early-stopping
        if valid_auc <= best_valid_auc:
            trigger += 1
        else:
            trigger = 0
            best_valid_auc = valid_auc
        if trigger == PATIENCE:
            print(f"early stopping at epoch: {epoch_num}")
            break

    # save model
    time = datetime.now().strftime('%m-%d_%H:%M')
    torch.save(model.state_dict(), os.path.join(MODEL_PATH, f'model_checkpoint_{round(best_valid_auc, 3)}'))

    return model


In [120]:
def test(model, test_loader=test_loader):
    torch.manual_seed(0)
    model.eval()

    test_loss = 0
    test_true = torch.tensor([])
    test_pred = torch.tensor([])
    # Validation
    THRESHOLD = 0.5
    with torch.no_grad():
        for test_batch_data in test_loader:
            test_batch_data = test_batch_data.to(DEVICE)
            x = test_batch_data.x_stat if USE_GRAPH_STAT else test_batch_data.x
            logits = model(x, test_batch_data.n_id, test_batch_data.edge_index, test_batch_data.edge_label_index)
            loss = loss_fn(logits.view(-1), test_batch_data.edge_label)
            test_loss += loss

            logits = Sigmoid()(logits.view(-1))
            logits[logits >= THRESHOLD] = 1
            logits[logits < THRESHOLD] = 0
            test_pred = torch.cat([test_pred, logits.cpu()])
            test_true = torch.cat([test_true, test_batch_data.edge_label.cpu()])
        
        test_auc = roc_auc_score(test_true.detach(), test_pred.detach())
        avg_loss = test_loss / len(test_loader)
    print(f"Test Loss: {avg_loss: .3f} | Test AUC: {test_auc: .3f}")

    return test_true, test_pred


In [122]:
# Start training!
GCN_model = train(model, verbose=False)
print('='*20, '\n')

# Start testing!
test_true, test_pred = test(GCN_model)

***** Epochs: 300, Batch_size: 64, Use_graph_stat: False, Num_neighbors: [5, 5] *****
***** Num_layers: 2, Hidden_channels: 64, Dropout: 0.3, lr: 0.0001 *****

**** The seed has been initialized ****
Epochs: 10  | Train Loss:  0.593 | Train AUC:  0.667 | Valid AUC:  0.649  --best:  0.668
Epochs: 20  | Train Loss:  0.594 | Train AUC:  0.661 | Valid AUC:  0.668  --best:  0.668
Epochs: 30  | Train Loss:  0.591 | Train AUC:  0.668 | Valid AUC:  0.672  --best:  0.668
Epochs: 40  | Train Loss:  0.591 | Train AUC:  0.660 | Valid AUC:  0.653  --best:  0.672
Epochs: 50  | Train Loss:  0.594 | Train AUC:  0.664 | Valid AUC:  0.656  --best:  0.672
Epochs: 60  | Train Loss:  0.592 | Train AUC:  0.662 | Valid AUC:  0.645  --best:  0.672
Epochs: 70  | Train Loss:  0.596 | Train AUC:  0.658 | Valid AUC:  0.653  --best:  0.672
Epochs: 80  | Train Loss:  0.596 | Train AUC:  0.662 | Valid AUC:  0.653  --best:  0.672
Epochs: 90  | Train Loss:  0.590 | Train AUC:  0.659 | Valid AUC:  0.649  --best:  0.672

In [121]:
# Current Best
# n2v: p=0.8, q=0.3, walk_length=10, walks=20, context=3
# ***** Epochs: 300, Batch_size: 64, Use_graph_stat: False, Num_neighbors: [5, 5] *****
# ***** Num_layers: 2, Hidden_channels: 64, Dropout: 0.3, lr: 0.0001 *****

model = GCN(HIDDEN_CHANNEL, NUM_LAYERS, USE_GRAPH_STAT, DROPOUT)
model_state = torch.load(os.path.join(MODEL_PATH, 'model_checkpoint_best_0.676'))
model.load_state_dict(model_state)
test_pred, test_true = test(model)

Test Loss:  0.571 | Test AUC:  0.676


In [151]:
# # Current Best
# # n2v: p=0.8, q=0.3, walk_length=10, walks=20, context=3
# # ***** Epochs: 300, Batch_size: 128, Use_graph_stat: False, Num_neighbors: [5, 5] *****
# # ***** Num_layers: 2, Hidden_channels: 64, Dropout: 0.3, lr: 0.0001 *****

# model = GCN(HIDDEN_CHANNEL, NUM_LAYERS, USE_GRAPH_STAT, DROPOUT)
# model_state = torch.load(os.path.join(MODEL_PATH, 'model_checkpoint_best_0.668'))
# model.load_state_dict(model_state)
# test_pred, test_true = test(model)

Test Loss:  0.556 | Test AUC:  0.668


In [146]:
# # Current Best
# # n2v: p=0.8, q=0.3, walk_length=10, walks=20, context=3
# # ***** Epochs: 300, Batch_size: 128, Use_graph_stat: False, Num_neighbors: [100, 100] *****
# # ***** Num_layers: 2, Hidden_channels: 64, Dropout: 0.3, lr: 0.005 *****

# model = GCN(HIDDEN_CHANNEL, NUM_LAYERS, USE_GRAPH_STAT, DROPOUT)
# model_state = torch.load(os.path.join(MODEL_PATH, 'model_checkpoint_best_0.653'))
# model.load_state_dict(model_state)
# test_pred, test_true = test(model)

Test Loss:  2.577 | Test AUC:  0.653


In [147]:
# # Current Best
# # n2v: p=0.8, q=0.3, walk_length=10, walks=20, context=3
# # ***** Epochs: 300, Batch_size: 128, Use_graph_stat: False, Num_neighbors: [5, 5, 5] *****
# # ***** Num_layers: 2, Hidden_channels: 64, Dropout: 0.3, lr: 0.0001 *****

# model = GCN(HIDDEN_CHANNEL, NUM_LAYERS, USE_GRAPH_STAT, DROPOUT)
# model_state = torch.load(os.path.join(MODEL_PATH, 'model_checkpoint_best_0.641'))
# model.load_state_dict(model_state)
# test_pred, test_true = test(model)

Test Loss:  0.558 | Test AUC:  0.641


In [149]:
# # Current Best
# # n2v: p=0.8, q=0.3, walk_length=10, walks=20, context=3
# # ***** Epochs: 300, Batch_size: 128, Use_graph_stat: False, Num_neighbors: [5, 5, 5] *****
# # ***** Num_layers: 2, Hidden_channels: 64, Dropout: 0.3, lr: 0.0001 *****

# model = GCN(64, 2, USE_GRAPH_STAT, 0.3)
# model_state = torch.load(os.path.join(MODEL_PATH, 'model_checkpoint_best_0.626'))
# model.load_state_dict(model_state)
# test_pred, test_true = test(model)

Test Loss:  0.563 | Test AUC:  0.626


### Features + Graph_Stat

In [651]:
# # GCN
# HIDDEN_CHANNEL = 128
# NUM_LAYERS = 2
# USE_GRAPH_STAT = True
# DROPOUT = 0.3

# # Optimization
# EPOCHS = 300
# LR = 1e-4
# MOMENTUM = 0.9
# PATIENCE = 30

# # Model Config
# model_stat = GCN(HIDDEN_CHANNEL, NUM_LAYERS, USE_GRAPH_STAT, DROPOUT)
# parameters = list(model_stat.parameters())
# # optimizer = torch.optim.SGD(params=parameters, lr=LR, momentum=MOMENTUM)
# optimizer = torch.optim.Adam(params=parameters, lr=LR)
# loss_fn = BCEWithLogitsLoss()

In [652]:
# GCN_v2 = train(model_stat, verbose=False)

***** Epochs: 300, Batch_size: 64, Use_graph_stat: True *****
***** Num_layers: 2, Hidden_channels: 128, Dropout: 0.3, lr: 0.0001 *****

**** The seed has been initialized ****
Epochs: 10  | Train Loss:  0.654 | Train AUC:  0.562 | Valid AUC:  0.561  --best: 0.549618320610687
Epochs: 20  | Train Loss:  0.612 | Train AUC:  0.596 | Valid AUC:  0.519  --best: 0.5610687022900763
Epochs: 30  | Train Loss:  0.556 | Train AUC:  0.656 | Valid AUC:  0.542  --best: 0.5763358778625953
Epochs: 40  | Train Loss:  0.486 | Train AUC:  0.734 | Valid AUC:  0.569  --best: 0.6068702290076337
Epochs: 50  | Train Loss:  0.407 | Train AUC:  0.794 | Valid AUC:  0.550  --best: 0.6068702290076337
Epochs: 60  | Train Loss:  0.345 | Train AUC:  0.836 | Valid AUC:  0.550  --best: 0.6068702290076337
early stopping at epoch: 66


# Inference

In [130]:
# Dataloader for whole dataset
torch.manual_seed(0)
all_data = prepare_data(data, 'all')
print(all_data)

all_data_loader = LinkNeighborLoader(all_data,
                                  num_neighbors=NUM_NEIGHBORS,
                                  batch_size=BATCH_SIZE,
                                  edge_label=all_data.edge_label,
                                  edge_label_index=all_data.edge_label_index,
                                  shuffle=True)

# Model Config
model = GCN(HIDDEN_CHANNEL, NUM_LAYERS, USE_GRAPH_STAT, DROPOUT)
parameters = list(model.parameters())
optimizer = torch.optim.Adam(params=parameters, lr=LR)
loss_fn = BCEWithLogitsLoss()

# Start training!
GCN_final_model = train(model, train_loader=all_data_loader, verbose=False)
print('='*20, '\n')
print("*** training completed! ***")

Data(x=[1942, 2], x_stat=[1942, 5], edge_index=[2, 2626], edge_label_index=[2, 2626], edge_label=[2626])
***** Epochs: 300, Batch_size: 64, Use_graph_stat: False, Num_neighbors: [5, 5] *****
***** Num_layers: 2, Hidden_channels: 64, Dropout: 0.3, lr: 0.0001 *****

**** The seed has been initialized ****
Epochs: 10  | Train Loss:  0.677 | Train AUC:  0.575 | Valid AUC:  0.592  --best:  0.611
Epochs: 20  | Train Loss:  0.658 | Train AUC:  0.597 | Valid AUC:  0.618  --best:  0.645
Epochs: 30  | Train Loss:  0.663 | Train AUC:  0.594 | Valid AUC:  0.649  --best:  0.656
Epochs: 40  | Train Loss:  0.652 | Train AUC:  0.607 | Valid AUC:  0.637  --best:  0.668
Epochs: 50  | Train Loss:  0.648 | Train AUC:  0.615 | Valid AUC:  0.668  --best:  0.683
Epochs: 60  | Train Loss:  0.658 | Train AUC:  0.610 | Valid AUC:  0.672  --best:  0.691
Epochs: 70  | Train Loss:  0.641 | Train AUC:  0.619 | Valid AUC:  0.664  --best:  0.702
Epochs: 80  | Train Loss:  0.646 | Train AUC:  0.623 | Valid AUC:  0.683

In [131]:
def inference(model: GCN, node_pair: tuple, THRESHOLD: float=0.5):
    torch.manual_seed(0)
    folder = '../SEAL_OGB/dataset/movie_actor'
    filename = 'movie_transformed.pt'
    data = torch.load(os.path.join(folder, filename))
    all_pos_edges = torch.cat([
        data.train_pos_edge_index,
        data.val_pos_edge_index,
        data.test_pos_edge_index
        ],
        dim=1
    )

    model.eval()
    node_ids = torch.tensor(list(range(data.num_nodes)))
    edge_label_index = torch.tensor([[node_pair[0]], [node_pair[1]]])
    logits = model(data.x, node_ids, all_pos_edges, edge_label_index)

    logits = Sigmoid()(logits.view(-1))
    logits[logits >= THRESHOLD] = 1
    logits[logits < THRESHOLD] = 0
    logits = logits.cpu().detach()
    model_inferred = int(logits[0].item())
    observed = 0
    mask_s = all_pos_edges[0] == node_pair[0]
    mask_t = all_pos_edges[1] == node_pair[1]
    mask = mask_s * mask_t
    if any(mask):
        observed = 1
    
    return observed, model_inferred

In [146]:
# start inference
model_name = 'model_checkpoint_final_model'
HIDDEN_CHANNEL = 64
NUM_LAYERS = 2
USE_GRAPH_STAT = False
DROPOUT = 0.3
model = GCN(HIDDEN_CHANNEL, NUM_LAYERS, USE_GRAPH_STAT, DROPOUT)
model_state = torch.load(os.path.join(MODEL_PATH, model_name))
model.load_state_dict(model_state)

node_pair = (1, 30)
observed, model_inferred = inference(model, node_pair)
print(f'observed: {observed} -> inferred: {model_inferred}')

observed: 0 -> inferred: 1


In [141]:
import json
data_path = '../datasets/actor_id_dict.json'
with open(data_path, 'r') as f:
    actor_id = json.load(f)
actor_id

{'AamirKhan': 0,
 'AbhishekBachchan': 1,
 'FatimaSanaShaikh': 2,
 'HrithikRoshan': 3,
 'BárbaraMori': 4,
 'ShahRukhKhan': 5,
 'RaniMukerji': 6,
 'SalmanKhan': 7,
 'KareenaKapoorKhan': 8,
 'PreityZinta': 9,
 'IrrfanKhan': 10,
 'NimratKaur': 11,
 'NaseeruddinShah': 12,
 'LilleteDubey': 13,
 'KatrinaKaif': 14,
 'AnushkaSharma': 15,
 'Siddharth': 16,
 'RanbirKapoor': 17,
 'DeepikaPadukone': 18,
 'AkshayKumar': 19,
 'ArjunRampal': 20,
 'SumeetDarshanDobhal': 21,
 'SaifAliKhan': 22,
 'AbhayDeol': 23,
 'EmraanHashmi': 24,
 'AnilKapoor': 25,
 'SunielShetty': 26,
 'KarismaKapoor': 27,
 'BipashaBasu': 28,
 'HelenMirren': 29,
 'JulieWalters': 30,
 'MakrandDeshpande': 31,
 'SachiinJoshi': 32,
 'PaulWalker': 33,
 'SteveZahn': 34,
 'ChristianBale': 35,
 'HeathLedger': 36,
 'ElijahWood': 37,
 'IanMcKellen': 38,
 'MatthewMcConaughey': 39,
 'AnneHathaway': 40,
 'AdrienBrody': 41,
 'ThomasKretschmann': 42,
 'LeonardoDiCaprio': 43,
 'JosephGordon-Levitt': 44,
 'AntonioBanderas': 45,
 'SalmaHayek': 46,
 '