In [None]:
import argparse
import resource
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.nn as dglnn
import time
import json
from torch_geometric.utils import negative_sampling
from torch_geometric.seed import seed_everything
import torch.multiprocessing
import pandas as pd
import numpy as np
import tqdm
import networkx as nx

from combsage.combsage import IDRSAGEJK
from combsage.graphsage import SAGE

from combsage.utils import evaluate, convert_to_heterograph_group_isolates
from combsage.utils import HetEdgePredictionSampler, HomoNeighborSampler

import os

In [None]:
def load_data(year): 
    
    # load graph
    G = nx.read_gexf('citation_graphs/{}.gexf'.format(year))
    G_i = nx.convert_node_labels_to_integers(G, label_attribute='paper_id')
    
    # load embeddings
    with open('scibert_embeddings.json') as infile:
        emb = json.load(infile)
    emb = pd.Series(emb)

    # get embedding subset as matrix
    node_df = pd.DataFrame.from_dict(dict(G_i.nodes(data=True)), orient='index')
    X = np.vstack(emb.loc[node_df.paper_id].values)
    
    # get source and target vectors to create dgl graph 
    edge_list = nx.to_pandas_edgelist(G_i)
    src = edge_list['source'].to_numpy()
    dst = edge_list['target'].to_numpy()

    # create dgl graph -- add attributes
    g = dgl.graph((src,dst))
    g.ndata['feat'] = torch.tensor(X).float()
    
    return node_df,g

In [None]:
config = {'r1': 15, 'r2': 10, 'lr': 0.0001, 'batch_size': 256, 'dropout':0.1}
year = 2014

In [None]:
node_df,g = load_data(year)

In [None]:
device = 'cpu'

In [None]:
neg_edge_index = negative_sampling(edge_index=torch.vstack(g.edges()),
                                num_nodes=g.number_of_nodes(), 
                                num_neg_samples=g.number_of_edges())

u, v = g.edges()

eids = np.arange(g.number_of_edges())
eids = np.random.permutation(eids)
val_size = int(len(eids) * 0.1)
val_pos_u, val_pos_v = u[eids[:val_size]], v[eids[:val_size]]

# Find all negative edges and split them for training and testing
neg_u, neg_v = neg_edge_index[0], neg_edge_index[1]

neg_eids = np.random.choice(len(neg_u), g.number_of_edges())
val_neg_u, val_neg_v = neg_u[neg_eids[:val_size]], neg_v[neg_eids[:val_size]]

val_pos_g = dgl.graph((val_pos_u, val_pos_v), num_nodes=g.number_of_nodes())
val_neg_g = dgl.graph((val_neg_u, val_neg_v), num_nodes=g.number_of_nodes())

train_g = dgl.remove_edges(g, eids[:val_size])

edge_list = pd.DataFrame(torch.vstack(train_g.edges()).T)
edge_list.columns = ['source', 'target']
G = nx.from_pandas_edgelist(edge_list)
g_hetero = convert_to_heterograph_group_isolates(G, n_nodes = train_g.number_of_nodes()).to(device)

n_types = max([int(t) for t in g_hetero.etypes])
e_tensors = [g_hetero.edges(etype = etype) for etype in sorted(g_hetero.etypes, key = int)]
src = torch.hstack([e[0] for e in e_tensors])
dst = torch.hstack([e[1] for e in e_tensors])

g_homo = dgl.heterograph({('paper','1','paper'):(src,dst)})
g_hetero.ndata['feat'] = train_g.ndata['feat']
g_hetero.ndata['feat'] = g_hetero.ndata['feat']

g_homo.to(device)
g_hetero.to(device)

del train_g 
del edge_list
del G
del src 
del dst
del g


In [None]:
model = IDRSAGEJK(g_hetero.ndata['feat'].shape[1], 256,
            n_types, dropout = config['dropout'])
model.to(device)

opt = torch.optim.Adam(model.parameters(), lr=config['lr'])
edge_dict = {etype: g_hetero.edges(etype = etype, form = 'all')[-1] for etype in g_hetero.etypes}

sampler = HomoNeighborSampler([config['r1'],config['r2']], prefetch_node_feats=['feat'])
sampler = HetEdgePredictionSampler(
        sampler, g_homo = g_homo,
        negative_sampler=dgl.dataloading.negative_sampler.Uniform(1))
dataloader = dgl.dataloading.DataLoader(
        g_hetero, edge_dict, sampler,
        device=device, batch_size= config['batch_size'], shuffle=True,
        drop_last=False, num_workers=10)

In [None]:

best_loss = 100
for epoch in range(10):
    model.train()
    t0 = time.time()
    # with dataloader.enable_cpu_affinity():
    with tqdm.tqdm(dataloader) as tq:
        for it, (input_nodes, pair_graph, neg_pair_graph, blocks) in enumerate(tq):
            tq.set_description('Epoch: {}'.format(epoch))
            x = {'paper':blocks[0].srcdata['feat']}
            pos_score, neg_score = model(pair_graph, neg_pair_graph, blocks, x)
            pos_label = torch.ones_like(pos_score)
            neg_label = torch.zeros_like(neg_score)
            score = torch.cat([pos_score, neg_score])
            labels = torch.cat([pos_label, neg_label])
            loss = F.binary_cross_entropy_with_logits(score, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()
            tq.set_postfix({'loss':'{:.3f}'.format(loss.item())})
    model.eval()
    val_loss, val_auc, _, _ = evaluate(model, g_hetero, val_pos_g, val_neg_g)
    if val_loss < best_loss:
        best_loss = val_loss
        best_params = model.state_dict()
    torch.save(best_params,params_path)
print("Finished Training")


In [None]:
model.load_state_dict(best_params)

with torch.no_grad():
    emb = model.inference(g,device,4096,0,device)

In [None]:
emb = emb.detach().numpy().astype(float)

In [None]:
emb_dict = {}
for i, data in node_df.iterrows():
    emb_dict[data['paper_id']] = list(emb[i,:])
emb_s = pd.Series(emb_dict)

In [None]:
with open('graphsage_mean_{}.json'.format(year), 'w') as outfile: 
        json.dump(emb_s.to_dict(), outfile)