In [1]:
import random
import math
import gc

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

import torch
from torch import nn
from torch_geometric.data import Data

tqdm.pandas()

In [2]:
n_samples = 1000
min_nodes = 4
max_nodes = 20
min_edges = 4

data_list = []

for i in tqdm(range(n_samples)):
    n_nodes = random.randint(min_nodes, max_nodes)
    x_features = torch.randn((n_nodes, 3))
    x_mask = torch.zeros((n_nodes, 1))
    x = torch.cat([x_features, x_mask], dim=1)
        
#     max_edges = math.factorial(n_nodes)
    n_edges = random.randint(min_edges, n_nodes * 2)
    edge_index = torch.randint(0, n_nodes - 1, (2, n_edges)).long()
    
    
    edge_attr = torch.cat([torch.randn((n_edges, 1)), torch.zeros((n_edges, 2))], dim=1)

    data = Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr
    )
    
    data_list.append(data)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [3]:
data_list[0].x.shape[1]

4

In [4]:
data_list[0].edge_attr.shape[1]

3

In [5]:
CONFIG = {
    "device": 0,
    "batch_size": 256,
    "epochs": 100,
    "lr": 0.001,
    "decay": 0,
    "num_layer": 5,
    "emb_dim": 300,
    "dropout_ratio": 0,
    "mask_rate": 0.15,
    "mask_edge": 1,
    "JK": "last",
    "output_model_file": '',
    "gnn_type": "gat",
    "seed": 0, 
    "num_workers": 8,
    "csize": 3,
    "mode": "skipgram",
}

In [6]:
CONFIG["num_layer"]

5

In [7]:
def pool_func(x, batch, mode = "sum"):
    if mode == "sum":
        return global_add_pool(x, batch)
    elif mode == "mean":
        return global_mean_pool(x, batch)
    elif mode == "max":
        return global_max_pool(x, batch)

def cycle_index(num, shift):
    arr = torch.arange(num) + shift
    arr[-shift:] = torch.arange(shift)
    return arr

criterion = nn.BCEWithLogitsLoss()

def train(CONFIG, model_substruct, model_context, loader, optimizer_substruct, optimizer_context, device):
    model_substruct.train()
    model_context.train()

    balanced_loss_accum = 0
    acc_accum = 0

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)

        # creating substructure representation
        substruct_rep = model_substruct(batch.x_substruct, batch.edge_index_substruct, batch.edge_attr_substruct)[batch.center_substruct_idx]
        
        ### creating context representations
        overlapped_node_rep = model_context(batch.x_context, batch.edge_index_context, batch.edge_attr_context)[batch.overlap_context_substruct_idx]

        #Contexts are represented by 
        if CONFIG["mode"] == "cbow":
            # positive context representation
            context_rep = pool_func(overlapped_node_rep, batch.batch_overlapped_context, mode = CONFIG["context_pooling"])
            # negative contexts are obtained by shifting the indicies of context embeddings
            neg_context_rep = torch.cat([context_rep[cycle_index(len(context_rep), i+1)] for i in range(CONFIG["neg_samples"])], dim = 0)
            
            pred_pos = torch.sum(substruct_rep * context_rep, dim = 1)
            pred_neg = torch.sum(substruct_rep.repeat((CONFIG["neg_samples"], 1))*neg_context_rep, dim = 1)

        elif CONFIG["mode"] == "skipgram":

            expanded_substruct_rep = torch.cat([substruct_rep[i].repeat((batch.overlapped_context_size[i],1)) for i in range(len(substruct_rep))], dim = 0)
            pred_pos = torch.sum(expanded_substruct_rep * overlapped_node_rep, dim = 1)

            #shift indices of substructures to create negative examples
            shifted_expanded_substruct_rep = []
            for i in range(CONFIG["neg_samples"]):
                shifted_substruct_rep = substruct_rep[cycle_index(len(substruct_rep), i+1)]
                shifted_expanded_substruct_rep.append(torch.cat([shifted_substruct_rep[i].repeat((batch.overlapped_context_size[i],1)) for i in range(len(shifted_substruct_rep))], dim = 0))

            shifted_expanded_substruct_rep = torch.cat(shifted_expanded_substruct_rep, dim = 0)
            pred_neg = torch.sum(shifted_expanded_substruct_rep * overlapped_node_rep.repeat((CONFIG["neg_samples"], 1)), dim = 1)

        else:
            raise ValueError("Invalid mode!")

        loss_pos = criterion(pred_pos.double(), torch.ones(len(pred_pos)).to(pred_pos.device).double())
        loss_neg = criterion(pred_neg.double(), torch.zeros(len(pred_neg)).to(pred_neg.device).double())

        
        optimizer_substruct.zero_grad()
        optimizer_context.zero_grad()

        loss = loss_pos + CONFIG["neg_samples"]*loss_neg
        loss.backward()
        #To write: optimizer
        optimizer_substruct.step()
        optimizer_context.step()

        balanced_loss_accum += float(loss_pos.detach().cpu().item() + loss_neg.detach().cpu().item())
        acc_accum += 0.5* (float(torch.sum(pred_pos > 0).detach().cpu().item())/len(pred_pos) + float(torch.sum(pred_neg < 0).detach().cpu().item())/len(pred_neg))

    return balanced_loss_accum/step, acc_accum/step

In [8]:
import argparse

from chem.dataloader import DataLoaderMasking, DataLoaderSubstructContext  # , DataListLoader
from chem.mydataset import MyDataset

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import numpy as np

from chem.model import GNN, GNN_graphpred

import pandas as pd

from chem.util import MaskNode, ExtractSubstructureContextPair

from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool

criterion = nn.MSELoss()

import timeit
import warnings 
warnings.filterwarnings("ignore")

torch.manual_seed(0)
np.random.seed(0)
device = torch.device("cuda:" + str(CONFIG["device"])) if torch.cuda.is_available() else torch.device("cpu")
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(0)


l1 = CONFIG["num_layer"] - 1
l2 = l1 + CONFIG["csize"]

print(CONFIG["mode"])
print("num layer: %d l1: %d l2: %d" %(CONFIG["num_layer"], l1, l2))

#set up dataset and transform function.
dataset = MyDataset(data_list, transform = ExtractSubstructureContextPair(CONFIG["num_layer"], l1, l2))
loader = DataLoaderSubstructContext(dataset, 
                                    batch_size=CONFIG["batch_size"], 
                                    shuffle=True, 
                                    num_workers=0)

#set up models, one for pre-training and one for context embeddings
model_substruct = GNN(CONFIG["num_layer"], 
                      CONFIG["emb_dim"], 
                      x_input_dim=data_list[0].x.shape[1], 
                      edge_attr_input_dim=data_list[0].edge_attr.shape[1], 
                      JK = CONFIG["JK"], 
                      drop_ratio = CONFIG["dropout_ratio"], 
                      gnn_type = CONFIG["gnn_type"]).to(device)

model_context = GNN(int(l2 - l1), 
                    CONFIG["emb_dim"], 
                    JK = CONFIG["JK"], 
                    x_input_dim=data_list[0].x.shape[1], 
                    edge_attr_input_dim=data_list[0].edge_attr.shape[1], 
                    drop_ratio = CONFIG["dropout_ratio"], 
                    gnn_type = CONFIG["gnn_type"]).to(device)

#set up optimizer for the two GNNs
optimizer_substruct = optim.Adam(model_substruct.parameters(), lr=CONFIG["lr"], weight_decay=CONFIG["decay"])
optimizer_context = optim.Adam(model_context.parameters(), lr=CONFIG["lr"], weight_decay=CONFIG["decay"])

for epoch in range(1, CONFIG["epochs"] + 1):
    print("====epoch " + str(epoch))

    train_loss, train_acc = train(CONFIG, model_substruct, model_context, loader, optimizer_substruct, optimizer_context, device)
    print(train_loss, train_acc)

if not CONFIG["output_model_file"] == "":
    torch.save(model_substruct.state_dict(), CONFIG["output_model_file"] + ".pth")

skipgram
num layer: 5 l1: 4 l2: 7
====epoch 1


Iteration:   0%|                                                                                                                                      | 0/4 [00:00<?, ?it/s]


ValueError: too many values to unpack (expected 2)

In [11]:
import importlib
from chem import model, util, batch

importlib.reload(model)
importlib.reload(util)
importlib.reload(batch)

<module 'chem.batch' from '/Users/user/PycharmProjects/any-domain-pretrain-gnns/chem/batch.py'>