In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

while 'notebooks' in os.getcwd():
    os.chdir('..')

import json
import torch
import torch_geometric.transforms as T
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
from sklearn.metrics import roc_auc_score
import logging
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.offline as pyo
import numpy as np
import pandas as pd
from torch_geometric.loader import ShaDowKHopSampler, NeighborLoader
from torch_geometric.utils import k_hop_subgraph

from src.torch_geo_models import GraphSAGE, LinkPredictor
from src.data.gamma.arxiv import load_data, get_val_test_edges, prepare_adjencency, get_edge_index_from_adjencency
from src.train.gamma import GammaGraphSage

In [3]:
logging.basicConfig(
    format='%(asctime)s - %(levelname)s : %(message)s',
    level=logging.INFO,
    datefmt='%Y-%m-%d %H:%M:%S'
)

In [4]:
torch.cuda.is_available()

True

In [5]:
device = f'cuda:0' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
device = torch.device(device)
device

device(type='cuda', index=0)

## Data Loading

In [6]:
data = load_data()
data = prepare_adjencency(data, to_symmetric=True).to(device)

In [7]:
data.adj_t.sum(dim=1).max()

tensor(13161., device='cuda:0')

## Load $\Gamma$ function

In [8]:
best_model_metrics = json.load(open('models/gamma_graph_sage/best_model_info.json'))

### Load model

In [9]:
gamma = GammaGraphSage.load_model(
    int(best_model_metrics['run']),
    int(best_model_metrics['epoch']),
    device,
    data.num_nodes)

## Rewire graph

In [10]:
row, col, edge_attr = data.adj_t.t().coo()
edge_index = torch.stack([row, col], dim=0).to(data.adj_t.device())

In [11]:
def get_node_highest_gamma_edges(node, depth, edge_index, data):
    khop_nodes, _, _, _ = k_hop_subgraph(node, depth, edge_index)
    target = khop_nodes[khop_nodes != node]
    source = torch.tensor([node] * target.size()[0]).to(data.adj_t.device())
    
    gamma_preds = gamma.forward(
        torch.stack([source, target]).to(data.adj_t.device()),
        data.adj_t).detach()
    
    node_neighborhood = torch.stack([
        source,
        target,
        gamma_preds.squeeze()
    ])
    
    sorted_indices = node_neighborhood[2,:].sort(descending=True)[1]
    
    sorted_neighborhood = node_neighborhood[:, sorted_indices]
    
    return sorted_neighborhood[:, :data.adj_t[node].sum().int()]
        

In [12]:
data.num_nodes

169343

In [13]:
node_edges_dict = {}
depth = 3
for node in range(data.num_nodes):
    if node % 10000 == 0:
        print(f'{node + 1}/{data.num_nodes}')
    node_edges_dict[node] = get_node_highest_gamma_edges(node, depth, edge_index, data).detach().cpu().numpy()

1/169343
10001/169343
20001/169343
30001/169343
40001/169343
50001/169343
60001/169343
70001/169343
80001/169343
90001/169343
100001/169343
110001/169343
120001/169343
130001/169343
140001/169343
150001/169343
160001/169343


In [18]:
k_hop_subgraph(0, depth, edge_index)

(tensor([     0,     14,     31,  ..., 169339, 169340, 169341], device='cuda:0'),
 tensor([[     0,      0,      0,  ..., 169341, 169341, 169341],
         [   411,    640,   1162,  ..., 140574, 147834, 163274]],
        device='cuda:0'),
 tensor([0], device='cuda:0'),
 tensor([ True,  True,  True,  ...,  True, False, False], device='cuda:0'))

In [None]:
out_path = 'data/graph_modifications/01-1-rewired_edges_same_degrees.csv'
for it, edges_info in enumerate(node_edges_dict.values()):
    edges_df = pd.DataFrame(
        edges_info.T.tolist(),
        columns=['source', 'target', 'weight'])
    
    if it == 0:
        edges_df.to_csv(out_path, index=False)
    else:
        edges_df.to_csv(out_path, mode='a', header=False, index=False)