In [1]:
import os
import torch
import numpy as np
from tqdm import tqdm

from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

from qm9_dataset import QM9DGLDataset

import dgl
from dgllife.model.gnn.gat import GAT

import pytorch_lightning as pl
import torchmetrics.functional as tm
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

config = {
    "data_path": "./data/",
    "task": "mu",
    "train_data": "qm9_train_data.pt",
    "test_data": "qm9_test_data.pt",
    "batch_size": 256,
    "num_workers": 16
}

In [2]:
dataset = QM9DGLDataset(config["data_path"],
                        config["task"],
                        file_name=config["train_data"],
                        mode='train')

train_dataset, val_dataset = dataset.train_val_random_split(0.8)


train_dataloader = DataLoader(train_dataset,
                              batch_size=config["batch_size"],
                              shuffle=True,
                              collate_fn=dataset.collate_fn,
                              num_workers=config["num_workers"])

valid_dataloader = DataLoader(val_dataset,
                              batch_size=config["batch_size"],
                              shuffle=False,
                              collate_fn=dataset.collate_fn,
                              num_workers=config["num_workers"])

# Test Dataset
test_dataset = QM9DGLDataset(config["data_path"],
                             config["task"],
                             file_name=config["test_data"],
                             mode='test')

test_dataloader = DataLoader(test_dataset,
                             batch_size=config["batch_size"],
                             shuffle=False,
                             collate_fn=test_dataset.collate_fn,
                             num_workers=config["num_workers"])

print(f"Train set size: {len(train_dataset)}")
print(f"Validation set size: {len(val_dataset)}")
print(f"Test set size: {len(test_dataset)}")

Loaded train-set, task: mu, source: ./data/, length: 98123
Loaded test-set, task: mu, source: ./data/, length: 32708
Train set size: 78498
Validation set size: 19625
Test set size: 32708


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch.glob import AvgPooling


class MLPReadout(nn.Module):
    def __init__(self, input_dim, output_dim, num_layers=2):  # L=nb_hidden_layers
        super(MLPReadout, self).__init__()
        FC_layers = []
        
        # Add hidden FC layers
        for l in range(num_layers):
            FC_layers += [
                nn.Linear(input_dim // 2 ** l, input_dim // 2 ** (l + 1), bias=True),
                nn.ReLU(),
            ]

        # Add output FC layer
        FC_layers.append(nn.Linear(input_dim // 2 ** num_layers, output_dim, bias=True))
        self.FC_layers = nn.Sequential(*FC_layers)

    def forward(self, x):
        x = self.FC_layers(x)
    
        return x

    
class GNN(nn.Module):
    def __init__(self, num_atom_type, graph_encoder_dim, 
                 hidden_dims=[32, 32, 64, 64, 128, 256], 
                 num_heads=[4, 4, 6, 6, 8, 12]):
        super().__init__()
        self.embedding = nn.Embedding(num_atom_type, graph_encoder_dim)
        self.backbone = GAT(in_feats=graph_encoder_dim,
                            hidden_feats=hidden_dims,
                            num_heads=num_heads,
                            feat_drops=[0.1] * len(hidden_dims), 
                            activations=[nn.ReLU()] * len(hidden_dims))
        
        self.pooling_layer = AvgPooling()
        self.MLP_layer = MLPReadout(input_dim=hidden_dims[-1], output_dim=1, num_layers=3)
        
        
    def forward(self, graph):
        feats = graph.ndata['f']
        embeddings = self.embedding(feats)
       
        x = self.backbone(graph, embeddings)     
        graph_embedding = self.pooling_layer(graph, x)
        out = self.MLP_layer(graph_embedding)
        
        return out

In [6]:
class Mu_predictor(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
    
    def step(self, batch):
        graph, y = batch
        pred = self.model(graph)
        loss = F.mse_loss(pred, y)
        acc = tm.mean_absolute_error(pred, y)
        
        return pred, loss, acc
    
       
    def training_step(self, batch, batch_idx):
        _, loss, acc = self.step(batch)
        
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
#         self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
        
        return loss

    
    def validation_step(self, batch, batch_idx):
        _, loss, acc = self.step(batch)
        
        self.log('valid_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("valid_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
    
    
    def test_step(self, batch, batch_idx):
        _, loss, acc = self.step(batch)
        
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
    
    
    def predict_step(self, batch, batch_idx):
        return self.model(batch)
    
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
    
        return {"optimizer": optimizer, "lr_scheduler": scheduler}
    
    
model = GNN(num_atom_type=5, graph_encoder_dim=64)
predictor = Mu_predictor(model)
callbacks = [
    ModelCheckpoint(monitor='valid_loss', save_top_k=3, dirpath='weights/GAT_large_batch', filename='GAT-{epoch:03d}-{valid_loss:.4f}-{valid_acc:.4f}'),
]

ckpt_fname = "GAT-epoch=199-valid_loss=0.4528-valid_acc=0.4824.ckpt"
predictor = predictor.load_from_checkpoint("weights/GAT_large_batch/" + ckpt_fname, model=model)

trainer = pl.Trainer(max_epochs=200, gpus=1, enable_progress_bar=True, callbacks=callbacks)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [7]:
trainer.fit(predictor, train_dataloader, valid_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type | Params
-------------------------------
0 | model | GNN  | 7.4 M 
-------------------------------
7.4 M     Trainable params
0         Non-trainable params
7.4 M     Total params
29.661    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=200` reached.


In [8]:
ckpt_fname = "GAT-epoch=150-valid_loss=0.3930-valid_acc=0.4281.ckpt"
predictor = predictor.load_from_checkpoint("weights/GAT_large_batch/" + ckpt_fname, model=model)

pred = trainer.predict(predictor, test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: 307it [00:00, ?it/s]

In [9]:
preds = []

def to_np(x):
    return x.cpu().detach().numpy()

for p in tqdm(pred):
    preds.append(to_np(p))

preds = np.concatenate(preds, axis=0)
np.savetxt('pred.csv', preds)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 128/128 [00:00<00:00, 130784.63it/s]
