# IMPORT

In [None]:
import torch
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss, GRUCell
from torch_geometric.data import Data
from sklearn.metrics import roc_auc_score,average_precision_score

import random

import bisect 

import gc
import copy

from itertools import permutations

import pandas as pd

from torch_geometric.utils import negative_sampling, erdos_renyi_graph, shuffle_node, to_networkx
import torch_geometric.transforms as T
from torch_geometric.utils import train_test_split_edges
from torch_geometric.transforms import RandomLinkSplit,NormalizeFeatures,Constant,OneHotDegree
from torch_geometric.utils import from_networkx
from torch_geometric.nn import GCNConv,SAGEConv,GATv2Conv, GINConv, Linear
from scipy.stats import entropy

import torch
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np

import copy
import itertools
import json

In [None]:
%load_ext autoreload
%autoreload 2

# LOAD DATASET

In [None]:
from steemitdata import get_steemit_dataset

In [None]:
#Snapshots with constant encoder as node features
#Snapshots with textual features as node features

snapshots_c = get_steemit_dataset(preprocess='constant')
snapshots_t = get_steemit_dataset(preprocess='text')

In [None]:
for i in range(len(snapshots_t)):
    torch.save(snapshots_t[i].x, f'steemit-t3gnn-data/{i}_x.pt')
    torch.save(snapshots_t[i].edge_index, f'steemit-t3gnn-data/{i}_edge_index.pt')

In [None]:
#Snapshots with random features as node features
snapshots_ts = get_steemit_dataset(preprocess='constant')
for snap in snapshots_ts:
    snap.x = torch.randn(snap.num_nodes, 384)

# LOAD MODEL

In [None]:
from t3gnn import T3GConvGRU, T3EvolveGCNH, T3EvolveGCNO, T3GNN, T3MLP
from traineval import *

## Experiments

In [None]:
import random
device = torch.device('cuda')
torch.manual_seed(41)
torch.cuda.manual_seed_all(41)
np.random.seed(41)
random.seed(41)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.cuda.empty_cache()

In [None]:
hidden_conv1 = 64
hidden_conv2 = 32

## Train on the link prediction task using sentence embedding and network structure

### Self-loops

In [None]:
ro_selfloops_avgpr = train_roland(snapshots_t, hidden_conv1, hidden_conv2, update='mlp', add_self_loops=True)

### Baselines

In [None]:
ro_avgpr, gcgru_avgpr, evo_avgpr, evh_avgpr = train_models(snapshots_t, hidden_conv1, hidden_conv2, update='mlp')

### Random Sampling

In [None]:
ro_randomsample_avgpr = train_roland_random(snapshots_t, hidden_conv1, hidden_conv2, update='mlp') #random-negative-sampling

### No Self loops, Skip Connections, Content MLP

In [None]:
ro_selfloops_avgpr = train_roland(snapshots_t, hidden_conv1, hidden_conv2, update='mlp', add_self_loops=True)

In [None]:
ro_skip_avgpr = train_roland(snapshots_t, hidden_conv1, hidden_conv2, update='mlp', skip_connections=True)
ro_content_avgpr = train_roland(snapshots_t, hidden_conv1, hidden_conv2, update='mlp', content_mlp=True)

### Feature shuffling, T3MLP, EdgeBank

In [None]:
print('Node shuffling')
ro_shuffling_avgpr = train_roland(snapshots_t, hidden_conv1, hidden_conv2, update='mlp', shuffle_node_features=True)

In [None]:
print('T3MLP')
ro_mlp_avgpr = train_mlp(snapshots_t, hidden_conv1, hidden_conv2, update='mlp')

In [None]:
print('EdgeBank')
ro_edgebank_avgpr = edge_bank(snapshots_t)

### Random Features

In [None]:
ro_constant_avgpr = train_roland(snapshots_c, hidden_conv1, hidden_conv2, update='mlp') #no-features
ro_randomf_avgpr = train_roland(snapshots_ts, hidden_conv1, hidden_conv2, update='mlp') #random_features