In [1]:
import torch
from torch_geometric.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import StepLR

from gnn import GNN

import os
from tqdm import tqdm
import argparse
import time
import numpy as np
import random

torch.cuda.is_available()

Using backend: pytorch


True

### hard-coded arguments

explain GCN model

In [2]:
# get args from main_gnn CLI
class Argument(object):
    name = "args"
    
args = Argument()
args.batch_size = 256
args.num_workers = 0
args.num_layers = 5
args.emb_dim = 600
args.drop_ratio = 0
args.graph_pooling = "sum"
args.checkpoint_dir = "models/gin-virtual/checkpoint"
args.device = 0

In [3]:
device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
# device = "cpu"
device

device(type='cuda', index=0)

In [4]:
shared_params = {
    'num_layers': args.num_layers,
    'emb_dim': args.emb_dim,
    'drop_ratio': args.drop_ratio,
    'graph_pooling': args.graph_pooling
}


### load model

In [5]:
from gnn import GNN

In [6]:
"""
LOAD Checkpoint data
"""
checkpoint = torch.load(os.path.join(args.checkpoint_dir, 'checkpoint.pt'))
checkpoint.keys()

dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'scheduler_state_dict', 'best_val_mae', 'num_params'])

In [7]:
gnn_name = "gin-virtual"
gnn_type = "gin"
virtual_node = True

In [8]:
model = GNN(gnn_type = gnn_type, virtual_node = virtual_node, **shared_params).to(device)
model.load_state_dict(checkpoint["model_state_dict"])
model.state_dict()

model.eval()

type(model)

gnn.GNN

In [9]:
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=300, gamma=0.25)
reg_criterion = torch.nn.L1Loss()

### load data

In [10]:
### importing OGB-LSC
from ogb.lsc import PygPCQM4MDataset, PCQM4MEvaluator

dataset = PygPCQM4MDataset(root = 'dataset/')

In [11]:
split_idx = dataset.get_idx_split()
split_idx["train"], split_idx["test"], split_idx["valid"]

(tensor([      0,       1,       2,  ..., 3045357, 3045358, 3045359]),
 tensor([3426030, 3426031, 3426032,  ..., 3803450, 3803451, 3803452]),
 tensor([3045360, 3045361, 3045362,  ..., 3426027, 3426028, 3426029]))

### triplet loss

In [14]:
"""
define triplet loss
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.nn import global_add_pool

class TripletLossRegression(nn.Module):
    """
        anchor, positive, negative are node-level embeddings of a GNN before they are sent to a pooling layer,
        and hence are expected to be matrices.
        anchor_gt, positive_gt, and negative_gt are ground truth tensors that correspond to the ground-truth
        values of the anchor, positive, and negative respectively.
    """

    def __init__(self, margin: float = 0.0, eps=1e-6):
        super(TripletLossRegression, self).__init__()
        self.margin = margin
        self.eps = eps

    def forward(self, anchor: Tensor, negative: Tensor, positive: Tensor,
                anchor_gt: Tensor, negative_gt: Tensor, positive_gt: Tensor) -> Tensor:

        # get distance
        pos_distance = torch.linalg.norm(positive - anchor, dim=1)
        negative_distance = torch.linalg.norm(negative - anchor, dim=1)

        coeff = torch.div(torch.abs(negative_gt - anchor_gt) , (torch.abs(positive_gt - anchor_gt) + self.eps))
        loss = F.relu((pos_distance - coeff * negative_distance) + self.margin)
        return torch.mean(loss)




In [15]:
"""
get embedding
"""
model_activation = {}
def get_activation(name):
    def hook(model, input, output):
        model_activation[name] = output
    return hook

model.gnn_node.register_forward_hook(get_activation('gnn_node'))


<torch.utils.hooks.RemovableHandle at 0x2888ee77fc8>

In [12]:
"""
load training dataset
"""

name = "valid"

train_loader = DataLoader(dataset[split_idx[name]], batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)
train_loader

<torch_geometric.data.dataloader.DataLoader at 0x288fc3dcf48>

In [13]:
"""
load triplet dataset
"""
anchor_loader = DataLoader(dataset[split_idx[name]], batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers)
positive_loader = DataLoader(dataset[split_idx[name]], batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers)
negative_loader = DataLoader(dataset[split_idx[name]], batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers)


In [None]:
"""
dynamic triplet dataset based on error
"""

# 1. get losses for training dataset

In [16]:
# def triplet_loss_train(model, device, anchor_loader, negative_loader, positive_loader, optimizer, gnn_name):
model.train()
loss_accum = 0
triplet_loss_criterion = TripletLossRegression()

for step, (anchor_batch, negative_batch, positive_batch) in \
        enumerate(zip(tqdm(anchor_loader, desc="Iteration"), negative_loader, positive_loader)):
    anchor_batch = anchor_batch.to(device)
    pred_anchor = model(anchor_batch).view(-1,)
    anchor_embed = model_activation['gnn_node']
    anchor_embed = model.pool(anchor_embed, anchor_batch.batch)

    negative_batch = negative_batch.to(device)
    pred_neg = model(negative_batch).view(-1,)
    neg_embed = model_activation['gnn_node']
    neg_embed = model.pool(neg_embed, negative_batch.batch)

    positive_batch = positive_batch.to(device)
    pred_pos= model(positive_batch).view(-1,)
    pos_embed = model_activation['gnn_node']
    pos_embed = model.pool(pos_embed, positive_batch.batch)

    optimizer.zero_grad()
    # 1. MAE loss
    mae_loss = reg_criterion(pred_anchor, anchor_batch.y)
    
    # 2. Triplet Loss
    tll_loss = triplet_loss_criterion(anchor_embed, neg_embed, pos_embed,
                                      anchor_batch.y, negative_batch.y, positive_batch.y)
    loss = mae_loss + tll_loss

    if gnn_name == 'gin-virtual-bnn':
        kl_loss = model.get_kl_loss()[0]
        loss += kl_loss

    loss.backward()
    optimizer.step()

    loss_accum += loss.detach().cpu().item()
    
#     break

# return loss_accum / (step + 1)
loss_accum / (step + 1)

Iteration: 100%|███████████████████████████████████████████████████████████████████| 1487/1487 [12:21<00:00,  2.00it/s]


0.9765289135041779

In [17]:
raise Exception("")

Exception: 

In [None]:
""" 
IMPORTANT: GRAPH QUERY ID
Pick the graph
"""
selectedID = 75088 #0 #131054
queryID = split_idx["valid"][selectedID:selectedID + 1]
queryID

In [None]:
list(valid_loader)

## predict

In [None]:
batch = list(valid_loader)[0]
data = batch[0]
data

In [None]:
batch = batch.to(device)
with torch.no_grad():
    pred = model(batch).view(-1,)
    
pred

In [None]:
y_true = data.y.item()
y_pred = pred.item()
y_true, y_pred

## plot sample

In [None]:
import networkx as nx
import matplotlib.pyplot as plt

In [None]:
def plotGraph(data, y_pred, y_true, ax, printnodelabel=False, printedgelabel=False):

    edges = data.edge_index.T.tolist()
    edges = np.array(edges)
    edges = [(x[0][0], x[0][1], {"feat": str(x[1])}) for x in list(zip(edges.tolist(), data.edge_attr.tolist()))]
    nodes = [(x[0], {"feat": str(x[1])}) for x in enumerate(data.x.tolist())]

    G = nx.Graph()
    G.add_nodes_from(nodes)
    G.add_edges_from(edges)
    nodelabels = nx.get_node_attributes(G, 'feat') 
    edgelabels = nx.get_edge_attributes(G, "feat")

    pos = nx.spring_layout(G)
    ax.set_title("pred={:.2f}, true={:.2f}".format(y_pred, y_true))
    if printnodelabel:
        nx.draw(G, pos, labels=nodelabels, ax=ax, node_size=40)
    else:
        nx.draw(G, pos, ax=ax, node_size=40)
        
    if printedgelabel:
        nx.draw_networkx_edge_labels(G, pos, ax=ax, edge_labels=edgelabels)


In [None]:
fig, ax = plt.subplots()
plotGraph(data, y_pred, y_true, ax, False, True)

## perturb edge feature

edge (5, 6, 2) possible dimensions

In [None]:
import ogb.utils as utils

In [None]:
edgeFeatDims = utils.features.get_bond_feature_dims()
edgeFeatDims

In [None]:
perturb_data_list = []

for _ in range(5000):
    # clone original data
    pData = data.clone()
    
    # create random noise
    randomNoise = np.random.randint(low=-4, high=4, size=data.edge_attr.shape)
    randomNoise = torch.tensor(randomNoise)

    # add edge_attr noise
    pData.edge_attr += randomNoise
    
    pData.edge_attr[:, 0] = pData.edge_attr[:, 0].clip(0, edgeFeatDims[0]-1)
    pData.edge_attr[:, 1] = pData.edge_attr[:, 1].clip(0, edgeFeatDims[1]-1)
    pData.edge_attr[:, 2] = pData.edge_attr[:, 2].clip(0, edgeFeatDims[2]-1)
    
    perturb_data_list.append(pData)
    
len(perturb_data_list)

In [None]:
valid_loader = DataLoader(perturb_data_list, batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)

# get data
batch = list(valid_loader)[0]
batch = batch.to(device)
with torch.no_grad():
    pred = model(batch) #.view(-1,)
    
pred.shape

In [None]:
plt.title("Perturb edge features. Label: {:.2f}".format(y_true))
plt.hist(pred.view(-1).tolist())
plt.axvline(y_pred, c="r")
plt.show()

given fixed node features and topology, perturbing edge features don't disturb the output much

## perturb node features


In [None]:
nodeDims = utils.features.get_atom_feature_dims()
nodeDims

In [None]:
perturb_data_list = []

for _ in range(1000):
    # clone original data
    pData = data.clone()
    
    # create random noise
    randomNoise = np.random.randint(low=-1, high=1, size=data.x.shape)
    randomNoise = torch.tensor(randomNoise)

    # add edge_attr noise
    pData.x += randomNoise
    
    pData.x[:, 0] = pData.x[:, 0].clip(0, nodeDims[0]-1)
    pData.x[:, 1] = pData.x[:, 1].clip(0, nodeDims[1]-1)
    pData.x[:, 2] = pData.x[:, 2].clip(0, nodeDims[2]-1)
    pData.x[:, 3] = pData.x[:, 2].clip(0, nodeDims[3]-1)
    pData.x[:, 4] = pData.x[:, 2].clip(0, nodeDims[4]-1)
    pData.x[:, 5] = pData.x[:, 2].clip(0, nodeDims[5]-1)
    pData.x[:, 6] = pData.x[:, 2].clip(0, nodeDims[6]-1)
    pData.x[:, 7] = pData.x[:, 2].clip(0, nodeDims[7]-1)
    pData.x[:, 8] = pData.x[:, 2].clip(0, nodeDims[8]-1)
    
    perturb_data_list.append(pData)
    
len(perturb_data_list)

In [None]:
# perturb_data_list = [data]

# for i in range(1):
#     pData = data.clone()
# #     pData.x[-1, 0] = torch.tensor(i)
#     pData.x[-1] = torch.tensor([ 5,  0,  4,  5,  3,  0,  2,  0,  0])
#     perturb_data_list.append(pData)


In [None]:
valid_loader = DataLoader(perturb_data_list, batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)

# get data
batch = list(valid_loader)[0]
batch = batch.to(device)
with torch.no_grad():
    pred = model(batch) #.view(-1,)
    
pred.shape #, pred

In [None]:
plt.title("Perturb node features. Label: {:.2f}".format(y_true))
plt.hist(pred.view(-1).tolist())
plt.axvline(y_pred, c="r")
plt.show()

node features seem very sensitive

## perturb topology

In [None]:
# keep backup
backup = data.edge_index.clone()
backup

In [None]:
perturb_data_list = []

for i in range(1000):
    # clone original data
    pData = data.clone()
    
    # noise parameters
    noEdgeSwap = 3

    # create edges
    edges = pData.edge_index.T.tolist()
    edges = np.array(edges)
    edges = [(x[0][0], x[0][1], {"feat": str(x[1])}) for x in list(zip(edges.tolist(), pData.edge_attr.tolist()))]
    nodes = [(x[0], {"feat": str(x[1])}) for x in enumerate(pData.x.tolist())]
    G = nx.Graph()
    G.add_nodes_from(nodes)
    G.add_edges_from(edges)

    # swap edges
    G = nx.double_edge_swap(G, noEdgeSwap)
    # both directions
    newEdges = list(G.edges()) + [(x[1], x[0]) for x in G.edges()]
    newEdges = torch.tensor(newEdges).T
    # set value
    pData.edge_index = newEdges

    perturb_data_list.append(pData)
    
    # visualise some graphs
    if i % 50 == 0:
        plt.figure(figsize=(2, 2))
        nx.draw(G)
        plt.show()
    
len(perturb_data_list)

In [None]:
valid_loader = DataLoader(perturb_data_list, batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)

# get data
batch = list(valid_loader)[0]
batch = batch.to(device)
with torch.no_grad():
    pred = model(batch) #.view(-1,)
    
pred.shape

In [None]:
plt.title("Perturb topology. Label: {:.2f}".format(y_true))
plt.hist(pred.view(-1).tolist())
plt.axvline(y_pred, c="r")
plt.show()

topology doesn't seem to affect the score too