In [191]:
from utils import *
import os
import scipy.sparse as ssp
from torch_geometric.data import InMemoryDataset

import numpy as np
import torch
from torch.nn import ModuleList, Linear, Embedding, Sigmoid
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

In [192]:
# setting
DATA_PATH = '../datasets/movie_transformed.pt'


In [193]:
# utils
def load_and_process_data(node_pair: tuple, path: str = DATA_PATH):
    data = torch.load(path)
    data.edge_index = torch.cat([
        data.train_pos_edge_index,
        data.val_pos_edge_index,
        data.test_pos_edge_index
        ],
        dim=1
    )
    edge_weight = torch.ones(data.edge_index.size(1), dtype=int)
    A = ssp.csr_matrix(
        (edge_weight, (data.edge_index[0], data.edge_index[1])), 
        shape=(data.num_nodes, data.num_nodes)
    )
    node_pair = torch.tensor(node_pair)
    edge = torch.tensor([[node_pair[0]], [node_pair[1]]])
    NUM_HOPS = 1
    NODE_LABEL = 'drnl'
    RATIO_PER_HOP = 1.0
    MAX_NODES_PER_HOP = None
    DIRECTED = None
    y = 1
    A_csc = None
    pos_list = extract_enclosing_subgraphs(
        edge, A, data.x, y, NUM_HOPS, NODE_LABEL, 
        RATIO_PER_HOP, MAX_NODES_PER_HOP, DIRECTED, A_csc)

    dataset = InMemoryDataset.collate(pos_list)[0]
    return data.edge_index, dataset
    

In [194]:
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels, num_layers, max_z, train_dataset, 
                 use_feature=False, node_embedding=None, dropout=0.5):
        super(GCN, self).__init__()
        self.use_feature = use_feature
        self.node_embedding = node_embedding
        self.max_z = max_z
        self.z_embedding = Embedding(self.max_z, hidden_channels)

        self.convs = ModuleList()
        initial_channels = hidden_channels
        if self.use_feature:
            initial_channels += train_dataset.num_features
        if self.node_embedding is not None:
            initial_channels += node_embedding.embedding_dim
        self.convs.append(GCNConv(initial_channels, hidden_channels))
        for _ in range(num_layers - 1):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))

        self.dropout = dropout
        self.lin1 = Linear(hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, 1)

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, z, edge_index, batch, x=None, edge_weight=None, node_id=None, inference=False):
        z_emb = self.z_embedding(z)
        if z_emb.ndim == 3:  # in case z has multiple integer labels
            z_emb = z_emb.sum(dim=1)
        if self.use_feature and x is not None:
            x = torch.cat([z_emb, x.to(torch.float)], 1)
        else:
            x = z_emb
        if self.node_embedding is not None and node_id is not None:
            n_emb = self.node_embedding(node_id)
            x = torch.cat([x, n_emb], 1)
        for conv in self.convs[:-1]:
            x = conv(x, edge_index, edge_weight)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index, edge_weight)
        if not inference:  # center pooling
            _, center_indices = np.unique(batch.cpu().numpy(), return_index=True)
            x_src = x[center_indices]
            x_dst = x[center_indices + 1]
            x = (x_src * x_dst)
            x = F.relu(self.lin1(x))
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = self.lin2(x)
        else:  
            x_src = x[0]
            x_dst = x[1]
            x = (x_src * x_dst)
            x = F.relu(self.lin1(x))
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = self.lin2(x)

        return x

In [195]:
def inference(model: GCN, dataset: torch.tensor, node_pair: tuple, all_pos_edges: torch.tensor, THRESHOLD: float=0.5):
    torch.manual_seed(0)

    model.eval()
    logits = model(dataset.z, dataset.edge_index, dataset.batch, dataset.x, dataset.edge_weight, inference=True)
    logits = Sigmoid()(logits.view(-1))
    logits[logits >= THRESHOLD] = 1
    logits[logits < THRESHOLD] = 0
    logits = logits.cpu().detach()
    model_inferred = int(logits[0].item())
    observed = 0
    mask_s = all_pos_edges[0] == node_pair[0]
    mask_t = all_pos_edges[1] == node_pair[1]
    mask = mask_s * mask_t
    if any(mask):
        observed = 1
    
    return observed, model_inferred
    

In [200]:
# start inference
# TODO: change it!
node_pair = (522, 55)

MODEL_PATH = 'model_states'
MODEL_NAME = 'SEAL_model_checkpoint_final_model'
HIDDEN_CHANNEL = 256
NUM_LAYERS = 3
MAX_Z = 1000
USE_FEATURE = True
NODE_EMBEDDINGS=None
all_pos_edges, dataset = load_and_process_data(node_pair)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# load and initialize model
model = GCN(HIDDEN_CHANNEL, NUM_LAYERS, MAX_Z, dataset, 
            USE_FEATURE, node_embedding=NODE_EMBEDDINGS).to(device)
model_state = torch.load(os.path.join(MODEL_PATH, MODEL_NAME))
model.load_state_dict(model_state)

observed, model_inferred = inference(model, dataset, node_pair, all_pos_edges)
print()
print(f'observed: {observed} -> inferred: {model_inferred}')


100%|██████████| 1/1 [00:00<00:00, 329.51it/s]


observed: 0 -> inferred: 1





In [197]:
import json
d_path = '../datasets/actor_id_dict.json'
with open(d_path, 'r') as f:
    actors = json.load(f)
print(actors['MichelleYeoh'])
print('Freeman', actors['MorganFreeman'])

# print(actors['AnneHathaway'])
# print(actors['RyanGosling'])
print('Leo', actors['LeonardoDiCaprio'])
print('Jim', actors['JimCarrey'])
print('Ryan', actors['RyanGosling'])
print('Hanks', actors['TomHanks'])
print()


# action
print('Jolie', actors['AngelinaJolie'])
print('Tom', actors['TomCruise'])
print('Jason', actors['JasonStatham'])
print('Vin', actors['VinDiesel'])
print('Rock', actors['DwayneJohnson'])
print('Gal', actors['GalGadot'])


963
Freeman 434
Leo 43
Jim 86
Ryan 165
Hanks 169

Jolie 332
Tom 51
Jason 208
Vin 615
Rock 354
Gal 1316


In [198]:
list(actors.keys())[30:50]

['JulieWalters',
 'MakrandDeshpande',
 'SachiinJoshi',
 'PaulWalker',
 'SteveZahn',
 'ChristianBale',
 'HeathLedger',
 'ElijahWood',
 'IanMcKellen',
 'MatthewMcConaughey',
 'AnneHathaway',
 'AdrienBrody',
 'ThomasKretschmann',
 'LeonardoDiCaprio',
 'JosephGordon-Levitt',
 'AntonioBanderas',
 'SalmaHayek',
 'HonorKneafsey',
 'EvaWhittaker',
 'HaleyLuRichardson']

In [183]:
sorted([name if name[:3] == 'Tom' else '' for name in sorted(list(actors.keys()))], reverse=True)
# sorted(list(actors.keys()))[1][0]

['TommyLeeJones',
 'TommyKnight',
 'TomWilkinson',
 'TomSturridge',
 'TomSizemore',
 'TomNoonan',
 'TomMcGrath',
 'TomKenny',
 'TomKane',
 'TomHolland',
 'TomHiddleston',
 'TomHardy',
 'TomHanks',
 'TomCruise',
 'TomCourtenay',
 'TomBateman',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '