In [1]:
import numpy as np
import pandas as pd
import os

import dgl
import torch
import torch.nn as nn
from dgl.nn import GATv2Conv
import torch.nn.functional as F

from sklearn.metrics.pairwise import cosine_similarity

In [7]:
import seaborn as sns
import matplotlib.pyplot as plt

In [45]:
DATA_HOME = "/lyceum/jhk1c21/msc_project/data"
V14_PATH = os.path.join(DATA_HOME, "graph", "v14")
FILTERED_PATH = os.path.join(V14_PATH, "filtered")
output_dir = '/lyceum/jhk1c21/msc_project/batch'

In [46]:
# Load the data
nodes = pd.read_csv(os.path.join(V14_PATH, "nodes_v14.csv"), index_col='id')
similarity = pd.read_csv(os.path.join(FILTERED_PATH, "similarity_edges.csv"))

titles = np.load(os.path.join(FILTERED_PATH, 'title_embedding.npy'))
abstracts = np.load(os.path.join(FILTERED_PATH, 'abstract_embedding.npy'))
keywords = np.load(os.path.join(FILTERED_PATH, 'keywords_embedding.npy'))
domains = np.load(os.path.join(FILTERED_PATH, 'domains_embedding.npy'))

ids = np.load(os.path.join(FILTERED_PATH, "filtered_id.npy"))
edges = np.load(os.path.join(FILTERED_PATH, 'filtered_edge.npy'))

In [47]:
edge_df = pd.DataFrame()
edge_df['src'] = edges[:, 0]
edge_df['dst'] = edges[:, 1]

# convert id from str to numbers
id_to_int = {original_id: i for i, original_id in enumerate(ids)}
int_to_id = {i: original_id for original_id, i in id_to_int.items()}

edge_df['src'] = edge_df['src'].apply(lambda x: id_to_int[x])
edge_df['dst'] = edge_df['dst'].apply(lambda x: id_to_int[x])

In [48]:
node_features = np.concatenate([titles, abstracts, keywords, domains], axis=1)
tensor_node_features = torch.FloatTensor(node_features)

citation_network = dgl.graph((edge_df['src'].to_numpy(), edge_df['dst'].to_numpy()))
citation_network.ndata['features'] = torch.tensor(node_features, dtype=torch.float32)

In [49]:
df = pd.read_csv(os.path.join(FILTERED_PATH, "trained_sim.csv"))
out = torch.load(os.path.join(output_dir, 'result.pt'))

In [50]:
similarity['src'] =similarity['src'].apply(lambda x: id_to_int[x])
similarity['dst'] =similarity['dst'].apply(lambda x: id_to_int[x])

In [51]:
w1, w2, w3, w4 = 0.2, 0.2, 0.2, 0.4
df['org_sim'] = w1*similarity['title'] + w2*similarity['abstract'] + w3*similarity['keyword'] + w4*similarity['domain']

In [52]:
df['diff'] = df['sim'] - df['org_sim']

In [53]:
df

Unnamed: 0,src,dst,sim,org_sim,diff
0,14961,42725,0.984479,0.424087,0.560393
1,33767,141985,0.930674,0.706678,0.223996
2,64675,132132,0.818782,0.621357,0.197425
3,102146,98155,0.786727,0.611837,0.174889
4,118623,112343,0.958769,0.757628,0.201141
...,...,...,...,...,...
1273170,69751,63337,0.966002,0.618167,0.347835
1273171,65311,59922,0.974722,0.514026,0.460696
1273172,156811,12704,0.313755,0.668000,-0.354244
1273173,155283,108435,0.737572,0.614611,0.122961


In [54]:
df[(df['diff'] > 0)].sort_values('diff', ascending=False)

Unnamed: 0,src,dst,sim,org_sim,diff
152668,102266,49694,0.981489,0.205964,7.755246e-01
836615,109965,49694,0.947942,0.201236,7.467055e-01
384072,34169,71860,0.957934,0.223899,7.340356e-01
1104478,42498,49694,0.982493,0.259533,7.229599e-01
765947,128937,94944,0.947471,0.227589,7.198817e-01
...,...,...,...,...,...
437104,132935,132935,1.000000,1.000000,1.110223e-16
133160,12088,12088,1.000000,1.000000,1.110223e-16
830349,44999,44999,1.000000,1.000000,1.110223e-16
675109,120,120,1.000000,1.000000,1.110223e-16


In [55]:
print(df[(df['diff'] > 0)].shape[0]/df.shape[0])
print(df[(df['diff'] < 0)].shape[0]/df.shape[0])

0.7985453688613113
0.20128968916291948


In [56]:
src, dst = 34169, 71860
# src, dst = 109965, 49694
nodes.loc[[int_to_id[src], int_to_id[dst]]]['fos']

id
53e9a1f3b7602d9702af5839    [{'name': 'Rule-based machine translation', 'w...
53e9addbb7602d97037e5b8d    [{'name': 'Sequence alignment', 'w': 0.42699},...
Name: fos, dtype: object

In [57]:
similarity[(similarity['src'] == src) & (similarity['dst'] == dst)]

Unnamed: 0,src,dst,title,abstract,keyword,domain
384072,34169,71860,0.305689,0.000482,0.445756,0.183783


In [60]:
out[[src, dst]]

tensor([[ 0.0543,  0.0342,  0.0269,  0.0622,  0.0453, -0.0804, -0.0600, -0.0300,
         -0.0384, -0.0299, -0.0562, -0.0426, -0.0025, -0.0708, -0.0093, -0.0836,
         -0.0748, -0.0090,  0.0127, -0.0534,  0.0455, -0.0714, -0.0107,  0.0928,
          0.0662, -0.1196,  0.0626,  0.0310, -0.0128,  0.0307, -0.0024,  0.0289,
          0.0153, -0.0360,  0.1663,  0.0008,  0.1805, -0.0446,  0.0554, -0.0539,
         -0.0121, -0.1048,  0.3030,  0.0125, -0.0159, -0.0205,  0.0564,  0.1160,
          0.0345, -0.0016],
        [ 0.0127,  0.0693,  0.0307,  0.0484,  0.0642, -0.1278, -0.0630, -0.0434,
         -0.0461, -0.0290, -0.0342,  0.0030,  0.0018, -0.0687, -0.0307, -0.1067,
         -0.0718, -0.0138,  0.0267, -0.0633, -0.0085, -0.0996,  0.0236,  0.0870,
          0.0710, -0.1301,  0.0730,  0.0180, -0.0226, -0.0130, -0.0193,  0.0545,
          0.0068, -0.0592,  0.1737, -0.0041,  0.1769, -0.0403,  0.0521, -0.0582,
         -0.0413, -0.0706,  0.2943,  0.0091, -0.0413,  0.0129,  0.0403,  0.1111,
