In [2]:
import os
import torch
from torch.nn import Embedding, ModuleList, Linear, Sigmoid
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

In [3]:
# file path
DATA_FOLDER = '../datasets'
DATE_FILENAME = 'movie_transformed.pt'
NODE_EMB_PATH = 'node_embeddings.pt'

# utils
def get_pos_edges(batch_edge_index, all_pos_edges):
    mask = []
    for link in batch_edge_index.t():
        if link in all_pos_edges.t():
            mask.append(True)
        else:
            mask.append(False)
    batch_pos_edges = torch.stack([
        batch_edge_index[0][mask],
        batch_edge_index[1][mask]
        ], dim=0
        )
    return batch_pos_edges

def load_data(folder=DATA_FOLDER, filename=DATE_FILENAME):
    data = torch.load(os.path.join(folder, filename))
    all_pos_edges = torch.cat([
        data.train_pos_edge_index,
        data.val_pos_edge_index,
        data.test_pos_edge_index
        ],
        dim=1
    )
    return data, all_pos_edges

def load_node_embeddings(folder=DATA_FOLDER, filename=NODE_EMB_PATH):
    return torch.load(os.path.join(folder, filename))

data, all_pos_edges = load_data()
node_embeddings = load_node_embeddings()

In [7]:
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels, num_layers, use_graph_stat=False, 
                 dropout=0.5, feature_dim = 2, graph_stat_dim = 3, node_embeddings = node_embeddings):
        super(GCN, self).__init__()
        self.use_graph_stat = use_graph_stat
        self.node_embeddings = node_embeddings
        self.actor_emb = Embedding(1942, hidden_channels)

        initial_channels = hidden_channels
        initial_channels += feature_dim
        # if graph stats included
        if self.use_graph_stat:
            initial_channels += graph_stat_dim

        # self.batchnorm = LBN(feature_dim + graph_stat_dim) if use_graph_stat else LBN(feature_dim)
        self.convs = ModuleList()
        self.convs.append(GCNConv(initial_channels, hidden_channels))
        for _ in range(num_layers - 1):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))

        self.dropout = dropout
        # self.lin1 = Linear(hidden_channels, hidden_channels)
        # self.lin2 = Linear(hidden_channels, 1)
        self.lin = Linear(hidden_channels, 1)

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, node_ids, edge_index, edge_label_index, edge_weight=None):
        # x = self.batchnorm(x)
        # x = torch.cat([self.actor_emb(node_ids), x], dim=-1)
        x = torch.cat([self.node_embeddings[node_ids], x], dim=-1)
        
        pos_edges = get_pos_edges(edge_index, all_pos_edges)
        # edge_index, _ = subgraph(node_ids, edge_index)
        for conv in self.convs[:-1]:
            x = conv(x, pos_edges, edge_weight)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index, edge_weight)
        
        x_src = x[edge_label_index[0]]
        x_dst = x[edge_label_index[1]]
        x = (x_src * x_dst) # node pair
        # x = F.relu(self.lin1(x))
        # x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin(x)
        
        return x

In [8]:
def inference(model: GCN, node_pair: tuple, THRESHOLD: float=0.5):
    torch.manual_seed(0)

    model.eval()
    node_ids = torch.tensor(list(range(data.num_nodes)))
    edge_label_index = torch.tensor([[node_pair[0]], [node_pair[1]]])
    logits = model(data.x, node_ids, all_pos_edges, edge_label_index)

    logits = Sigmoid()(logits.view(-1))
    logits[logits >= THRESHOLD] = 1
    logits[logits < THRESHOLD] = 0
    logits = logits.cpu().detach()
    model_inferred = int(logits[0].item())
    observed = 0
    mask_s = all_pos_edges[0] == node_pair[0]
    mask_t = all_pos_edges[1] == node_pair[1]
    mask = mask_s * mask_t
    if any(mask):
        observed = 1
    
    return observed, model_inferred

In [11]:
# start inference
# TODO: change it!
node_pair = (2, 40)

MODEL_PATH = './model_states'
MODEL_NAME = 'model_checkpoint_final_model'
HIDDEN_CHANNEL = 64
NUM_LAYERS = 2
USE_GRAPH_STAT = False
DROPOUT = 0.3
model = GCN(HIDDEN_CHANNEL, NUM_LAYERS, USE_GRAPH_STAT, DROPOUT)
model_state = torch.load(os.path.join(MODEL_PATH, MODEL_NAME))
model.load_state_dict(model_state)

observed, model_inferred = inference(model, node_pair)
print(f'observed: {observed} -> inferred: {model_inferred}')

observed: 0 -> inferred: 0
