In [1]:
%load_ext autoreload
%autoreload 2

import os, sys
import numpy as np
import pandas as pd
import glob
from collections import Counter
import collections, itertools
import networkx as nx
import regex as re

PROJ_PATH = os.path.join(re.sub("/heterogeneous.*$", '', os.getcwd()), 'heterogeneous_subgraph_representation_for_team_discovery')
sys.path.insert(1, os.path.join(PROJ_PATH, 'src'))

from train_config import *
from datasets import SubgraphDataset

from sklearn.cluster import KMeans
from sklearn.decomposition import PCA, TruncatedSVD

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
%matplotlib inline
plt.style.use('ggplot')

import torch
import train as tr
import subgraph_utils
from pathlib import Path
from src import HSGNN
from torch.utils.data import DataLoader, Dataset
from datasets import SubgraphDataset

class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

In [2]:
class InferenceAgent:
    def __init__(self,
                 common_args={},
                 pred_args={'task':'dblp_v82',
                            'restoreModelPath':'',
                            'restoreModelName':''},
                 device='cuda',
                 dropout=False,
                ):
        default_common_args = {
            "max_epochs" : 300,
            "tb_dir" : os.path.join(PROJ_PATH, 'dataset/tensorboard/'),
            'trial': None,
            'runTest': True,
            "no_checkpointing" : False, #0 and True or 1 and False
            "runTest": True, 
            "no_save": True,
            "debug_mode": False,
            "subset_data": False,
            "noTrain": True,
            "log_path": None,
        }
        default_common_args.update(common_args)
        common_args = default_common_args
        dict_args = {**common_args, **pred_args}
        args = Namespace(**dict_args)
        args.tb_name = 'S_' + args.task + '_optuna'
        args.config_path = os.path.join(PROJ_PATH, 'HSGNN/config_files/', args.task, 'config.json')
        
       
        self.model, self.hyperparameters = tr.build_model(args)
        self.trainer, _, _ = tr.build_trainer(args, self.hyperparameters)

        run_config = read_json(args.config_path)
        run_config['no_cuda'] = True
        if 'local' in run_config['tb'] and run_config['tb']['local']:
            run_config['tb']['dir_full'] = run_config['tb']['dir']
        else:
            run_config['tb']['dir_full'] = os.path.join(config.PROJECT_ROOT, run_config['tb']['dir'])
#         trainer, trainer_kwargs, results_path = tr.build_trainer(args, hyperparameters)

        random.seed(self.hyperparameters['seed'])
        torch.manual_seed(self.hyperparameters['seed'])
        np.random.seed(self.hyperparameters['seed'])
        torch.cuda.manual_seed(self.hyperparameters['seed'])
        torch.cuda.manual_seed_all(self.hyperparameters['seed']) 
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
        self.dropout = dropout
        self.device = device
        map_location = lambda storage, loc: storage.cuda()
        checkpoint = torch.load(Path(args.restoreModelPath)/args.restoreModelName,
                               map_location=map_location)
        self.model_dict = self.model.state_dict()
        self.pretrain_dict = {k: v for k, v in checkpoint['state_dict'].items() if k in self.model_dict}
        self.model.load_state_dict(self.pretrain_dict)
        self.model.to(self.device)

    def train_dataset(self):
        print('TRAIN DATASET')
        self.model.prepare_data()
        dataset = SubgraphDataset(
            self.model.train_sub_G, self.model.train_sub_G_label, self.model.train_cc_ids,
            self.model.train_N_border, self.model.train_neigh_pos_similarities, self.model.train_int_struc_similarities,
            self.model.train_bor_struc_similarities, self.model.multilabel, self.model.multilabel_binarizer)
        loader = DataLoader(
            dataset, batch_size=len(dataset), shuffle=False, collate_fn=self.model._pad_collate)
        return loader

    def val_dataset(self):
        print('VAL DATASET')
        dataset = SubgraphDataset(
            self.model.val_sub_G, self.model.val_sub_G_label, self.model.val_cc_ids,
            self.model.val_N_border, self.model.val_neigh_pos_similarities, self.model.val_int_struc_similarities,
            self.model.val_bor_struc_similarities, self.model.multilabel, self.model.multilabel_binarizer)
        loader = DataLoader(
            dataset, batch_size=len(dataset), shuffle=False, collate_fn=self.model._pad_collate)
        return loader

    def test_dataset(self):
        self.model.prepare_test_data()
        print('TEST DATASET')
        dataset = SubgraphDataset(
            self.model.test_sub_G, self.model.test_sub_G_label, self.model.test_cc_ids,
            self.model.test_N_border, self.model.test_neigh_pos_similarities, self.model.test_int_struc_similarities,
            self.model.test_bor_struc_similarities, self.model.multilabel, self.model.multilabel_binarizer)
        loader = DataLoader(
            dataset, batch_size=len(dataset), shuffle=False, collate_fn=self.model._pad_collate)
        return loader
    
    def single_dataset_inference(self, loader, dataset_type='train'):
        for i, ds in enumerate(loader):
            subgraph_ids = ds['subgraph_ids']
            cc_ids = ds['cc_ids']
            subgraph_idx = ds['subgraph_idx']
            labels = ds['label'].squeeze(-1)
            # get similarities for batch
            NP_sim = ds['NP_sim']
            I_S_sim = ds['I_S_sim']
            B_S_sim = ds['B_S_sim']
        if dataset_type == 'train':
            N_I_cc_embed = self.model.train_N_I_cc_embed
            N_B_cc_embed = self.model.train_N_B_cc_embed
            S_I_cc_embed = self.model.train_S_I_cc_embed
            S_B_cc_embed = self.model.train_S_B_cc_embed
            P_I_cc_embed = self.model.train_P_I_cc_embed
            P_B_cc_embed = self.model.train_P_B_cc_embed
        elif dataset_type == 'val':
            N_I_cc_embed = self.model.val_N_I_cc_embed
            N_B_cc_embed = self.model.val_N_B_cc_embed
            S_I_cc_embed = self.model.val_S_I_cc_embed
            S_B_cc_embed = self.model.val_S_B_cc_embed
            P_I_cc_embed = self.model.val_P_I_cc_embed
            P_B_cc_embed = self.model.val_P_B_cc_embed
        elif dataset_type == 'test':
            N_I_cc_embed = self.model.test_N_I_cc_embed
            N_B_cc_embed = self.model.test_N_B_cc_embed
            S_I_cc_embed = self.model.test_S_I_cc_embed
            S_B_cc_embed = self.model.test_S_B_cc_embed
            P_I_cc_embed = self.model.test_P_I_cc_embed
            P_B_cc_embed = self.model.test_P_B_cc_embed
        if self.dropout:
            embedding = self.model.myforward(
                dataset_type, N_I_cc_embed, N_B_cc_embed,
                S_I_cc_embed, S_B_cc_embed, P_I_cc_embed,
                P_B_cc_embed,
                subgraph_ids.to(self.device), 
                cc_ids.to(self.device), 
                subgraph_idx.to(self.device), 
                NP_sim, 
                I_S_sim.to(self.device), 
                B_S_sim.to(self.device),
            )
        else:
            embedding = self.model.myforward_2(
                dataset_type, N_I_cc_embed, N_B_cc_embed,
                S_I_cc_embed, S_B_cc_embed, P_I_cc_embed,
                P_B_cc_embed,
                subgraph_ids.to(self.device), 
                cc_ids.to(self.device), 
                subgraph_idx.to(self.device), 
                NP_sim, 
                I_S_sim.to(self.device), 
                B_S_sim.to(self.device),
            )
        output = embedding.cpu().detach().numpy().tolist()
        print(f'Number of samples {len(output)}')
        return output
    
    def inference(self):
        train_ds = self.train_dataset()
        val_ds = self.val_dataset()
        test_ds = self.test_dataset()
        self.model.eval()
        output = []
        with torch.no_grad():
            # train
            output += self.single_dataset_inference(train_ds, dataset_type='train')
            output += self.single_dataset_inference(val_ds, dataset_type='val')
            output += self.single_dataset_inference(test_ds, dataset_type='test')
        self.output = output
        
def save_embedding(output, ds_name, output_path, save_name='pid2vec_128.pkl'):
    ds2pids = pd.read_pickle(os.path.join(PROJ_PATH, 'dataset', ds_name, 'ds2pids.pkl')) 
    pid = pd.read_pickle(os.path.join('/home/hoang/github/heterogeneous_subgraph_representation_for_team_discovery/dataset', ds_name, 'id.pkl')) 
    
    assert len(output) == len(pid)
    pid2vec = dict(zip(pid, output))
    
    if save_name is not None:
        if not os.path.exists(output_path):
            os.mkdir(output_path)
        save_path = os.path.join(output_path, save_name)
        print(f'Save to {save_path}')
        pd.to_pickle(pid2vec, save_path)
    return pid2vec

## DBLP

In [3]:
pred_args_fold_1 = {
    "task" : 'fold_1',
    "restoreModelPath" : '/home/hoang/github/heterogeneous_subgraph_representation_for_team_discovery/dataset/tensorboard/fold_1_optuna/version_5',
    "restoreModelName": 'epoch=6-val_micro_f1=0.31-val_acc=0.31-val_auroc=0.86.ckpt',
}

pred_args_fold_2 = {
    "task" : 'fold_2',
    "restoreModelPath" : '/home/hoang/github/heterogeneous_subgraph_representation_for_team_discovery/dataset/tensorboard/fold_2_optuna/version_4',
    "restoreModelName": 'epoch=2-val_micro_f1=0.34-val_acc=0.34-val_auroc=0.87.ckpt',
}

pred_args_fold_3 = {
    "task" : 'fold_3',
    "restoreModelPath" : '/home/hoang/github/heterogeneous_subgraph_representation_for_team_discovery/dataset/tensorboard/fold_3_optuna/version_8',
    "restoreModelName": 'epoch=3-val_micro_f1=0.34-val_acc=0.34-val_auroc=0.89.ckpt',
}

pred_args_fold_4 = {
    "task" : 'fold_4',
    "restoreModelPath" : '/home/hoang/github/heterogeneous_subgraph_representation_for_team_discovery/dataset/tensorboard/fold_4_optuna/version_5',
    "restoreModelName": 'epoch=4-val_micro_f1=0.32-val_acc=0.32-val_auroc=0.88.ckpt',
}

pred_args_fold_5 = {
    "task" : 'fold_5',
    "restoreModelPath" : '/home/hoang/github/heterogeneous_subgraph_representation_for_team_discovery/dataset/tensorboard/fold_5_optuna/version_6',
    "restoreModelName": 'epoch=4-val_micro_f1=0.30-val_acc=0.30-val_auroc=0.87.ckpt',
}

pred_args_fold_6 = {
    "task" : 'fold_6',
    "restoreModelPath" : '/home/hoang/github/heterogeneous_subgraph_representation_for_team_discovery/dataset/tensorboard/fold_6_optuna/version_6',
    "restoreModelName": 'epoch=3-val_micro_f1=0.32-val_acc=0.32-val_auroc=0.89.ckpt',
}

pred_args_fold_7 = {
    "task" : 'fold_7',
    "restoreModelPath" : '/home/hoang/github/heterogeneous_subgraph_representation_for_team_discovery/dataset/tensorboard/fold_7_optuna/version_3',
    "restoreModelName": 'epoch=4-val_micro_f1=0.31-val_acc=0.31-val_auroc=0.88.ckpt',
}


pred_args_fold_8 = {
    "task" : 'fold_8',
    "restoreModelPath" : '/home/hoang/github/heterogeneous_subgraph_representation_for_team_discovery/dataset/tensorboard/fold_8_optuna/version_4',
    "restoreModelName": 'epoch=2-val_micro_f1=0.34-val_acc=0.34-val_auroc=0.87.ckpt',
}


pred_args_fold_9 = {
    "task" : 'fold_9',
    "restoreModelPath" : '/home/hoang/github/heterogeneous_subgraph_representation_for_team_discovery/dataset/tensorboard/fold_9_optuna/version_9',
    "restoreModelName": 'epoch=3-val_micro_f1=0.34-val_acc=0.34-val_auroc=0.88.ckpt',
}


pred_args_fold_10 = {
    "task" : 'fold_10',
    "restoreModelPath" : '/home/hoang/github/heterogeneous_subgraph_representation_for_team_discovery/dataset/tensorboard/fold_10_optuna/version_9',
    "restoreModelName": 'epoch=4-val_micro_f1=0.33-val_acc=0.33-val_auroc=0.88.ckpt',
}


In [None]:
ls_args = [
    pred_args_fold_1, pred_args_fold_2, pred_args_fold_3, 
    pred_args_fold_4, pred_args_fold_5, 
    pred_args_fold_6, 
    pred_args_fold_7, pred_args_fold_8, pred_args_fold_9, pred_args_fold_10,
]
for i, a in enumerate(ls_args):
    print('#####################################################################################')
    pred_args = a.copy()
    agent = InferenceAgent(pred_args=pred_args, dropout=True)
    agent.inference()
    output = agent.output
    output_path = '/media/HSGNN/output'
    M = 'singlepaper'
    D = pred_args['task']
    S = '128'
    F = pred_args['restoreModelName'].split('.ckpt')[0]
    save_name = 'M={}_D={}_S={}_F={}.pkl'.format(M, D, S, F)
    pid2vec = save_embedding(output, ds_name=D, output_path=output_path, save_name=save_name)

## IMDB

In [3]:
pred_args_fold_1 = {
    "task" : 'imdb_fold_1',
    "restoreModelPath" : '/home/hoang/github/heterogeneous_subgraph_representation_for_team_discovery/dataset/tensorboard/imdb_fold_1/version_2532099',
    "restoreModelName": 'epoch=71-val_micro_f1=0.99-val_acc=0.95-val_auroc=1.00.ckpt',
}

pred_args_fold_2 = {
    "task" : 'imdb_fold_2',
    "restoreModelPath" : '/home/hoang/github/heterogeneous_subgraph_representation_for_team_discovery/dataset/tensorboard/imdb_fold_2/version_8961335',
    "restoreModelName": 'epoch=95-val_micro_f1=0.98-val_acc=0.94-val_auroc=1.00.ckpt',
}

pred_args_fold_3 = {
    "task" : 'imdb_fold_3',
    "restoreModelPath" : '/home/hoang/github/heterogeneous_subgraph_representation_for_team_discovery/dataset/tensorboard/imdb_fold_3/version_4916345',
    "restoreModelName": 'epoch=73-val_micro_f1=0.98-val_acc=0.94-val_auroc=1.00.ckpt',
}

pred_args_fold_4 = {
    "task" : 'imdb_fold_4',
    "restoreModelPath" : '/home/hoang/github/heterogeneous_subgraph_representation_for_team_discovery/dataset/tensorboard/imdb_fold_4/version_5017442',
    "restoreModelName": 'epoch=59-val_micro_f1=0.99-val_acc=0.95-val_auroc=1.00.ckpt',
}


pred_args_fold_5 = {
    "task" : 'imdb_fold_5',
    "restoreModelPath" : '/home/hoang/github/heterogeneous_subgraph_representation_for_team_discovery/dataset/tensorboard/imdb_fold_5/version_7655999',
    "restoreModelName": 'epoch=73-val_micro_f1=0.98-val_acc=0.93-val_auroc=1.00.ckpt',
}


pred_args_fold_6 = {
    "task" : 'imdb_fold_6',
    "restoreModelPath" : '/home/hoang/github/heterogeneous_subgraph_representation_for_team_discovery/dataset/tensorboard/imdb_fold_6/version_8773398',
    "restoreModelName": 'epoch=82-val_micro_f1=0.99-val_acc=0.95-val_auroc=1.00.ckpt',
}


pred_args_fold_7 = {
    "task" : 'imdb_fold_7',
    "restoreModelPath" : '/home/hoang/github/heterogeneous_subgraph_representation_for_team_discovery/dataset/tensorboard/imdb_fold_7/version_7815933',
    "restoreModelName": 'epoch=96-val_micro_f1=0.99-val_acc=0.94-val_auroc=1.00.ckpt',
}


pred_args_fold_8 = {
    "task" : 'imdb_fold_8',
    "restoreModelPath" : '/home/hoang/github/heterogeneous_subgraph_representation_for_team_discovery/dataset/tensorboard/imdb_fold_8/version_9421881',
    "restoreModelName": 'epoch=90-val_micro_f1=0.98-val_acc=0.90-val_auroc=1.00.ckpt',
}


pred_args_fold_9 = {
    "task" : 'imdb_fold_9',
    "restoreModelPath" : '/home/hoang/github/heterogeneous_subgraph_representation_for_team_discovery/dataset/tensorboard/imdb_fold_9/version_5985613',
    "restoreModelName": 'epoch=61-val_micro_f1=0.99-val_acc=0.94-val_auroc=1.00.ckpt',
}


pred_args_fold_10 = {
    "task" : 'imdb_fold_10',
    "restoreModelPath" : '/home/hoang/github/heterogeneous_subgraph_representation_for_team_discovery/dataset/tensorboard/imdb_fold_10/version_4329351',
    "restoreModelName": 'epoch=93-val_micro_f1=0.99-val_acc=0.96-val_auroc=1.00.ckpt',
}


In [None]:
ls_args = [
    pred_args_fold_1, pred_args_fold_2, pred_args_fold_3, 
    pred_args_fold_4, pred_args_fold_5, pred_args_fold_6, 
    pred_args_fold_7, pred_args_fold_8, pred_args_fold_9, 
    pred_args_fold_10,
]
for i, a in enumerate(ls_args):
    print('#####################################################################################')
    pred_args = a.copy()
    agent = InferenceAgent(pred_args=pred_args, dropout=True)
    agent.inference()
    output = agent.output
    output_path = '/media/HSGNN/output'
    M = 'singlepaper'
    D = pred_args['task']
    S = '128'
    F = pred_args['restoreModelName'].split('.ckpt')[0]
    save_name = 'M={}_D={}_S={}_F={}.pkl'.format(M, D, S, F)
    pid2vec = save_embedding(output, ds_name=D, output_path=output_path, save_name=save_name)