In [1]:
import os
import pandas as pd
from datamol.mol import standardize_smiles
from tdc.benchmark_group import admet_group
from contextlib import redirect_stdout, redirect_stderr

group = admet_group(path='admet_data/')
seed = 42

columns = ['smiles']
num_mols = 0
df = pd.DataFrame(columns=columns)
task_types = {}

for dataset_i, dataset_name in enumerate(group.dataset_names):
    print(f"{dataset_i + 1} / {len(group.dataset_names)} - {dataset_name}")
    benchmark = group.get(dataset_name)
    name = benchmark['name']
    task_types[name] = 'classification' if len(benchmark['test']['Y'].unique()) == 2 else 'regression'
    
    with open(os.devnull, 'w') as fnull, redirect_stdout(fnull), redirect_stderr(fnull): # suppress output
        mols_train, mols_valid = group.get_train_valid_split(benchmark=name, split_type='default', seed=seed)

    temp_df = pd.DataFrame({
        'smiles': mols_train['Drug'],
        dataset_name: mols_train['Y']
    })

    num_mols += len(temp_df)
    df = pd.merge(df, temp_df, on='smiles', how='outer')

print(f"All mols: {num_mols}")
print(f"Unique mols: {len(df)}")


Found local copy...


1 / 22 - caco2_wang
2 / 22 - hia_hou
3 / 22 - pgp_broccatelli
4 / 22 - bioavailability_ma
5 / 22 - lipophilicity_astrazeneca
6 / 22 - solubility_aqsoldb
7 / 22 - bbb_martins
8 / 22 - ppbr_az
9 / 22 - vdss_lombardo
10 / 22 - cyp2d6_veith
11 / 22 - cyp3a4_veith
12 / 22 - cyp2c9_veith
13 / 22 - cyp2d6_substrate_carbonmangels
14 / 22 - cyp3a4_substrate_carbonmangels
15 / 22 - cyp2c9_substrate_carbonmangels
16 / 22 - half_life_obach
17 / 22 - clearance_microsome_az
18 / 22 - clearance_hepatocyte_az
19 / 22 - herg
20 / 22 - ames
21 / 22 - dili
22 / 22 - ld50_zhu
All mols: 53695
Unique mols: 35628


In [2]:
from minimol import Minimol

featurizer = Minimol()

In [None]:
df['embeddings'] = featurizer(list(df['smiles']))

In [4]:
df.head()

Unnamed: 0,smiles,caco2_wang,hia_hou,pgp_broccatelli,bioavailability_ma,lipophilicity_astrazeneca,solubility_aqsoldb,bbb_martins,ppbr_az,vdss_lombardo,...,cyp3a4_substrate_carbonmangels,cyp2c9_substrate_carbonmangels,half_life_obach,clearance_microsome_az,clearance_hepatocyte_az,herg,ames,dili,ld50_zhu,embeddings
0,CNC1(c2ccccc2Cl)CCCCC1=O,-4.26,,,0.0,,,1.0,44.84,,...,,,,,,,,0.0,,"[tensor(1.1916), tensor(0.3334), tensor(0.8102..."
1,CNC1(c2ccccc2Cl)CCCCC1=O,-4.26,,,0.0,,,1.0,42.01,,...,,,,,,,,0.0,,"[tensor(1.1916), tensor(0.3334), tensor(0.8102..."
2,C/C=C/C/C=C/CCC(=O)[C@@H]1O[C@@H]1C(N)=O,-5.422406,,,,,,,,,...,,,,,,,,,,"[tensor(1.2584), tensor(0.7886), tensor(1.3459..."
3,O=C(NC1(C(=O)N[C@H](Cc2ccccc2)C(=O)NCCCC(=O)N2...,-5.769776,,,,,,,,,...,,,,,,,,,,"[tensor(1.8761), tensor(0.1187), tensor(1.0743..."
4,NC(=O)[C@H](Cc1ccccc1)NC(=O)[C@H](Cc1ccccc1)NC...,-7.431799,,,,,,,,,...,,,,,,,,,,"[tensor(2.0538), tensor(1.9126), tensor(1.1178..."


In [5]:
import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class MultiDataset(Dataset):
    def __init__(self, samples, tasks):
        self.samples = samples['embeddings'].tolist()
        self.targets = samples[tasks.keys()].fillna(np.nan).values

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        target = self.targets[idx]
        return sample, target

dataloader = DataLoader(MultiDataset(df, task_types), batch_size=64, shuffle=True)
print(f"Number of batches in the epoch: {len(dataloader)}")

Number of batches in the epoch: 557


In [61]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class ResidualMLP(nn.Module):
    def __init__(self, input_size, head_input_size, hidden_size, depth, dropout):
        super(ResidualMLP, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(input_size, hidden_size))
        self.layers.append(nn.BatchNorm1d(hidden_size))
        for _ in range(depth - 1):
            self.layers.append(nn.ReLU())
            self.layers.append(nn.Dropout(dropout))
            self.layers.append(nn.Linear(hidden_size, hidden_size))
            self.layers.append(nn.BatchNorm1d(hidden_size))
        self.output_layer = nn.Linear(hidden_size, head_input_size)

    def forward(self, x):
        identity = x
        for layer in self.layers:
            x = layer(x)
        x += identity
        return self.output_layer(x)


class TaskHead(nn.Module):
    def __init__(self, hidden_size, depth, dropout):
        super(TaskHead, self).__init__()
        self.layers = nn.ModuleList()
        for _ in range(depth):
            self.layers.append(nn.Linear(hidden_size, hidden_size))
            self.layers.append(nn.LayerNorm(hidden_size))
            self.layers.append(nn.ReLU())
            self.layers.append(nn.Dropout(dropout))
        self.output_layer = nn.Linear(hidden_size, 1)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return self.output_layer(x)


class MultiTaskModel(nn.Module):
    def __init__(self, input_size, trunk_hidden_size, trunk_depth, head_hidden_size, head_depth, dropout, tasks):
        super(MultiTaskModel, self).__init__()
        self.trunk = ResidualMLP(input_size, head_hidden_size, trunk_hidden_size, trunk_depth, dropout)
        self.heads = nn.ModuleDict({task: TaskHead(head_hidden_size, head_depth, dropout) for task in tasks.keys()})
        self.tasks = tasks
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x, targets=None):
        x = self.trunk(x)
        outputs = {}
        filtered_targets = {}

        if targets is None:  # inference
            return {task: self.heads[task](x) for task in self.tasks}

        task_mask = ~torch.isnan(targets)
        for idx, task in enumerate(self.tasks.keys()): 
            if task_mask[:, idx].any():
                indices = np.where(task_mask[:, idx])[0]
                outputs[task] = self.heads[task](x[indices]).squeeze()
                filtered_targets[task] = targets[indices, idx].squeeze()
        return outputs, filtered_targets


In [45]:
model = MultiTaskModel(
    trunk_hidden_size=512,
    head_hidden_size=128, 
    input_size=512, 
    trunk_depth=2, 
    head_depth=3, 
    tasks=task_types, 
    dropout=0.5 
)
for batch_idx, (samples, targets) in enumerate(dataloader):
    outputs, filtered_targets = model(samples, targets=targets)
    for task, output in outputs.items():
        if torch.isnan(output).any():
            print(task)
            print(output)
            break

In [64]:
import torch
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss as bce_loss
from torch.nn import MSELoss as mse_loss

def compute_loss(outputs, filtered_targets, task_types):
    total_loss = torch.tensor(0.0, requires_grad=True)

    for task, output in outputs.items():
        target = filtered_targets[task].float()

        if torch.isnan(output).any(): print("NANS IN THE OUTPUT")
        if torch.isnan(target).any(): print("NANS IN THE TARGETS")

        if task_types[task] == 'classification':
            loss = bce_loss()(output.float(), target.float())
        else:  # regression
            loss = mse_loss()(output.float(), target.float())

        total_loss = total_loss + loss 

    return total_loss / len(task_types.keys())

def train(model, dataloader, task_types, optimizer, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for batch_idx, (samples, targets) in enumerate(dataloader):
            outputs, filtered_targets = model(samples, targets=targets)
            total_loss = compute_loss(outputs, filtered_targets, task_types)
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            running_loss += total_loss.item()

        print(f"Epoch [{epoch + 1} / {num_epochs}], Train loss: {running_loss / len(dataloader):.2f}")


model = MultiTaskModel(
    trunk_hidden_size=512,
    head_hidden_size=128, 
    input_size=512, 
    trunk_depth=2, 
    head_depth=3, 
    tasks=task_types, 
    dropout=0.5 
)

optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9, weight_decay=0.0001)
train(model, dataloader, task_types, optimizer, num_epochs=10)


Epoch [1 / 10], Train loss: 378.97
Epoch [2 / 10], Train loss: 357.80
Epoch [3 / 10], Train loss: 281.99
Epoch [4 / 10], Train loss: 357.34
Epoch [5 / 10], Train loss: 302.66
Epoch [6 / 10], Train loss: 278.21
Epoch [7 / 10], Train loss: 265.68
Epoch [8 / 10], Train loss: 264.86
Epoch [9 / 10], Train loss: 309.25
Epoch [10 / 10], Train loss: 306.45
