In [37]:
import torch
from torch_geometric.nn import GCNConv, global_mean_pool, GATConv, GINConv, Sequential
from torch.nn import Sequential as Seq, Linear, ReLU
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets import MoleculeNet, QM9
# from torch_geometric.loader import DataLoader
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score, roc_auc_score
import numpy as np
from chienn.data.edge_graph import to_edge_graph, collate_circle_index
from chienn.data.featurization import smiles_to_3d_mol, mol_to_data
from chienn.data.featurize import *
from chienn.model.chienn_layer import *

In [38]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [39]:
def preprocess_features(dataset):
    data_list = []
    idx = []
    for i, data in enumerate(tqdm(dataset)):
        smiles = data.smiles
        try:
            processed_data = smiles_to_data_with_circle_index(smiles)
            processed_data.y = data.y 
            data_list.append(processed_data)
        except:
            idx.append(i)
    return data_list

In [40]:
dataset = MoleculeNet(root="data/MoleculeNet", name="Tox21")
# processed_dataset = preprocess_features(dataset)
# torch.save(processed_dataset, 'processed_tox21.pt')


In [None]:
processed_dataset = torch.load('processed_tox21.pt')
print(len(processed_dataset)/len(dataset))

In [29]:
def custom_collate(data_list):
    return collate_with_circle_index(data_list, k_neighbors=3)

In [30]:
train_len = int(len(processed_dataset) * 0.8)
train_set, val_set = torch.utils.data.random_split(processed_dataset, [0.8, 0.2])
train_loader = DataLoader(train_set, batch_size=128, shuffle=True, collate_fn=custom_collate)
val_loader = DataLoader(val_set, batch_size=128, collate_fn=custom_collate)

In [31]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.chilayer = ChiENNLayer(hidden_dim=93)
        self.conv2 = GCNConv(93, hidden_channels)
        self.ln = Linear(hidden_channels, out_channels)
        self.pool = global_mean_pool

    def forward(self, data):
        x, edge_index = data.x.float(), data.edge_index
        # x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.chilayer(data))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = self.pool(x, data.batch)
        x = self.ln(x)
        return x

In [32]:
class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=4):
        super(GAT, self).__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, concat=True)
        self.chilayer = ChiENNLayer(hidden_dim=93)
        self.conv2 = GATConv(93, hidden_channels, heads=1, concat=False)
        self.ln = Linear(hidden_channels, out_channels)
        self.pool = global_mean_pool

    def forward(self, data):
        x, edge_index = data.x.float(), data.edge_index
        # x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.chilayer(data))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = self.pool(x, data.batch)
        x = self.ln(x)
        return x

In [33]:
class GIN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GIN, self).__init__()
        nn = Seq(Linear(in_channels, hidden_channels), ReLU(), Linear(hidden_channels, hidden_channels))
        self.conv1 = GINConv(nn)
        self.chilayer = ChiENNLayer(hidden_dim=93)
        nn = Seq(Linear(93, hidden_channels), ReLU(), Linear(hidden_channels, hidden_channels))
        self.conv2 = GINConv(nn)
        self.ln = Linear(hidden_channels, out_channels)
        self.pool = global_mean_pool

    def forward(self, data):
        x, edge_index = data.x.float(), data.edge_index
        # x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.chilayer(data))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = self.pool(x, data.batch)
        x = self.ln(x)
        return x


In [42]:
def validate(model, loader):
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for data in val_loader:
            data = data.to(device)
            output = torch.sigmoid(model(data))  # Apply sigmoid for probabilities

            # Mask NaN targets in validation
            mask = ~torch.isnan(data.y)
            y_true.append(data.y[mask].cpu())
            y_pred.append(output[mask].cpu())

    # Concatenate and evaluate AUC-ROC
    y_true = torch.cat(y_true, dim=0).numpy()
    y_pred = torch.cat(y_pred, dim=0).numpy()
    auc_roc = roc_auc_score(y_true, y_pred, average='macro')
    print(f"Validation AUC-ROC: {auc_roc:.4f}")
    return auc_roc

In [43]:
epochs = 20

In [None]:
model = GCN(in_channels=93, hidden_channels=128, out_channels=12)
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.BCEWithLogitsLoss()
model.to(device).train()
losses = []
roc_auc_scores = []
for epoch in range(epochs):
    total_loss = 0
    for i, data in enumerate(tqdm(train_loader)):
        data = data.to(device)
        output = model(data)
        targets = data.y.float().to(device)
        mask = ~torch.isnan(targets)
        loss = loss_fn(output[mask], targets[mask])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss
    avg_loss = total_loss / len(train_loader)
    print(f'epoch {epoch+1}: loss: {avg_loss:.4f}')
    losses.append(avg_loss.item())
    roc_auc_scores.append(validate(model, val_loader).item())
torch.save(losses, 'gcn_tox21_chienn_losses.pt')
torch.save(roc_auc_scores, 'gcn_tox21_chienn_auc_roc.pt')

In [None]:
model = GAT(in_channels=93, hidden_channels=128, out_channels=12)
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.MSELoss()
model.to('cuda').train()
losses = []
roc_auc_scores = []
for epoch in range(epochs):
    total_loss = 0
    for i, data in enumerate(tqdm(train_loader)):
        data = data.to(device)
        output = model(data)
        targets = data.y.float().to(device)
        mask = ~torch.isnan(targets)
        loss = loss_fn(output[mask], targets[mask])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss
    avg_loss = total_loss / len(train_loader)
    print(f'epoch {epoch+1}: loss: {avg_loss:.4f}')
    losses.append(avg_loss.item())
    roc_auc_scores.append(validate(model, val_loader).item())
torch.save(losses, 'gat_tox21_chienn_losses.pt')
torch.save(roc_auc_scores, 'gat_tox21_chienn_auc_roc.pt')

In [None]:
model = GIN(in_channels=93, hidden_channels=128, out_channels=12)
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.MSELoss()
model.to('cuda').train()
losses = []
roc_auc_scores = []
for epoch in range(epochs):
    total_loss = 0
    for i, data in enumerate(tqdm(train_loader)):
        data = data.to(device)
        output = model(data)
        targets = data.y.float().to(device)
        mask = ~torch.isnan(targets)
        loss = loss_fn(output[mask], targets[mask])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss
    avg_loss = total_loss / len(train_loader)
    print(f'epoch {epoch+1}: loss: {avg_loss:.4f}')
    losses.append(avg_loss.item())
    roc_auc_scores.append(validate(model, val_loader).item())
torch.save(losses, 'gin_tox21_chienn_losses.pt')
torch.save(roc_auc_scores, 'gin_tox21_chienn_auc_roc.pt')