In [1]:
import os, sys
import numpy as np
import pandas as pd
from collections import Counter
import collections, itertools
import networkx as nx
import regex as re

PROJ_PATH = os.path.join(re.sub("/heterogeneous.*$", '', os.getcwd()), 'heterogeneous_subgraph_representation_for_team_discovery')
sys.path.insert(1, os.path.join(PROJ_PATH, 'src'))

from train_config import *
from datasets import SubgraphDataset

from sklearn.cluster import KMeans
from sklearn.decomposition import PCA, TruncatedSVD

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
%matplotlib inline
plt.style.use('ggplot')

In [2]:
class MySubGraphs():
    '''
    Args:
        cid2cname: dictionary to map conference id to conference name which is used as the label of subgraph
        n_subgraphs: number of subgraph for each label
        n_nodes_in_subgraph: number of nodes in each subgraph
        tvt (default: 60, 20, 20): ratio for training, validate and testing set
        shuffle (default=True): whether to shuffle the order of subgraphs
    Return:
     A subgraph object with:
        sub_G: list of subgraphs
        sub_G_label: list of subgraphs label
        mask: dataset indicator (whether data used for training/validation/testing)
    '''
    def __init__(
        self, 
        PA_edges, 
        PT_edges,
        PC_edges,
        cid2cname,
        n_subgraphs,
        n_nodes_in_subgraph,
        use_venue=True,
        tvt=[60, 20, 20],
        shuffle=True):
        
        assert sum(tvt) == 100, 'Sum of TVT should be 100!'
        
        self.PA_edges = PA_edges
        self.PT_edges = PT_edges
        self.PC_edges = PC_edges
        self.edges = pd.concat([self.PA_edges, self.PT_edges, self.PC_edges])
        
        self.cid2cname = cid2cname
        self.n_subgraphs = n_subgraphs
        self.n_nodes_in_subgraph = n_nodes_in_subgraph
        self.use_venue = use_venue
        self.tvt = tvt
        self.shuffle = shuffle
        self.n_groups = len(cid2cname)
        
    def create_base_graph(self):
        '''
        Create a base graph for each label before sampling subgraphs
            - cid: conference id
            - cname: conference name
        '''
        G_dict = {}
        for cid, cname in self.cid2cname.items():
            print(f'- Creating base graph for group {cname}')
            # filter paper
            PC_edges_filtered = self.PC_edges[self.PC_edges['target']==cid]
            papers_filtered = list(PC_edges_filtered['source'].unique())
            
            # filter author
            PA_edges_filtered = self.PA_edges[self.PA_edges['source'].isin(papers_filtered)]
            
            # filter term
            PT_edge_filtered = self.PA_edges[self.PA_edges['source'].isin(papers_filtered)]
            
            # create networkx graph
            if self.use_venue:
                df_edges_filtered = pd.concat([PC_edges_filtered, PA_edges_filtered, PT_edge_filtered])
            else:
                df_edges_filtered = pd.concat([PA_edges_filtered, PT_edge_filtered])
            G_i = nx.from_pandas_edgelist(df_edges_filtered)
            G_dict[cname] = G_i
            
        return G_dict
    
    def get_subgraphs_randomly(self, G_dict):
        """
        Randomly generates subgraphs of size n_nodes_in_subgraph
        Args
            - n_subgraphs (int): number of subgraphs
            - n_nodes_in_subgraph (int): number of nodes in each subgraph
        Return
            - subgraphs (list of lists): list of subgraphs, where each subgraph is a list of nodes
        """

        sub_G = []
        sub_G_label = []
        random.seed(0)
        
        print(f'- Sampling {self.n_subgraphs*self.n_groups} subgraphs')
        for cname, G in G_dict.items():
            print(f'- Sampling {self.n_subgraphs} subgraphs for group {cname}')
            for s in range(self.n_subgraphs):
                n_nodes = min(len(G.nodes), self.n_nodes_in_subgraph)
                sampled_nodes = random.sample(G.nodes, n_nodes)
                sub_G.append(sampled_nodes)
                sub_G_label.append(cname)
                
        if self.shuffle:
            tmp = list(zip(sub_G, sub_G_label))
            random.shuffle(tmp)
            sub_G, sub_G_label = zip(*tmp)
        return sub_G, sub_G_label
     
    def get_train_val_test(self):
        self.n_samples = self.n_subgraphs * self.n_groups
        n_train = int(self.n_samples * self.tvt[0] / 100)
        n_val = int(self.n_samples * self.tvt[1] / 100)
        n_test = self.n_samples - n_train - n_val
        mask = [0] * n_train + [1] * n_val + [2] * n_test
        return mask
    
    def sample_subgraphs(self):
        print('Creating base graphs')
        self.G_dict = self.create_base_graph()
        print('Sampling subgraphs')
        self.sub_G, self.sub_G_label = self.get_subgraphs_randomly(self.G_dict)
        self.mask = self.get_train_val_test()  

        
class MySubGraphs_v2():
    '''
    Args:
        cid2cname: dictionary to map conference id to conference name which is used as the label of subgraph
        n_subgraphs: number of subgraph for each label
        n_nodes_in_subgraph: number of nodes in each subgraph
        tvt (default: 60, 20, 20): ratio for training, validate and testing set
        shuffle (default=True): whether to shuffle the order of subgraphs
    Return:
     A subgraph object with:
        sub_G: list of subgraphs
        sub_G_label: list of subgraphs label
        mask: dataset indicator (whether data used for training/validation/testing)
    '''
    def __init__(
        self, 
        PA_edges,
        PT_edges,
        PC_edges, 
        cid2cname,
        ds2pids,
        remove_pa_in_test=True,
        shuffle=True,
        seed=0,
    ):
        
        self.PA_edges = PA_edges
        self.PT_edges = PT_edges
        self.PC_edges = PC_edges
        self.paper_nodes = list(PA_edges['source'].unique())
        self.author_nodes = list(PA_edges['target'].unique())
        self.term_nodes = list(PT_edges['target'].unique())
        self.conference_nodes = list(PC_edges['target'].unique())
        
        self.cid2cname = cid2cname # conference id to conference name
        self.ds2pids = ds2pids
        self.train_pids = ds2pids['train']
        self.val_pids = ds2pids['val']
        self.test_pids = ds2pids['test']
        self.remove_pa_in_test = remove_pa_in_test
        self.shuffle = shuffle
        self.seed = seed
        self.n_groups = len(cid2cname)
        print(f'Number of train: {len(self.train_pids)}')
        print(f'Number of val: {len(self.val_pids)}')
        print(f'Number of test: {len(self.test_pids)}')
        
    def get_subgraphs_based_on_paper(self):
        """
        Randomly generates subgraphs of size n_nodes_in_subgraph
        Args
            - n_subgraphs (int): number of subgraphs
            - n_nodes_in_subgraph (int): number of nodes in each subgraph
        Return
            - subgraphs (list of lists): list of subgraphs, where each subgraph is a list of nodes
        """
        random.seed(self.seed)
        
        sub_G = []
        sub_G_label = []
        mask_G = []
        id_G = []
        
        print(f'- Sampling {len(self.paper_nodes)} subgraphs')
        for pid in self.paper_nodes:
            PC_edges_filtered = self.PC_edges[self.PC_edges['source']==pid]
            cids = PC_edges_filtered['target'].unique().tolist()
            cnames = '-'.join([cid2cname[i].replace('-', '') for i in cids])
            
            sampled_paper_nodes = [pid]
            id_G.append(pid) # for prediction phase
            sampled_term_nodes = list(self.PT_edges[self.PT_edges['source'].isin(sampled_paper_nodes)]['target'].unique())
            sampled_author_nodes = list(self.PA_edges[self.PA_edges['source'].isin(sampled_paper_nodes)]['target'].unique())
         
            if pid in self.test_pids and self.remove_pa_in_test:
                sampled_nodes = sampled_paper_nodes + sampled_term_nodes
            else:        
                sampled_nodes = sampled_author_nodes + sampled_paper_nodes + sampled_term_nodes

            sub_G.append(sampled_nodes)
            sub_G_label.append(cnames)

            if pid in self.train_pids:
                mask_G.append(0)
            elif pid in self.val_pids:
                mask_G.append(1)
            elif pid in self.test_pids:
                mask_G.append(2)

        if self.shuffle:
            print('-- Shuffle data')
            tmp = list(zip(sub_G, sub_G_label, mask_G, id_G))
            random.shuffle(tmp)
            tmp = sorted(tmp, key=lambda x: x[2])
            sub_G, sub_G_label, mask_G, id_G = zip(*tmp)
        return sub_G, sub_G_label, mask_G, id_G
    
    def get_node_types(self):
        # author --> paper --> term --> conference
        no_author = len(self.author_nodes)
        no_paper = len(self.paper_nodes)
        no_term = len(self.term_nodes)
        node_id = range(no_author + no_paper + no_term)
        node_type = [0] * no_author + [1] * no_paper + [2] * no_term
        nodetype_mapping = {
            0: 'author',
            1: 'paper',
            2: 'term',
        }
        df_node_types = pd.DataFrame({'node_id': node_id, 'node_type': node_type})
        df_node_types['node_type_name'] = df_node_types['node_type'].map(nodetype_mapping)
        return df_node_types

    def get_edges(self):
        if self.remove_pa_in_test:
            self.PA_edges_filtered = self.PA_edges[~self.PA_edges['target'].isin(self.test_pids)]
        else:
            self.PA_edges_filtered = self.PA_edges.copy()
        df_edges = pd.concat([self.PT_edges, self.PA_edges_filtered])
        print(f'Number of edges {df_edges.shape[0]}')
        return df_edges
        
    def sample_subgraphs(self):
        print('Start sampling subgraphs ...')
        self.sub_G, self.sub_G_label, self.mask_G, self.id_G = self.get_subgraphs_based_on_paper()
        
        print('Start getting nodes ...')
        self.df_node_types = self.get_node_types()
        
        print('Start getting edges ...')
        self.df_edges = self.get_edges()
                 
def write_subgraph(sub_f, sub_G, sub_G_label, mask):
    """
    Write subgraph information into the appropriate format for HSGNN (tab-delimited file where each row
    has dash-delimited nodes, subgraph label, and train/val/test label).
    Args
        - sub_f (str): file directory to save subgraph information
        - sub_G (list of lists): list of subgraphs, where each subgraph is a list of nodes
        - sub_G_label (list): subgraph labels
        - mask (list): 0 if subgraph is in train set, 1 if in val set, 2 if in test set
    """

    with open(sub_f, "w") as fout:
        for g, l, m in zip(sub_G, sub_G_label, mask):
            g = [str(val) for val in g]
            if len(g) == 0: continue
            if m == 0: fout.write("\t".join(["-".join(g), str(l), "train", "\n"]))
            elif m == 1: fout.write("\t".join(["-".join(g), str(l), "val", "\n"]))
            elif m == 2: fout.write("\t".join(["-".join(g), str(l), "test", "\n"]))

In [3]:
folds = pd.read_pickle('/media/HSGNN/dataset/Train_Test_indices_V2.2.pkl')
entityID_map = np.genfromtxt('/media/HSGNN/dataset/V2_2/entity_id_mapping.csv', delimiter=",", dtype=str)
paperIDmap = {int(i[2]):int(i[3]) for i in entityID_map[1:] if i[1]=="paper"}
print(folds.keys())

dict_keys([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])


In [None]:
seed = 3579
random.seed(seed)
val_ratio = 0.05
n_folds = 10
for i in range(1, n_folds+1):
    print('############################################################')
    print(f'Prepare tvt file for fold: {i}')
    experiment = f'fold_{i}'
    train_ids_origin = folds[i]['Train']
    test_ids_origin = folds[i]['Test']
    trainval_ids = [paperIDmap[idx] for idx in train_ids_origin]
    test_pids = [paperIDmap[idx] for idx in test_ids_origin]
    random.shuffle(trainval_ids)
    n_trainval = len(trainval_ids)
    n_val = int(n_trainval * val_ratio)
    train_pids = trainval_ids[:-n_val]
    val_pids = trainval_ids[-n_val:]
    
    assert min(train_pids+val_pids+test_pids) == 1840
    assert max(train_pids+val_pids+test_pids) == (1840+10674-1)
    assert len(set(train_pids+val_pids+test_pids)) == 10674
    
    fold_dir = os.path.join(PROJ_PATH, 'dataset', experiment)
    if not os.path.exists(fold_dir): 
        os.mkdir(fold_dir)
    print(f'Number of train: {len(train_pids)}')
    print(f'Number of val: {len(val_pids)}')
    print(f'Number of test: {len(test_pids)}')
    ds2pids = {
        'train': train_pids,
        'val': val_pids,
        'test': test_pids,
    }
    if False:
        fname = os.path.join(fold_dir, 'ds2pids.pkl')
        print(f'Save to: {fname}')
        pd.to_pickle(ds2pids, fname)

In [None]:
n_folds = 10
seed = 3579
remove_pa_in_test = True

for i in range(1, n_folds+1):
    print('############################################################')
    print(f'Processing fold {i}')
    experiment = f'fold_{i}'
    DATA_DIR = '/media/HSGNN/dataset/V2_2'

    PA_edges = pd.read_csv(os.path.join(DATA_DIR, 'PA_edges.csv'), index_col=None)
    PC_edges = pd.read_csv(os.path.join(DATA_DIR, 'PC_edges.csv'), index_col=None)
    PT_edges = pd.read_csv(os.path.join(DATA_DIR, 'PT_edges.csv'), index_col=None)
    cid2cname = pd.read_pickle(os.path.join(PROJ_PATH, 'dataset/dblp_v8/cid2cname.pkl'))
    ds2pids = pd.read_pickle(os.path.join(PROJ_PATH, 'dataset', experiment, 'ds2pids.pkl')) 

    subgraph = MySubGraphs_v2(
        PA_edges,
        PT_edges,
        PC_edges, 
        cid2cname,
        ds2pids,
        remove_pa_in_test=remove_pa_in_test,
        shuffle=True,
        seed=seed,
    )
    subgraph.sample_subgraphs()
    if False:
        sub_G, sub_G_label, mask_G = subgraph.sub_G, subgraph.sub_G_label, subgraph.mask_G
        df_node_types, df_edges = subgraph.df_node_types, subgraph.df_edges
        id_G = subgraph.id_G
        save_path = os.path.join(PROJ_PATH, 'dataset', experiment, 'subgraphs.pth')
        print(f'Save subgraph to {save_path}')
        write_subgraph(save_path, sub_G, sub_G_label, mask_G)

        save_path = os.path.join(PROJ_PATH, 'dataset', experiment, 'node_types.csv')
        print(f'Save node types to {save_path}')
        df_node_types.to_csv(save_path, index=False)

        save_path = os.path.join(PROJ_PATH, 'dataset', experiment, 'edge_list.txt')
        print(f'Save edges to {save_path}')
        df_edges.to_csv(save_path, header=None, index=None, sep=' ')

        save_path = os.path.join(PROJ_PATH, 'dataset', experiment, 'id.pkl')
        print(f'Save id to {save_path}')
        pd.to_pickle(id_G, save_path)