In [45]:
import random
import math
import gc

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

import torch
from torch_geometric.data import Data

tqdm.pandas()

In [46]:
n_samples = 1000
min_nodes = 4
max_nodes = 20
min_edges = 4

data_list = []

for i in tqdm(range(n_samples)):
    n_nodes = random.randint(min_nodes, max_nodes)
    x = torch.randn((n_nodes, 3))
        
#     max_edges = math.factorial(n_nodes)
    n_edges = random.randint(min_edges, n_nodes * 2)
    edge_index = edges = torch.randint(0, n_nodes - 1, (2, n_edges)).long()
    
    edge_features = torch.randn((n_edges, 1))
    self_loops = torch.zeros((n_edges, 1))
    edge_mask = torch.zeros((n_edges, 1))
    # edfe_attr: [feature_0 ... feature_n, self_loop, mask]
    edge_attr = torch.cat([edge_features, self_loops, edge_mask], dim=1)

    data = Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr
    )
    
    data_list.append(data)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [47]:
data_list[0].num_nodes

17

In [53]:
CONFIG = {
    "device": 0,
    "batch_size": 256,
    "epochs": 100,
    "lr": 0.001,
    "decay": 0,
    "num_layer": 5,
    "emb_dim": 300,
    "dropout_ratio": 0,
    "mask_rate": 0.15,
    "mask_edge": 0,
    "JK": "last",
    "output_model_file": '',
    "gnn_type": "gin",
    "seed": 0, 
    "num_workers": 8
}

In [54]:
CONFIG["num_layer"]

5

In [55]:
def compute_accuracy(pred, target):
    # return float(torch.sum((pred.detach() > 0) == target.to(torch.uint8)).cpu().item())/(pred.shape[0]*pred.shape[1])
    return float(torch.sum(torch.max(pred.detach(), dim=1)[1] == target).cpu().item()) / len(pred)


def train(config, model_list, loader, optimizer_list, device):
    model, linear_pred_edges = model_list
    optimizer_model, optimizer_linear_pred_edges = optimizer_list

    model.train()
    linear_pred_edges.train()

    loss_accum = 0
    acc_accum = 0
    n_broken_batches = 0

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)
        try:
            batch.masked_edge_idx
        except Exception:
            n_broken_batches += 1
            print("FUCKED UP BATCHES: ", n_broken_batches)
            print(batch.x, torch.LongTensor(batch.edge_index), batch.edge_attr)
            continue
        
        optimizer_model.zero_grad()
        optimizer_linear_pred_edges.zero_grad()
                    
        node_rep = model(batch.x, torch.LongTensor(batch.edge_index), batch.edge_attr)

        ### predict the edge types.
        masked_edge_index = batch.edge_index[:, batch.masked_edge_idx]
        edge_rep = node_rep[masked_edge_index[0]] + node_rep[masked_edge_index[1]]
        pred_edge = linear_pred_edges(edge_rep)

        # converting the binary classification to multiclass classification
        edge_label = batch.mask_edge_label[:, 0]

        acc_edge = compute_accuracy(pred_edge, edge_label)
        acc_accum += acc_edge
        
        loss = criterion(pred_edge, edge_label).type(torch.FloatTensor)
        loss.backward()

        optimizer_model.step()
        optimizer_linear_pred_edges.step()

        loss_accum += float(loss.cpu().item())

    return loss_accum / (step - n_broken_batches + 1), acc_accum / (step - n_broken_batches + 1)

In [71]:
import argparse

from chem.dataloader import DataLoaderMasking  # , DataListLoader
from chem.mydataset import MyDataset

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import numpy as np

from chem.model import GNN, GNN_graphpred

import pandas as pd

from chem.util import MaskEdge

from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool

criterion = nn.MSELoss()

import timeit
import warnings 
warnings.filterwarnings("ignore")

torch.manual_seed(0)
np.random.seed(0)
device = torch.device("cuda:" + str(CONFIG["device"])) if torch.cuda.is_available() else torch.device("cpu")
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(0)

print("num layer: %d mask rate: %f mask edge: %d" %(CONFIG["num_layer"], CONFIG["mask_rate"], CONFIG["mask_edge"]))

dataset = MyDataset(data_list, transform=MaskEdge(mask_rate=CONFIG["mask_rate"]))
loader = DataLoaderMasking(dataset, batch_size=1, shuffle=True, num_workers=0)

model = GNN(CONFIG["num_layer"], 
            CONFIG["emb_dim"], 
            x_input_dim=data_list[0].x.shape[1], 
            edge_attr_input_dim=data_list[0].edge_attr.shape[1], 
            JK=CONFIG["JK"], 
            drop_ratio=CONFIG["dropout_ratio"], 
            gnn_type=CONFIG["gnn_type"]).to(device)

linear_pred_edges = torch.nn.Linear(CONFIG["emb_dim"], 1).to(device)

model_list = [model, linear_pred_edges]

#set up optimizers
optimizer_model = optim.Adam(model.parameters(), lr=CONFIG["lr"], weight_decay=CONFIG["decay"])
optimizer_linear_pred_edges = optim.Adam(linear_pred_edges.parameters(), lr=CONFIG["lr"], weight_decay=CONFIG["decay"])

optimizer_list = [optimizer_model, optimizer_linear_pred_edges]

for epoch in range(1, CONFIG["epochs"] + 1):
    print("====epoch " + str(epoch))

    train_loss, train_acc = train(CONFIG, model_list, loader, optimizer_list, device)
    print(train_loss, train_acc)

if not CONFIG["model_file"] == "":
    torch.save(model.state_dict(), CONFIG["model_file"] + ".pth")

num layer: 5 mask rate: 0.150000 mask edge: 0
====epoch 1


Iteration:  23%|███████████████████████████▊                                                                                             | 230/1000 [00:03<00:10, 75.91it/s]


KeyboardInterrupt: 

In [69]:
import importlib
from chem import model, util, mydataset, batch

importlib.reload(model)
importlib.reload(util)
importlib.reload(mydataset)
importlib.reload(batch)

<module 'chem.batch' from '/Users/user/PycharmProjects/any-domain-pretrain-gnns/chem/batch.py'>