In [23]:
import torch
from torch_geometric.nn import GCNConv, global_mean_pool, GATConv, GINConv, Sequential
from torch.nn import Sequential as Seq, Linear, ReLU
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets import MoleculeNet, QM9
from torch_geometric.loader import DataLoader
from tqdm import tqdm
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score, roc_auc_score
import numpy as np


In [24]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [25]:
dataset = MoleculeNet(root="data/MoleculeNet", name="Tox21")

In [26]:
train_len = int(len(dataset) * 0.8)
train_set, val_set = torch.utils.data.random_split(dataset, [0.8, 0.2])
train_loader = DataLoader(train_set, batch_size=128, shuffle=True)
val_loader = DataLoader(val_set, batch_size=128)

In [27]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.ln = Linear(hidden_channels, out_channels)
        self.pool = global_mean_pool

    def forward(self, data):
        x, edge_index = data.x.float(), data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = self.pool(x, data.batch)
        x = self.ln(x)
        return x

In [28]:
class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=4):
        super(GAT, self).__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, concat=True)
        self.conv2 = GATConv(hidden_channels * heads, hidden_channels, heads=1, concat=False)
        self.ln = Linear(hidden_channels, out_channels)
        self.pool = global_mean_pool

    def forward(self, data):
        x, edge_index = data.x.float(), data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = self.pool(x, data.batch)
        x = self.ln(x)
        return x

In [29]:
class GIN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GIN, self).__init__()
        nn = Seq(Linear(in_channels, hidden_channels), ReLU(), Linear(hidden_channels, hidden_channels))
        self.conv1 = GINConv(nn)
        nn = Seq(Linear(hidden_channels, hidden_channels), ReLU(), Linear(hidden_channels, hidden_channels))
        self.conv2 = GINConv(nn)
        self.ln = Linear(hidden_channels, out_channels)
        self.pool = global_mean_pool

    def forward(self, data):
        x, edge_index = data.x.float(), data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = self.pool(x, data.batch)
        x = self.ln(x)
        return x


In [30]:
def validate(model, loader):
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for data in val_loader:
            data = data.to(device)
            output = torch.sigmoid(model(data))  # Apply sigmoid for probabilities

            # Mask NaN targets in validation
            mask = ~torch.isnan(data.y)
            y_true.append(data.y[mask].cpu())
            y_pred.append(output[mask].cpu())

    # Concatenate and evaluate AUC-ROC
    y_true = torch.cat(y_true, dim=0).numpy()
    y_pred = torch.cat(y_pred, dim=0).numpy()
    auc_roc = roc_auc_score(y_true, y_pred, average='macro')
    print(f"Validation AUC-ROC: {auc_roc:.4f}")
    return auc_roc

In [31]:
epochs = 20

In [None]:
model = GCN(in_channels=dataset.num_node_features, hidden_channels=128, out_channels=12)
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.BCEWithLogitsLoss()
model.to(device).train()
losses = []
auc_roc = []
for epoch in range(epochs):
    total_loss = 0
    for i, data in enumerate(tqdm(train_loader)):
        data = data.to(device)
        output = model(data)
        targets = data.y.float().to(device)
        mask = ~torch.isnan(targets)
        loss = loss_fn(output[mask], targets[mask])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss
    avg_loss = total_loss / len(train_loader)
    print(f'epoch {epoch+1}: loss: {avg_loss:.4f}')
    validate(model, val_loader)
    losses.append(avg_loss.item())
    auc_roc.append(validate(model, val_loader).item())
torch.save(losses, 'gcn_tox21_losses.pt')
torch.save(auc_roc, 'gcn_tox21_auc_roc.pt')

In [None]:
model = GAT(in_channels=dataset.num_node_features, hidden_channels=128, out_channels=12)
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.MSELoss()
model.to('cuda').train()
losses = []
auc_roc = []
for epoch in range(epochs):
    total_loss = 0
    for i, data in enumerate(tqdm(train_loader)):
        data = data.to(device)
        output = model(data)
        targets = data.y.float().to(device)
        mask = ~torch.isnan(targets)
        loss = loss_fn(output[mask], targets[mask])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss
    avg_loss = total_loss / len(train_loader)
    print(f'epoch {epoch+1}: loss: {avg_loss:.4f}')
    validate(model, val_loader)
    losses.append(avg_loss.item())
    auc_roc.append(validate(model, val_loader).item())
torch.save(losses, 'gat_tox21_losses.pt')
torch.save(auc_roc, 'gat_tox21_auc_roc.pt')

In [None]:
model = GIN(in_channels=dataset.num_node_features, hidden_channels=128, out_channels=12)
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.MSELoss()
model.to('cuda').train()
losses = []
auc_roc = []
for epoch in range(epochs):
    total_loss = 0
    for i, data in enumerate(tqdm(train_loader)):
        data = data.to(device)
        output = model(data)
        targets = data.y.float().to(device)
        mask = ~torch.isnan(targets)
        loss = loss_fn(output[mask], targets[mask])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss
    avg_loss = total_loss / len(train_loader)
    print(f'epoch {epoch+1}: loss: {avg_loss:.4f}')
    losses.append(avg_loss.item())
    auc_roc.append(validate(model, val_loader).item())
torch.save(losses, 'gin_tox21_losses.pt')
torch.save(auc_roc, 'gin_tox21_auc_roc.pt')