In [None]:
import os
os.environ['DGLBACKEND'] = 'pytorch'
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import dgl.function as fn
import dgl
from dgllife.data import Tox21
from dgllife.utils import SMILESToBigraph, CanonicalAtomFeaturizer, CanonicalBondFeaturizer, RandomSplitter
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from dgl.data.utils import split_dataset
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score as rac
import torch.optim as optim
from tqdm.notebook import tqdm, trange
from scipy import signal
smiles_to_g = SMILESToBigraph(node_featurizer=CanonicalAtomFeaturizer(), edge_featurizer=CanonicalBondFeaturizer())
dataset = Tox21(smiles_to_g)
dataset[0]


# Batching a list of datapoints for dataloader.
def collate_molgraphs(data):
    smiles, graphs, labels, masks = map(list, zip(*data))

    g = dgl.batch(graphs)
    g.set_n_initializer(dgl.init.zero_initializer)
    g.set_e_initializer(dgl.init.zero_initializer)
    labels = torch.stack(labels, dim=0)
    masks = torch.stack(masks, dim=0)
    return smiles, g, labels, masks


train_set, val_set, test_set = split_dataset(dataset, shuffle=True)
train_loader = DataLoader(train_set, batch_size=128, shuffle=True, collate_fn=collate_molgraphs)
val_loader = DataLoader(val_set, batch_size=128, shuffle=True, collate_fn=collate_molgraphs)
test_loader = DataLoader(test_set, batch_size=128, shuffle=True, collate_fn=collate_molgraphs)


class Meter(object):
    """Track and summarize model performance on a dataset for
    (multi-label) binary classification."""

    def __init__(self):
        self.mask = []
        self.y_pred = []
        self.y_true = []

    def update(self, y_pred, y_true, mask):
        """Update for the result of an iteration
        Parameters
        ----------
        y_pred : float32 tensor
            Predicted molecule labels with shape (B, T),
            B for batch size and T for the number of tasks
        y_true : float32 tensor
            Ground truth molecule labels with shape (B, T)
        mask : float32 tensor
            Mask for indicating the existence of ground
            truth labels with shape (B, T)
        """
        self.y_pred.append(y_pred.detach().cpu())
        self.y_true.append(y_true.detach().cpu())
        self.mask.append(mask.detach().cpu())

    def roc_auc_score(self):
        """Compute roc-auc score for each task.
        Returns
        -------
        list of float
            roc-auc score for all tasks
        """
        mask = torch.cat(self.mask, dim=0)
        y_pred = torch.cat(self.y_pred, dim=0)
        y_true = torch.cat(self.y_true, dim=0)
        # This assumes binary case only
        y_pred = torch.sigmoid(y_pred)
        n_tasks = y_true.shape[1]
        scores = []
        for task in range(n_tasks):
            task_wise_mask = mask[:, task] > 0.5
            task_wise_y_true = y_true[:, task][task_wise_mask]
            if len(torch.unique(task_wise_y_true)) == 1:
                scores.append(np.nan)
                continue
            task_wise_y_pred = y_pred[:, task][task_wise_mask]
            scores.append(rac(task_wise_y_true, task_wise_y_pred))
        return scores


class GATLayer1(nn.Module):
    """Single GAT layer implementation"""

    def __init__(self, in_feats, out_feats, num_heads, activation):
        super(GATLayer1, self).__init__()
        self.num_heads = num_heads
        self.heads = nn.ModuleList()
        for _ in range(num_heads):
            self.heads.append(self.build_head(in_feats, out_feats))
        self.out_feats = out_feats
        self.activation = activation

    def build_head(self, in_feats, out_feats):
        return nn.Linear(in_feats, out_feats, bias=False)

    def forward(self, g, h):
        """
        Parameters
        ----------
        g : dgl.DGLGraph
            DGLGraph for a batch of graphs
        h : torch.Tensor
            Node features with shape (B, N, D), where B for the batch size,
            N for the number of nodes, D for the number of node features

        Returns
        -------
        torch.Tensor
            New node features with shape (B, N, H), where H for the output
            node feature size
        """
        if self.num_heads > 1:
            hs = []
            for head in self.heads:
                hs.append(head(g, h).unsqueeze(0))
            hs = torch.cat(hs, dim=0)
            h = torch.mean(hs, dim=0)
        else:
            h = self.heads[0](g, h)
        if self.activation is not None:
            h = self.activation(h)
        return h


class MultiHeadGATLayer(nn.Module):
    """Multiple GAT layer implementation"""

    def __init__(self, in_feats, out_feats, num_heads, activation, feat_drop, attn_drop, negative_slope, residual):
        super(MultiHeadGATLayer, self).__init__()
        self.gat_layers = nn.ModuleList()
        self.residual = residual
        self.num_heads = num_heads
        for i in range(num_heads):
            self.gat_layers.append(GATLayer1(in_feats, out_feats, 1, activation))
        self.feat_drop = nn.Dropout(feat_drop)
        self.attn_drop = nn.Dropout(attn_drop)
        self.negative_slope = negative_slope

    def forward(self, g, h):
        """
        Parameters
        ----------
        g : dgl.DGLGraph
            DGLGraph for a batch of graphs
        h : torch.Tensor
            Node features with shape (B, N, D), where B for the batch size,
            N for the number of nodes, D for the number of node features

        Returns
        -------
        torch.Tensor
            New node features with shape (B, N, H), where H for the output
            node feature size
        """
        # (B, N, D) -> (B, N, H)
        head_outs = []
        for l in self.gat_layers:
            head_outs.append(l(g, h))
        h = torch.cat(head_outs, dim=-1)
        if self.residual:
            h = h + self.feat_drop(h)
        return h


class GAT1(nn.Module):
    def __init__(self,
                 g,
                 num_layers,
                 in_dim,
                 num_hidden,
                 num_classes,
                 num_heads,
                 activation=F.elu,
                 feat_drop=0.6,
                 attn_drop=0.6,
                 negative_slope=0.2,
                 residual=False):
        super(GAT1, self).__init__()
        self.g = g
        self.num_layers = num_layers
        self.gat_layers = nn.ModuleList()
        self.activation = activation
        self.gat_layers.append(MultiHeadGATLayer(in_dim, num_hidden, num_heads[0], activation,
                                                 feat_drop, attn_drop, negative_slope, residual))
        for l in range(1, num_layers):
            self.gat_layers.append(MultiHeadGATLayer(num_hidden * num_heads[l-1],
                                                     num_hidden, num_heads[l], activation,
                                                     feat_drop, attn_drop, negative_slope, residual))
        self.gat_layers.append(MultiHeadGATLayer(num_hidden * num_heads[-2],
                                                 num_classes, num_heads[-1], None,
                                                 feat_drop, attn_drop, negative_slope, residual))

    def forward(self, inputs):
        h = inputs
        for l in range(self.num_layers):
            h = self.gat_layers[l](self.g, h).flatten(1)
        # output projection
        logits = self.gat_layers[-1](self.g, h).mean(1)
        return logits


class EarlyStopping:
    def __init__(self, patience=5, delta=0.0):
        self.patience = patience
        self.delta = delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_score is None:
            self.best_score = val_loss
        elif val_loss > self.best_score - self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = val_loss
            self.counter = 0


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GAT1(train_set.g,
             num_layers=3,
             in_dim=train_set.g.ndata['feat'].shape[-1],
             num_hidden=128,
             num_classes=12,
             num_heads=[4, 4, 6],
             activation=F.elu,
             feat_drop=0.6,
             attn_drop=0.6,
             negative_slope=0.2,
             residual=True).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = BCEWithLogitsLoss(pos_weight=train_set.pos_weights.to(device))

# Training loop
num_epochs = 200
patience = 10
early_stopping = EarlyStopping(patience=patience)
best_val_score = -1
best_model_path = "best_model.pt"

model.train()

for epoch in range(num_epochs):
    train_meter = Meter()

    for batch_id, (_, bg, label, mask) in enumerate(train_loader):
        bg = bg.to(device)
        label = label.to(device)
        mask = mask.to(device)
        atom_feats = bg.ndata.pop('feat').to(device)
        logits = model(atom_feats)
        loss = criterion(logits, label.float())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_meter.update(logits, label, mask)

    model.eval()
    val_score = np.mean(train_meter.roc_auc_score())
    if val_score > best_val_score:
        best_val_score = val_score
        torch.save(model.state_dict(), best_model_path)
    else:
        early_stopping(val_score)
        if early_stopping.early_stop:
            print("Early stopping!")
            break
    model.train()

# Load the best model and evaluate on the test set
model.load_state_dict(torch.load(best_model_path))
model.eval()

test_meter = Meter()
for _, bg, label, mask in test_loader:
    bg = bg.to(device)
    label = label.to(device)
    mask = mask.to(device)
    atom_feats = bg.ndata.pop('feat').to(device)
    with torch.no_grad():
        logits = model(atom_feats)
    test_meter.update(logits, label, mask)

test_score = np.mean(test_meter.roc_auc_score())
print(f"Test ROC-AUC score: {test_score}")


: 