In [1]:
from minimol import Minimol

featurizer = Minimol()

  predictor.load_state_dict(torch.load(state_dict_path), strict=False)


In [2]:
import os
import pickle
from contextlib import redirect_stdout, redirect_stderr

import pandas as pd
from datamol.mol import standardize_smiles
from tdc.benchmark_group import admet_group

from torch.nn import MSELoss as mse_loss
from torch.nn import BCEWithLogitsLoss as bce_loss
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


class MultiDataset(Dataset):
    def __init__(self, samples, task_names):
        self.samples = samples['embeddings'].tolist()
        self.targets = samples[task_names].fillna(np.nan).values

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        target = self.targets[idx]
        return sample, target


class AdmetDataset(Dataset):
    def __init__(self, samples):
        self.samples = samples['embeddings'].tolist()
        self.targets = [float(target) for target in samples['Y'].tolist()]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = torch.tensor(self.samples[idx])
        target = torch.tensor(self.targets[idx])
        return sample, target


cache_path = '.cache/admet_cache.pkl'

if os.path.exists(cache_path):
    print("Loading from cache...")
    with open(cache_path, 'rb') as f:
        cache_data = pickle.load(f)
    task_losses = cache_data['task_losses']
    test_dataloaders = cache_data['test_dataloaders']
    validation_dataloaders = cache_data['validation_dataloaders']
    df = cache_data['df']
else:
    print("Cache not found. Running the program...")
    seed = 42
    batch_size = 128

    group = admet_group(path='admet_data/')
    num_mols = 0
    task_losses = {}
    test_dataloaders = {}
    validation_dataloaders = {}
    df = pd.DataFrame(columns=['smiles'])

    for dataset_i, dataset_name in enumerate(group.dataset_names):
        print(f"{dataset_i + 1} / {len(group.dataset_names)} - {dataset_name}")
        benchmark = group.get(dataset_name)
        name = benchmark['name']
        mols_test = benchmark['test']

        with open(os.devnull, 'w') as fnull, redirect_stdout(fnull), redirect_stderr(fnull): # Suppress output
            mols_train, mols_valid = group.get_train_valid_split(benchmark=name, split_type='default', seed=seed)
            mols_test['embeddings'] = featurizer(list(mols_test['Drug']))
            mols_valid['embeddings'] = featurizer(list(mols_valid['Drug']))

        temp_df = pd.DataFrame({
            'smiles': mols_train['Drug'],
            dataset_name: mols_train['Y']
        })

        num_mols += len(temp_df)
        df = pd.merge(df, temp_df, on='smiles', how='outer')

        task_losses[name] = bce_loss() if len(mols_test['Y'].unique()) == 2 else mse_loss()
        test_dataloaders[name] = DataLoader(AdmetDataset(mols_test), batch_size=batch_size, shuffle=False)
        validation_dataloaders[name] = DataLoader(AdmetDataset(mols_valid), batch_size=batch_size, shuffle=False)

    df['embeddings'] = featurizer(list(df['smiles']))

    cache_data = {
        'task_losses': task_losses,
        'test_dataloaders': test_dataloaders,
        'validation_dataloaders': validation_dataloaders,
        'df': df
    }
    os.makedirs(os.path.dirname(cache_path), exist_ok=True)
    with open(cache_path, 'wb') as f:
        pickle.dump(cache_data, f)
    print("Cache saved.")

Loading from cache...


In [31]:
df.head()

Unnamed: 0,smiles,caco2_wang,hia_hou,pgp_broccatelli,bioavailability_ma,lipophilicity_astrazeneca,solubility_aqsoldb,bbb_martins,ppbr_az,vdss_lombardo,...,cyp3a4_substrate_carbonmangels,cyp2c9_substrate_carbonmangels,half_life_obach,clearance_microsome_az,clearance_hepatocyte_az,herg,ames,dili,ld50_zhu,embeddings
0,CNC1(c2ccccc2Cl)CCCCC1=O,-4.26,,,0.0,,,1.0,44.84,,...,,,,,,,,0.0,,"[tensor(1.1916), tensor(0.3334), tensor(0.8102..."
1,CNC1(c2ccccc2Cl)CCCCC1=O,-4.26,,,0.0,,,1.0,42.01,,...,,,,,,,,0.0,,"[tensor(1.1916), tensor(0.3334), tensor(0.8102..."
2,C/C=C/C/C=C/CCC(=O)[C@@H]1O[C@@H]1C(N)=O,-5.422406,,,,,,,,,...,,,,,,,,,,"[tensor(1.2584), tensor(0.7886), tensor(1.3459..."
3,O=C(NC1(C(=O)N[C@H](Cc2ccccc2)C(=O)NCCCC(=O)N2...,-5.769776,,,,,,,,,...,,,,,,,,,,"[tensor(1.8761), tensor(0.1187), tensor(1.0743..."
4,NC(=O)[C@H](Cc1ccccc1)NC(=O)[C@H](Cc1ccccc1)NC...,-7.431799,,,,,,,,,...,,,,,,,,,,"[tensor(2.0538), tensor(1.9126), tensor(1.1178..."


In [3]:
tdc_reference = {
    "top1_tdc": {
        "caco2_wang": 0.276,
        "bioavailability_ma": 0.748,
        "lipophilicity_astrazeneca": 0.467,
        "solubility_aqsoldb": 0.761,
        "hia_hou": 0.989,
        "pgp_broccatelli": 0.938,
        "bbb_martins": 0.916,
        "ppbr_az": 7.526,
        "vdss_lombardo": 0.713,
        "cyp2c9_veith": 0.859,
        "cyp2d6_veith": 0.790,
        "cyp3a4_veith": 0.916,
        "cyp2c9_substrate_carbonmangels": 0.474,
        "cyp2d6_substrate_carbonmangels": 0.736,
        "cyp3a4_substrate_carbonmangels": 0.662,
        "half_life_obach": 0.562,
        "clearance_hepatocyte_az": 0.498,
        "clearance_microsome_az": 0.630,
        "ld50_zhu": 0.552,
        "herg": 0.880,
        "ames": 0.871,
        "dili": 0.925
    },
    "mole": {
        "caco2_wang": 0.310,
        "bioavailability_ma": 0.654,
        "lipophilicity_astrazeneca": 0.469,
        "solubility_aqsoldb": 0.792,
        "hia_hou": 0.963,
        "pgp_broccatelli": 0.915,
        "bbb_martins": 0.903,
        "ppbr_az": 8.073,
        "vdss_lombardo": 0.654,
        "cyp2c9_veith": 0.801,
        "cyp2d6_veith": 0.682,
        "cyp3a4_veith": 0.877,
        "cyp2c9_substrate_carbonmangels": 0.446,
        "cyp2d6_substrate_carbonmangels": 0.699,
        "cyp3a4_substrate_carbonmangels": 0.670,
        "half_life_obach": 0.549,
        "clearance_hepatocyte_az": 0.381,
        "clearance_microsome_az": 0.607,
        "ld50_zhu": 0.823,
        "herg": 0.813,
        "ames": 0.883,
        "dili": 0.577
    },
    "minimol": {
        "caco2_wang": 0.350,
        "bioavailability_ma": 0.689,
        "lipophilicity_astrazeneca": 0.456,
        "solubility_aqsoldb": 0.741,
        "hia_hou": 0.993,
        "pgp_broccatelli": 0.942,
        "bbb_martins": 0.924,
        "ppbr_az": 7.696,
        "vdss_lombardo": 0.535,
        "cyp2c9_veith": 0.823,
        "cyp2d6_veith": 0.719,
        "cyp3a4_veith": 0.877,
        "cyp2c9_substrate_carbonmangels": 0.474,
        "cyp2d6_substrate_carbonmangels": 0.695,
        "cyp3a4_substrate_carbonmangels": 0.663,
        "half_life_obach": 0.495,
        "clearance_hepatocyte_az": 0.446,
        "clearance_microsome_az": 0.628,
        "ld50_zhu": 0.585,
        "herg": 0.846,
        "ames": 0.849,
        "dili": 0.956
    }
}

metrics = {
    "caco2_wang": "MAE",
    "bioavailability_ma": "AUROC",
    "lipophilicity_astrazeneca": "MAE",
    "solubility_aqsoldb": "MAE",
    "hia_hou": "AUROC",
    "pgp_broccatelli": "AUROC",
    "bbb_martins": "AUROC",
    "ppbr_az": "MAE",
    "vdss_lombardo": "Spearman",
    "cyp2c9_veith": "AUPRC",
    "cyp2d6_veith": "AUPRC",
    "cyp3a4_veith": "AUPRC",
    "cyp2c9_substrate_carbonmangels": "AUPRC",
    "cyp2d6_substrate_carbonmangels": "AUPRC",
    "cyp3a4_substrate_carbonmangels": "AUROC",
    "half_life_obach": "Spearman",
    "clearance_hepatocyte_az": "Spearman",
    "clearance_microsome_az": "Spearman",
    "ld50_zhu": "MAE",
    "herg": "AUROC",
    "ames": "AUROC",
    "dili": "AUROC"
}


import pandas as pd
result_reference_df = pd.DataFrame(tdc_reference)

def evaluate_new_model(df, new_model_dict, new_model_name):
    df[new_model_name] = df.index.map(lambda x: new_model_dict[x][0])
    df['metrics'] = metrics
    def rank_row(row):
        metric = metrics.get(row.name)
        row_without_rank = row.drop(labels=[f"{new_model_name}_rank", "metrics"], errors='ignore') 
        if metric in ["AUROC", "AUPRC", "Spearman"]:
            return row_without_rank.rank(ascending=False)
        elif metric in ["MAE"]:
            return row_without_rank.rank(ascending=True)

    rank_column = f"{new_model_name}_rank"
    df[rank_column] = df.apply(rank_row, axis=1)[new_model_name]
    
    return df

result_reference_df



Unnamed: 0,top1_tdc,mole,minimol
caco2_wang,0.276,0.31,0.35
bioavailability_ma,0.748,0.654,0.689
lipophilicity_astrazeneca,0.467,0.469,0.456
solubility_aqsoldb,0.761,0.792,0.741
hia_hou,0.989,0.963,0.993
pgp_broccatelli,0.938,0.915,0.942
bbb_martins,0.916,0.903,0.924
ppbr_az,7.526,8.073,7.696
vdss_lombardo,0.713,0.654,0.535
cyp2c9_veith,0.859,0.801,0.823


In [7]:
import torch
import numpy as np
from tdc.benchmark_group import admet_group

group = admet_group(path='admet_data/')


class MultitaskTrainer:
    def __init__(self, device, batch_size: int = 256):
        self.device = device
        self.bs = batch_size
        self.task_losses = {}
        self.test_dataloaders = {}
        self.validation_dataloaders = {}
        self.train_dataloader = None
    
    def set_multitask_dataloader(self, data):
        task_names = list(self.task_losses.keys())
        self.train_dataloader = DataLoader(MultiDataset(data, task_names), batch_size=124, shuffle=True)

    def set_per_task_dataloaders(self, test_loaders, val_loaders, task_losses):
        self.task_losses = task_losses
        self.test_dataloaders = test_loaders
        self.validation_dataloaders = val_loaders

    def _eval(self, model, dataloaders, tdc_eval: bool = False):
        predictions = {}
        total_loss = torch.tensor(0.0, requires_grad=True, device=self.device)
        per_task_losses = torch.tensor([0.0]*len(dataloaders.keys()), requires_grad=False, device=self.device)

        for task_i, (task_name, dataloader) in enumerate(dataloaders.items()):
            predictions[task_name] = []
            
            for batch_idx, (samples, targets) in enumerate(dataloader):
                samples = samples.to(device)
                targets = targets.to(device)

                output = model(samples, task=task_name)

                loss_fn = self.task_losses[task_name]
                task_loss = loss_fn(output.float(), targets.float())
                
                if tdc_eval:
                    if isinstance(loss_fn, torch.nn.BCEWithLogitsLoss):
                        output = torch.nn.functional.sigmoid(output)
                    predictions[task_name] += list(output.detach().cpu())

                per_task_losses[task_i] = per_task_losses[task_i] + (task_loss / len(dataloader))
                total_loss = total_loss + (task_loss / len(dataloader))

        total_loss = total_loss / len(dataloaders.keys())

        if tdc_eval:
            return total_loss, per_task_losses, predictions 
        
        return total_loss, per_task_losses
    
    def eval_on_val(self, model):
        return self._eval(model, self.validation_dataloaders)

    def eval_on_test(self, model):
        return self._eval(model, self.test_dataloaders)

    def _compute_train_loss(self, outputs, filtered_targets):
        total_loss = torch.tensor(0.0, requires_grad=True, device=self.device)
        per_task_losses = torch.tensor([0.0]*len(self.task_losses.keys()), requires_grad=False, device=self.device)

        for task_i, (task, loss) in enumerate(self.task_losses.items()):
            if task not in outputs.keys():
                continue
            
            output = outputs[task]
            target = filtered_targets[task]
            assert not torch.isnan(output).any(), "NaNs IN THE OUTPUT!!!"

            task_loss = loss(output.float(), target.float())
            total_loss = total_loss + task_loss
            per_task_losses[task_i] = task_loss

        return total_loss / len(outputs.keys()), per_task_losses

    def train(self, model, optimizer, num_epochs):
        model.train()

        total_val_loss, per_task_val_loss = self._eval(model, self.validation_dataloaders)
        print(f"Epoch [0 / {num_epochs}], Val loss: {total_val_loss:.2f}")
        
        for epoch in range(num_epochs):
            running_loss = 0.0
            running_per_task_loss = torch.tensor([0.0]*len(self.task_losses.keys()), requires_grad=False, device=self.device)
            for batch_idx, (samples, targets) in enumerate(self.train_dataloader):
                samples = samples.to(self.device)
                targets = targets.to(self.device)
                
                outputs, filtered_targets = model(samples, targets)
                total_loss, per_task_losses = self._compute_train_loss(outputs, filtered_targets)
                
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()

                running_loss += total_loss.item()
                running_per_task_loss += per_task_losses
            
            total_val_loss, per_task_val_loss = self._eval(model, self.validation_dataloaders)
            print(f"Epoch [{epoch + 1} / {num_epochs}], Train loss: {running_loss / len(self.train_dataloader):.2f}, Val loss: {total_val_loss:.2f}")

        total_test_loss, per_task_test_loss, predictions = self._eval(model, self.test_dataloaders, tdc_eval=True)
        print(f"Epoch [{epoch + 1} / {num_epochs}], Test loss: {total_test_loss:.2f}")
        tdc_evaluation = group.evaluate_many([predictions]*5)
        print(evaluate_new_model(result_reference_df, tdc_evaluation, 'this run'))


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trainer = MultitaskTrainer(device)
trainer.set_per_task_dataloaders(test_dataloaders, validation_dataloaders, task_losses)
trainer.set_multitask_dataloader(df)

Found local copy...


In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class ResidualMLP(nn.Module):
    def __init__(self, input_size, head_input_size, hidden_size, depth, dropout):
        super(ResidualMLP, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(input_size, hidden_size))
        self.layers.append(nn.BatchNorm1d(hidden_size))
        for _ in range(depth - 1):
            self.layers.append(nn.ReLU())
            self.layers.append(nn.Dropout(dropout))
            self.layers.append(nn.Linear(hidden_size, hidden_size))
            self.layers.append(nn.BatchNorm1d(hidden_size))
        self.output_layer = nn.Linear(hidden_size, head_input_size)

    def forward(self, x):
        identity = x
        for layer in self.layers:
            x = layer(x)
        x += identity
        return self.output_layer(x)


class TaskHead(nn.Module):
    def __init__(self, hidden_size, depth, dropout):
        super(TaskHead, self).__init__()
        self.layers = nn.ModuleList()
        for _ in range(depth):
            self.layers.append(nn.Linear(hidden_size, hidden_size))
            self.layers.append(nn.LayerNorm(hidden_size))
            self.layers.append(nn.ReLU())
            self.layers.append(nn.Dropout(dropout))
        self.output_layer = nn.Linear(hidden_size, 1)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return self.output_layer(x)


class MultiTaskModel(nn.Module):
    def __init__(self, input_size, trunk_hidden_size, trunk_depth, head_hidden_size, head_depth, dropout, tasks, uncertainty_weighing: bool = False):
        super(MultiTaskModel, self).__init__()
        self.trunk = ResidualMLP(input_size, head_hidden_size, trunk_hidden_size, trunk_depth, dropout)
        self.heads = nn.ModuleDict({task: TaskHead(head_hidden_size, head_depth, dropout) for task in tasks})
        self.log_variance = nn.Parameter(torch.tensor(2.0), requires_grad=True) if uncertainty_weighing else None
        self.tasks = list(tasks)
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x, targets=None, task=None):
        x = self.trunk(x)
        outputs = {}
        filtered_targets = {}

        if task:
            return self.heads[task](x).squeeze()

        task_mask = ~torch.isnan(targets)
        for idx, task in enumerate(self.tasks): 
            if task_mask[:, idx].any():
                indices = torch.nonzero(task_mask[:, idx], as_tuple=False).squeeze()
                outputs[task] = self.heads[task](x[indices]).squeeze()
                filtered_targets[task] = targets[indices, idx].squeeze()
        return outputs, filtered_targets
    

In [9]:
import torch
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss as bce_loss
from torch.nn import MSELoss as mse_loss
pd.set_option('display.max_columns', None)
pd.set_option('display.expand_frame_repr', False)

hparams = {
    "trunk_hidden_size": 512,
    "head_hidden_size": 128, 
    "tasks": trainer.task_losses.keys(), 
    "input_size": 512, 
    "trunk_depth": 2, 
    "head_depth": 3, 
    "dropout": 0.5
}
model = MultiTaskModel(**hparams).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=.9, weight_decay=1e-4)
trainer.train(model, optimizer, num_epochs=10)

Epoch [0 / 10], Val loss: 1557.76
Epoch [1 / 10], Train loss: 649.34, Val loss: 1101.61
Epoch [2 / 10], Train loss: 667.91, Val loss: 1058.86
Epoch [3 / 10], Train loss: 613.77, Val loss: 1064.19
Epoch [4 / 10], Train loss: 536.10, Val loss: 1057.59
Epoch [5 / 10], Train loss: 426.99, Val loss: 1061.56
Epoch [6 / 10], Train loss: 447.92, Val loss: 1066.94
Epoch [7 / 10], Train loss: 564.03, Val loss: 1038.85
Epoch [8 / 10], Train loss: 473.82, Val loss: 1059.58
Epoch [9 / 10], Train loss: 413.23, Val loss: 1056.36
Epoch [10 / 10], Train loss: 489.38, Val loss: 1037.92
Epoch [10 / 10], Test loss: 241.90
                                top1_tdc   mole  minimol  this run   metrics  this run_rank
caco2_wang                         0.276  0.310    0.350     0.781       MAE            4.0
bioavailability_ma                 0.748  0.654    0.689     0.556     AUROC            4.0
lipophilicity_astrazeneca          0.467  0.469    0.456     1.029       MAE            4.0
solubility_aqsoldb    

In [None]:
import wandb
import os

entity, project = "blazejba-gc", "multitask_pretraining"
run = wandb.init(entity=entity, project=project)

# Simulate logging metrics over epochs
for epoch in range(10):
    accuracy = 0.9 + epoch * 0.01
    wandb.log({"epoch": epoch, "accuracy": accuracy})

wandb.finish()