In [36]:
import random
import math
import gc

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

import torch
from torch import nn
from torch_geometric.data import Data
from torch_geometric.nn.inits import uniform

tqdm.pandas()

In [37]:
n_samples = 1000
min_nodes = 4
max_nodes = 20
min_edges = 4

data_list = []

for i in tqdm(range(n_samples)):
    n_nodes = random.randint(min_nodes, max_nodes)
    x_features = torch.randn((n_nodes, 3))
    x_mask = torch.zeros((n_nodes, 1))
    x = torch.cat([x_features, x_mask], dim=1)
        
#     max_edges = math.factorial(n_nodes)
    n_edges = random.randint(min_edges, n_nodes * 2)
    edge_index = torch.randint(0, n_nodes - 1, (2, n_edges)).long()
    
    direction = torch.tensor([[int(edge_index[0][i] > edge_index[1][i])] for i in range(n_edges)])
    
    edge_attr = torch.cat([torch.randn((n_edges, 1)), direction, torch.zeros((n_edges, 2))], dim=1)

    data = Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr
    )
    
    data_list.append(data)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [38]:
data_list[0].x.shape[1]

4

In [39]:
data_list[0].edge_attr.shape[1]

4

In [40]:
CONFIG = {
    "device": 0,
    "batch_size": 256,
    "epochs": 100,
    "lr": 0.001,
    "decay": 0,
    "num_layer": 5,
    "emb_dim": 300,
    "dropout_ratio": 0,
    "mask_rate": 0.15,
    "mask_edge": 1,
    "JK": "last",
    "output_model_file": '',
    "gnn_type": "gat",
    "seed": 0, 
    "num_workers": 8
}

In [41]:
CONFIG["num_layer"]

5

In [42]:
def cycle_index(num, shift):
    arr = torch.arange(num) + shift
    arr[-shift:] = torch.arange(shift)
    return arr

class Discriminator(nn.Module):
    def __init__(self, hidden_dim):
        super(Discriminator, self).__init__()
        self.weight = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.reset_parameters()

    def reset_parameters(self):
        size = self.weight.size(0)
        uniform(size, self.weight)

    def forward(self, x, summary):
        h = torch.matmul(summary, self.weight)
        return torch.sum(x * h, dim = 1)

class Infomax(nn.Module):
    def __init__(self, gnn, discriminator):
        super(Infomax, self).__init__()
        self.gnn = gnn
        self.discriminator = discriminator
        self.loss = nn.BCEWithLogitsLoss()
        self.pool = global_mean_pool


def train(args, model, device, loader, optimizer):
    model.train()

    train_acc_accum = 0
    train_loss_accum = 0

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)
        node_emb = model.gnn(batch.x, batch.edge_index, batch.edge_attr)
        summary_emb = torch.sigmoid(model.pool(node_emb, batch.batch))

        positive_expanded_summary_emb = summary_emb[batch.batch]

        shifted_summary_emb = summary_emb[cycle_index(len(summary_emb), 1)]
        negative_expanded_summary_emb = shifted_summary_emb[batch.batch]

        positive_score = model.discriminator(node_emb, positive_expanded_summary_emb)
        negative_score = model.discriminator(node_emb, negative_expanded_summary_emb)      

        optimizer.zero_grad()
        loss = model.loss(positive_score, torch.ones_like(positive_score)) + model.loss(negative_score, torch.zeros_like(negative_score))
        loss.backward()

        optimizer.step()

        train_loss_accum += float(loss.detach().cpu().item())
        acc = (torch.sum(positive_score > 0) + torch.sum(negative_score < 0)).to(torch.float32)/float(2*len(positive_score))
        train_acc_accum += float(acc.detach().cpu().item())

    return train_acc_accum/step, train_loss_accum/step

In [46]:
import argparse

from chem.dataloader import DataLoaderMasking  # , DataListLoader
from chem.mydataset import MyDataset
from torch_geometric.data import DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import numpy as np

from chem.model import GNN, GNN_graphpred

import pandas as pd

from chem.util import MaskNode

from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool

criterion = nn.MSELoss()

torch.manual_seed(0)
np.random.seed(0)
device = torch.device("cuda:" + str(CONFIG["device"])) if torch.cuda.is_available() else torch.device("cpu")
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(0)


#set up dataset
dataset = MyDataset(data_list)
loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)

#set up model
gnn = GNN(CONFIG["num_layer"], 
          CONFIG["emb_dim"], 
          x_input_dim = data_list[0].x.shape[1],
          edge_attr_input_dim = data_list[0].edge_attr.shape[1],
          JK = CONFIG["JK"],
          drop_ratio = CONFIG["dropout_ratio"], 
          gnn_type = CONFIG["gnn_type"])

discriminator = Discriminator(CONFIG["emb_dim"])
model = Infomax(gnn, discriminator)
model.to(device)

#set up optimizer
optimizer = optim.Adam(model.parameters(), lr=CONFIG["lr"], weight_decay=CONFIG["decay"])

for epoch in range(1, CONFIG["epochs"] + 1):
    print("====epoch " + str(epoch))

    train_loss = train(CONFIG, model=model, device=device, loader=loader, optimizer=optimizer)
    print(train_loss)

if not CONFIG["model_file"] == "":
    torch.save(model.state_dict(), CONFIG["model_file"] + ".pth")



====epoch 1


Iteration: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:15<00:00, 64.07it/s]


(0.5005005005005005, 4.4379396946938545)
====epoch 2


Iteration: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:15<00:00, 64.62it/s]


(0.5005005005005005, 2.654238336556428)
====epoch 3


Iteration:  27%|████████████████████████████████▎                                                                                        | 267/1000 [00:04<00:13, 55.92it/s]


KeyboardInterrupt: 

In [None]:
import importlib
from chem import model, util, batch

importlib.reload(model)
importlib.reload(util)
importlib.reload(batch)

In [None]:
!cat "chem/batch.py"