# Link Prediction

We create a graph with nodes labeled by the lexicon terms and corresponding document embeddings in node features. It then predicts existence of the edges by the similarity of node features.

We use the JS-divergence distance of two multivariate normal distributions of nodes features to calculate the edge weights.

Here is the steps we take:

- load and organize GPT-3 topic-embeddings in the nodes.
- Fit multivariate normal distributions to the node features (parameters are loc and diagonal_scale).
- Calculate the adjacency matrix using JS-divergence of the node features.


## Input

- `models/gpt3/abstracts_gpt3ada.npz`: GPT-3 embeddings, one document per row.
- `models/gpt3/abstracts_gpt3ada_pmids.csv`: the PMIDs of the document embeddings in the above embedding dataset.

## Output

- `models/gpt3/abstracts_metapath2vec.pkl`: a pickle file containing the trained metapath2vec model.

In [1]:
%reload_ext autoreload
%autoreload 2

import pandas as pd
import numpy as np
from umap import UMAP
import matplotlib.pyplot as plt
import seaborn as sns; sns.set_theme()  # noqa
import dash_bio as dashbio
from python.cogtext.datasets.pubmed import PubMedDataLoader
from python.cogtext.similarity_matrix import get_similarity_matrix
from sklearn.preprocessing import normalize
from tqdm import tqdm

import stellargraph as sg
from stellargraph.data import UniformRandomMetaPathWalk

import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

2022-02-02 11:51:54.618115: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


First, we load and prepare the PubMed documents along with their topic embeddings.

In [2]:
# load weights, clusters and metadata (takes ~ 20sec)

clusters = pd.read_csv('models/gpt3/abstracts_gpt3ada_clusters.csv.gz', index_col=0)
weights = np.load('models/gpt3/abstracts_gpt3ada_weights.npz')['arr_0']
clusters['weights'] = list(weights)

# load abstracts
pubmed = PubMedDataLoader(preprocessed=False, drop_low_occurred_labels=False).load()
pubmed = pubmed.merge(clusters, on='pmid', how='left')

pubmed.dropna(subset=['cluster'], inplace=True)

print(f'Successfully create a list of {len(pubmed)} topic-embeddings.')

Successfully create a list of 293014 topic-embeddings.


In [3]:
# discard low-occurred labels

docs_per_label = pubmed.groupby('label').size()
low_appeared_labels = docs_per_label[docs_per_label < 5].index.to_list()

pubmed = pubmed.query('label not in @low_appeared_labels')

print(f'Successfully removed {len(low_appeared_labels)} low-appeared labels: {low_appeared_labels}')

Successfully removed 14 low-appeared labels: ['BackwardSpanTask', 'BearDragonTask', 'BoxCrossingDualTask', 'BoxesTask', 'CategorySwitchTask', 'Cued_Unpredictable_Switch_task', 'D2_target_detection', 'DegradedVigilanceTask', 'GiftWrap', 'GrassSnowTask', 'Incompatibility_test', 'KnockAndTapTask', 'PurposiveAction', 'ReverseCategorization']


## Node-node similarity

Here, we calculate the similarity between nodes using the node features. We use
KL-divergence of the node features a measure of distance between the two.

In [4]:
# KL model
def nll(X, dist):
  return - tf.reduce_mean(dist.log_prob(X))

@tf.function
def get_loss_and_grads(X_train, dist):
  with tf.GradientTape() as tape:
    tape.watch(dist.trainable_variables)
    loss = nll(X_train, dist)
  grads = tape.gradient(loss, dist.trainable_variables)
  return loss, grads

def fit_multivariate_normal(data, n_epochs=10, batch_size=100):
  dist = tfd.MultivariateNormalDiag(
    loc=tf.Variable(data.mean(axis=0), name='loc'),
    scale_diag=tf.Variable(np.ones(data.shape[1]), name='scale_diag'))

  optimizer = tf.keras.optimizers.Adam(learning_rate=0.05)

  for _ in range(n_epochs):
    # for batch in np.array_split(data, 1 + (data.shape[0] // batch_size)):
    batch = data  # use all the data in each epoch
    loss, grads = get_loss_and_grads(batch, dist)
    optimizer.apply_gradients(zip(grads, dist.trainable_variables))
    # loc_value = dist.loc.value()
  return dist


In [5]:
n_skip = 20
n_top_labels_per_category = 20

popular_labels = (pubmed.groupby(['category','label'])['pmid']
                        .count().sort_values(ascending=False)
                        .groupby('category').head(n_top_labels_per_category)
                        .index.get_level_values('label').to_list())
popular_pubmed = pubmed.query('label in @popular_labels')

node_features = popular_pubmed.groupby('label')['weights'].apply(lambda x: np.stack(x))

print(f"Popular labels: {', '.join(popular_labels)}")

tqdm.pandas()
node_dists = node_features.progress_apply(fit_multivariate_normal)

Popular labels: Attention, WorkingMemory, Planning, Initiation, Sequencing, ExecutiveFunction, Inhibition, Fluency, Reasoning, ProcessingSpeed, EpisodicMemory, Stroop, LongTermMemory, Shifting, ProblemSolving, CogntiveControl, InhibitoryControl, Verbal_fluency_task, ShortTermMemory, Mindfulness, SelectiveAttention, SelfRegulation, TMT_-_Trail_Making_Task, Digit_Span, WCST_-_Wisconsin_Card_Sort_Test, Go_NoGo, NBackTask, FlankerTask, StopSignalTask, IGT_-_Iowa_Gambling_task, PVT_-_Psychomotor_Vigilance_task, ContiniousPerformanceTask, Span_Task, CategoryFluencyTask, Simon_task, PEG_-_Pencil_Tapping_task, Semantic_Fluency_test, TowerOfLondon, DiscountingTask, Sorting_task


  0%|          | 0/40 [00:00<?, ?it/s]

Instructions for updating:
`scale_identity_multiplier` is deprecated; please combine it into `scale_diag` directly instead.


2022-02-02 11:52:15.496623: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
100%|██████████| 40/40 [00:44<00:00,  1.12s/it]


In [6]:
# cache the node-to-node kl distance matrix
n2n_js = np.zeros((len(node_dists), len(node_dists)))

for i,lbl_i in tqdm(enumerate(node_dists.index)):
  for j, lbl_j in enumerate(node_dists.index):
    P = node_dists.loc[lbl_i]
    Q = node_dists.loc[lbl_j]
    M = tfd.MultivariateNormalDiag(
      loc=(P.mean() + Q.mean())/2.,
      scale_diag=(P.stddev() + Q.stddev())/2.
    )
    js = .5 * (P.kl_divergence(M) + Q.kl_divergence(M)).numpy()
    # print(X_dists.loc[lbl_i].trainable_variables[1].shape)
    n2n_js[i,j] = js

n2n_js = pd.DataFrame(n2n_js, index=node_dists.index.to_list(), columns=node_dists.index.to_list())

40it [00:27,  1.45it/s]


In [8]:
# visualize the similarity between constructs pairs regardless of the underlying tasks

# # drop tasks for visualization
tasks = pubmed.query('category.str.contains("Task")')['label'].unique()
constructs = pubmed.query('category.str.contains("Construct")')['label'].unique()

n2n_sim = n2n_js.applymap(lambda x: np.exp(-x))

plot_data = n2n_sim.drop(index=constructs, columns=constructs, errors='ignore')

# # heatmap plot categorized by label category, either task or construct
# g = sns.clustermap(labels_sim,
#                    figsize=(25, 25),
#                    dendrogram_ratio=0.0001, 
#                    cbar_pos=(1.0, 0.73, 0.03, 0.25),
#                    robust=True)
# plt.show()

dashbio.Clustergram(
  data=plot_data,
  column_labels=plot_data.columns.to_list(),
  row_labels=plot_data.index.to_list(),
  cluster='all',
  center_values=False,
  height=800,
  width=1000,
  display_ratio=[0.001, 0.001],
  color_map='RdBu_r',
  hidden_labels=['col'],
  row_dist='euclidean',
  col_dist='euclidean',
  # standardize='col'
)


## Graph

In [13]:
node_avg_embeddings = node_features.apply(lambda x: x.mean(axis=0))

sim = n2n_js.applymap(lambda x: np.exp(-x))
sim.index.name = 'source'
sim.columns.name = 'target'
sim = sim.reset_index().melt(id_vars=['source'], value_vars=sim.columns, value_name='weight')

adj = sim

# arbitrary threshold to keep Simon connected in the graph
adj = sim.query('target != source').query('weight >= .91').copy()

task_features = sg.IndexedArray(np.empty(tasks.shape).reshape(-1,1), index=tasks)
construct_features = sg.IndexedArray(np.empty(constructs.shape).reshape(-1,1), index=constructs)

G = sg.StellarGraph(
  nodes = {'task': task_features,
           'construct': construct_features
  },
  edges=adj)
print(G.info())

StellarGraph: Undirected multigraph
 Nodes: 156, Edges: 1464

 Node types:
  task: [85]
    Features: float64 vector, length 1
    Edge types: task-default->construct, task-default->task
  construct: [71]
    Features: float64 vector, length 1
    Edge types: construct-default->construct, construct-default->task

 Edge types:
    construct-default->task: [774]
        Weights: range=[0.91002, 0.953902], mean=0.933167, std=0.00979739
        Features: none
    construct-default->construct: [380]
        Weights: range=[0.912328, 0.954768], mean=0.942605, std=0.00720964
        Features: none
    task-default->task: [310]
        Weights: range=[0.910201, 0.949784], mean=0.926574, std=0.00991365
        Features: none


In [14]:
# metapath2vec

from gensim.models import Word2Vec

# metapath schemas as a list of lists of node types.
metapaths = [
    ['task', 'construct', 'task'],
    ['construct', 'task', 'construct'],
    # ['construct', 'task', 'task', 'construct'],
    # ['task', 'construct', 'task', 'construct', 'task'],
    # ['task', 'construct', 'construct', 'task'],
]

walks = UniformRandomMetaPathWalk(G).run(
    nodes=list(G.nodes()),  # root nodes
    length=3,  # maximum length of a random walk
    n=10,  # number of random walks per root node
    metapaths=metapaths,  # the metapaths
)
print(f'[MetaPath2Vec] Created {len(walks)} random walks.')
# DEBUG: print('Random walks starting from the "RewardProcessing" node:',
# DEBUG:      [' -> '.join(w) for w in walks if w[0] == 'RewardProcessing'])

print('[MetaPath2Vec] Now training the Word2Vec model...', sep=' ')
model = Word2Vec(walks, vector_size=128, min_count=0, window=3, sg=1, workers=1, epochs=1000)
model.save('models/gpt3/abstracts_metapath2vec.pkl')

print('[MetaPath2Vec] Done! Model saved to `models/gpt3/abstracts_metapath2vec.pkl`.')

[MetaPath2Vec] Created 1560 random walks.
[MetaPath2Vec] Now training the Word2Vec model...
[MetaPath2Vec] Done! Model saved to `models/gpt3/abstracts_metapath2vec.pkl`.


In [21]:
# query the joint graph

from gensim.models import Word2Vec
model = Word2Vec.load('models/gpt3/abstracts_metapath2vec.pkl')

#  visual-spatial and organizational processing abilities, as well as nonverbal problem-solving skills. 
 
# FIXME: there is a typo in "CogntiveControl"; it comes from a type in the original lexicon.
model.wv.most_similar(
  positive=['Simon_task'],
  # negative=['Sorting_task','Go_NoGo'],
  topn=20)

# TODO filter by category


[('FlankerTask', 0.6479361653327942),
 ('StopSignalTask', 0.636583149433136),
 ('TowerOfLondon', 0.6256452798843384),
 ('Fluency', 0.612270712852478),
 ('InhibitoryControl', 0.601482093334198),
 ('Inhibition', 0.5987109541893005),
 ('Attention', 0.5787180662155151),
 ('DiscountingTask', 0.5702905654907227),
 ('Span_Task', 0.5652682781219482),
 ('Verbal_fluency_task', 0.5599583983421326),
 ('EpisodicMemory', 0.5560160875320435),
 ('ExecutiveFunction', 0.5529468655586243),
 ('ProcessingSpeed', 0.5521318316459656),
 ('WorkingMemory', 0.5419722199440002),
 ('PVT_-_Psychomotor_Vigilance_task', 0.5362783074378967),
 ('Sorting_task', 0.5287689566612244),
 ('ContiniousPerformanceTask', 0.5238953828811646),
 ('Semantic_Fluency_test', 0.5211595892906189),
 ('Go_NoGo', 0.5187562704086304),
 ('PEG_-_Pencil_Tapping_task', 0.5165704488754272)]

In [22]:
# Plot all tasks and constructs in a lower dimensional 3d space

import plotly.express as px

# projections = PCA(n_components=2, random_state=0).fit_transform(label_embeddings)
projections_3d = UMAP(n_components=3, random_state=0).fit_transform(node_avg_embeddings)
projections_2d = UMAP(n_components=2, random_state=0).fit_transform(node_avg_embeddings)

projections_3d = pd.DataFrame(projections_3d, index=node_avg_embeddings.index).reset_index()
projections_2d = pd.DataFrame(projections_2d, index=node_avg_embeddings.index).reset_index()

fig = px.scatter_3d(projections_3d,
                    x=0, y=1, z=2,
                    color='category', hover_name='label',
                    title='Popular labels in the topic space',
                    color_discrete_sequence=['red','blue'],
                    width=600, height=600,)
fig.show()


fig, ax = plt.subplots(1,1, figsize=(15,15))
sns.scatterplot(data=projections_2d, x=0, y=1, ax=ax, s=50, hue='category')

for (i,lbl, _,x, y) in projections_2d.itertuples():
  lbl = f'{lbl[:12]}...' if len(lbl)>10 else lbl
  ax.text(x+0.01, y-np.random.random()*.02, f'{lbl}', alpha=0.5, fontsize=12)

ax.set(xlabel='UMAP 1', ylabel='UMAP 2')
plt.suptitle('2D projection of the topic embeddings',y=.9)
plt.show()

In [15]:
%reload_ext watermark
%watermark
%watermark -iv -p umap,pytorch,scikit-learn,python.cogtext

Last updated: 2022-01-21T16:09:50.086463+01:00

Python implementation: CPython
Python version       : 3.9.9
IPython version      : 8.0.0

Compiler    : Clang 11.1.0 
OS          : Darwin
Release     : 21.2.0
Machine     : x86_64
Processor   : i386
CPU cores   : 12
Architecture: 64bit

umap          : 0.5.2
pytorch       : not installed
scikit-learn  : 0.0
python.cogtext: 0.1.2022012116

plotly      : 5.3.1
seaborn     : 0.11.2
pandas      : 1.3.4
dash_bio    : 0.8.0
numpy       : 1.20.3
sys         : 3.9.9 | packaged by conda-forge | (main, Dec 20 2021, 02:38:53) 
[Clang 11.1.0 ]
matplotlib  : 3.4.3
stellargraph: 1.2.1

