In [11]:
import torch
from torch_geometric.nn import GCNConv, global_mean_pool, GATConv, GINConv, Sequential
from torch.nn import Sequential as Seq, Linear, ReLU
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets import MoleculeNet, QM9
# from torch_geometric.loader import DataLoader
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import numpy as np
from chienn.data.edge_graph import to_edge_graph, collate_circle_index
from chienn.data.featurization import smiles_to_3d_mol, mol_to_data
from chienn.data.featurize import *
from chienn.model.chienn_layer import *

In [12]:
def preprocess_features(dataset):
    data_list = []
    idx = []
    for i, data in enumerate(tqdm(dataset)):
        smiles = data.smiles
        try:
            processed_data = smiles_to_data_with_circle_index(smiles)
            processed_data.y = data.y 
            data_list.append(processed_data)
        except:
            idx.append(i)
    return data_list

In [2]:
# dataset = QM9(root='data/MoleculeNet/qm9')
# mean = dataset.data.y.mean(dim=0)
# std = dataset.data.y.std(dim=0)
# dataset.data.y = (dataset.data.y - mean) / std
# processed_dataset = preprocess_features(dataset)
# torch.save(processed_dataset, 'processed_qm9.pt')
# processed_data = preprocess_features(dataset)
dataset = torch.load('processed_qm9.pt', weights_only=False)

In [13]:
def custom_collate(data_list):
    return collate_with_circle_index(data_list, k_neighbors=3)

In [19]:
train_len = int(len(dataset) * 0.8)
train_set, val_set = torch.utils.data.random_split(dataset, [0.8, 0.2])
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, collate_fn=custom_collate)
val_loader = DataLoader(val_set, batch_size=64, collate_fn=custom_collate)

In [None]:
len(dataset), len(train_set), len(val_set)

In [6]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.chilayer = ChiENNLayer(hidden_dim=93)
        self.conv2 = GCNConv(93, hidden_channels)
        self.ln = Linear(hidden_channels, out_channels)
        self.pool = global_mean_pool

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        # x = F.elu(self.conv1(x, edge_index))
        x = F.elu(self.chilayer(data))
        x = F.dropout(x, training=self.training)
        x = F.elu(self.conv2(x, edge_index))
        x = self.pool(x, data.batch)
        x = self.ln(x)
        return x

In [7]:
class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=4):
        super(GAT, self).__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, concat=True)
        self.chilayer = ChiENNLayer(hidden_dim=93)
        self.conv2 = GATConv(93, hidden_channels, heads=1, concat=False)
        self.ln = Linear(hidden_channels, out_channels)
        self.pool = global_mean_pool

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        # x = F.elu(self.conv1(x, edge_index))
        x = F.elu(self.chilayer(data))
        x = F.dropout(x, training=self.training)
        x = F.elu(self.conv2(x, edge_index))
        x = self.pool(x, data.batch)
        x = self.ln(x)
        return x


In [8]:
class GIN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GIN, self).__init__()
        nn = Seq(Linear(in_channels, hidden_channels), ReLU(), Linear(hidden_channels, hidden_channels))
        self.conv1 = GINConv(nn)
        self.chilayer = ChiENNLayer(hidden_dim=93)
        nn = Seq(Linear(93, hidden_channels), ReLU(), Linear(hidden_channels, hidden_channels))
        self.conv2 = GINConv(nn)
        self.ln = Linear(hidden_channels, out_channels)
        self.pool = global_mean_pool

    def forward(self, data):
        x, edge_index = data.x.float(), data.edge_index
        # x = F.elu(self.conv1(x, edge_index))
        x = F.elu(self.chilayer(data))
        x = F.dropout(x, training=self.training)
        x = F.elu(self.conv2(x, edge_index))
        x = self.pool(x, data.batch)
        x = self.ln(x)
        return x


In [33]:
def validate(model, loader, task):
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for data in tqdm(loader):
            output = model(data.to('cuda')).flatten()
            y_true.append(data.y[:, task].cpu())
            y_pred.append(output.cpu())

    y_true = torch.cat(y_true).numpy()
    y_pred = torch.cat(y_pred).numpy()

    # Metrics
    mae = mean_absolute_error(y_true, y_pred)
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    r2 = r2_score(y_true, y_pred)
    print(y_pred[-1], y_true[-1])
    return mae, rmse, r2

In [34]:
epochs = 10

In [None]:
for task in range(3, 12):
    print(f'task {task+1}:')
    model = GCN(in_channels=93, hidden_channels=32, out_channels=1)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    loss_fn = torch.nn.MSELoss()
    model.to('cuda').train()
    first = True
    losses = []
    validation = {'mae': [], 'rmse': [], 'r2': []}
    for epoch in range(epochs):
        total_loss = 0
        for i, data in enumerate(tqdm(train_loader)):
            output = model(data.to('cuda')).flatten()
            targets = data.y[:, task].to('cuda')
            loss = loss_fn(output, targets)
            if first:
                print(loss)
                first = False
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss
        print(f'epoch {epoch+1}: loss: {total_loss / len(train_loader)}')
        mae, rmse, r2 = validate(model, val_loader, task)
        print(f'validate mae: {mae}, rmse: {rmse}, r2: {r2}')
        losses.append(total_loss / len(train_loader))
        validation['mae'].append(mae)
        validation['rmse'].append(rmse)
        validation['r2'].append(r2)
    torch.save(losses, f'gcn_chienn_losses_{task+1}.pt')
    torch.save(validation, f'gcn_chienn_validation_{task+1}.pt')
    print(f'task {task+1} finished')

In [None]:
model = GAT(in_channels=93, hidden_channels=32, out_channels=1)
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.MSELoss()
model.to('cuda').train()
first = True
for epoch in range(epochs):
    total_loss = 0
    for i, data in enumerate(tqdm(train_loader)):
        output = model(data.to('cuda')).flatten()
        targets = data.y[:, 0].to('cuda')
        loss = loss_fn(output, targets)
        if first:
            print(loss)
            first = False
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss
    print(f'epoch {epoch+1}: loss: {total_loss / len(train_loader)}')
    mae, rmse, r2 = validate(model, val_loader)
    print(f'validate mae: {mae}, rmse: {rmse}, r2: {r2}')

In [None]:
model = GIN(in_channels=93, hidden_channels=32, out_channels=1)
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.MSELoss()
model.to('cuda').train()
first = True
for epoch in range(epochs):
    total_loss = 0
    for i, data in enumerate(tqdm(train_loader)):
        output = model(data.to('cuda')).flatten()
        targets = data.y[:, 0].to('cuda')
        loss = loss_fn(output, targets)
        if first:
            print(loss)
            first = False
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss
    print(f'epoch {epoch+1}: loss: {total_loss / len(train_loader)}')
    mae, rmse, r2 = validate(model, val_loader)
    print(f'validate mae: {mae}, rmse: {rmse}, r2: {r2}')