In [1]:
from minimol import Minimol

featurizer = Minimol()

  predictor.load_state_dict(torch.load(state_dict_path), strict=False)


In [2]:
import os
import pickle
from contextlib import redirect_stdout, redirect_stderr

import pandas as pd
from datamol.mol import standardize_smiles
from tdc.benchmark_group import admet_group

from torch.nn import MSELoss as mse_loss
from torch.nn import BCEWithLogitsLoss as bce_loss
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


class MultiDataset(Dataset):
    def __init__(self, samples, task_names):
        self.samples = samples['embeddings'].tolist()
        self.targets = samples[task_names].fillna(np.nan).values

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        target = self.targets[idx]
        return sample, target


class AdmetDataset(Dataset):
    def __init__(self, samples):
        self.samples = samples['embeddings'].tolist()
        self.targets = [float(target) for target in samples['Y'].tolist()]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = torch.tensor(self.samples[idx])
        target = torch.tensor(self.targets[idx])
        return sample, target


cache_path = '.cache/admet_cache.pkl'

if os.path.exists(cache_path):
    print("Loading from cache...")
    with open(cache_path, 'rb') as f:
        cache_data = pickle.load(f)
    task_losses = cache_data['task_losses']
    test_dataloaders = cache_data['test_dataloaders']
    validation_dataloaders = cache_data['validation_dataloaders']
    df = cache_data['df']
else:
    print("Cache not found. Running the program...")
    seed = 42
    batch_size = 128

    group = admet_group(path='admet_data/')
    num_mols = 0
    task_losses = {}
    test_dataloaders = {}
    validation_dataloaders = {}
    df = pd.DataFrame(columns=['smiles'])

    for dataset_i, dataset_name in enumerate(group.dataset_names):
        print(f"{dataset_i + 1} / {len(group.dataset_names)} - {dataset_name}")
        benchmark = group.get(dataset_name)
        name = benchmark['name']
        mols_test = benchmark['test']

        with open(os.devnull, 'w') as fnull, redirect_stdout(fnull), redirect_stderr(fnull): # Suppress output
            mols_train, mols_valid = group.get_train_valid_split(benchmark=name, split_type='default', seed=seed)
            mols_test['embeddings'] = featurizer(list(mols_test['Drug']))
            mols_valid['embeddings'] = featurizer(list(mols_valid['Drug']))

        temp_df = pd.DataFrame({
            'smiles': mols_train['Drug'],
            dataset_name: mols_train['Y']
        })

        num_mols += len(temp_df)
        df = pd.merge(df, temp_df, on='smiles', how='outer')

        task_losses[name] = bce_loss() if len(mols_test['Y'].unique()) == 2 else mse_loss()
        test_dataloaders[name] = DataLoader(AdmetDataset(mols_test), batch_size=batch_size, shuffle=False)
        validation_dataloaders[name] = DataLoader(AdmetDataset(mols_valid), batch_size=batch_size, shuffle=False)

    df['embeddings'] = featurizer(list(df['smiles']))

    cache_data = {
        'task_losses': task_losses,
        'test_dataloaders': test_dataloaders,
        'validation_dataloaders': validation_dataloaders,
        'df': df
    }
    os.makedirs(os.path.dirname(cache_path), exist_ok=True)
    with open(cache_path, 'wb') as f:
        pickle.dump(cache_data, f)
    print("Cache saved.")

Loading from cache...


In [4]:
df.head()

Unnamed: 0,smiles,caco2_wang,hia_hou,pgp_broccatelli,bioavailability_ma,lipophilicity_astrazeneca,solubility_aqsoldb,bbb_martins,ppbr_az,vdss_lombardo,...,cyp3a4_substrate_carbonmangels,cyp2c9_substrate_carbonmangels,half_life_obach,clearance_microsome_az,clearance_hepatocyte_az,herg,ames,dili,ld50_zhu,embeddings
0,CNC1(c2ccccc2Cl)CCCCC1=O,-4.26,,,0.0,,,1.0,44.84,,...,,,,,,,,0.0,,"[tensor(1.1916), tensor(0.3334), tensor(0.8102..."
1,CNC1(c2ccccc2Cl)CCCCC1=O,-4.26,,,0.0,,,1.0,42.01,,...,,,,,,,,0.0,,"[tensor(1.1916), tensor(0.3334), tensor(0.8102..."
2,C/C=C/C/C=C/CCC(=O)[C@@H]1O[C@@H]1C(N)=O,-5.422406,,,,,,,,,...,,,,,,,,,,"[tensor(1.2584), tensor(0.7886), tensor(1.3459..."
3,O=C(NC1(C(=O)N[C@H](Cc2ccccc2)C(=O)NCCCC(=O)N2...,-5.769776,,,,,,,,,...,,,,,,,,,,"[tensor(1.8761), tensor(0.1187), tensor(1.0743..."
4,NC(=O)[C@H](Cc1ccccc1)NC(=O)[C@H](Cc1ccccc1)NC...,-7.431799,,,,,,,,,...,,,,,,,,,,"[tensor(2.0538), tensor(1.9126), tensor(1.1178..."


In [77]:
tdc_reference = {
    "tdc1": {"caco2_wang": 0.276, "bioavailability_ma": 0.748, "lipophilicity_astrazeneca": 0.467, "solubility_aqsoldb": 0.761, "hia_hou": 0.989, "pgp_broccatelli": 0.938, "bbb_martins": 0.916, "ppbr_az": 7.526, "vdss_lombardo": 0.713, "cyp2c9_veith": 0.859, "cyp2d6_veith": 0.790, "cyp3a4_veith": 0.916, "cyp2c9_substrate_carbonmangels": 0.474, "cyp2d6_substrate_carbonmangels": 0.736, "cyp3a4_substrate_carbonmangels": 0.662, "half_life_obach": 0.562, "clearance_hepatocyte_az": 0.498, "clearance_microsome_az": 0.630, "ld50_zhu": 0.552, "herg": 0.880, "ames": 0.871, "dili": 0.925},
    "tdc2": {'caco2_wang': 0.285, 'bioavailability_ma': 0.742, 'lipophilicity_astrazeneca': 0.470, 'solubility_aqsoldb': 0.776, 'hia_hou': 0.988, 'pgp_broccatelli': 0.935, 'bbb_martins': 0.915, 'ppbr_az': 7.660, 'vdss_lombardo': 0.707, 'cyp2c9_veith': 0.839, 'cyp2d6_veith': 0.739, 'cyp3a4_veith': 0.904, 'cyp2c9_substrate_carbonmangels': 0.437, 'cyp2d6_substrate_carbonmangels': 0.720, 'cyp3a4_substrate_carbonmangels': 0.650, 'half_life_obach': 0.557, 'clearance_hepatocyte_az': 0.466, 'clearance_microsome_az': 0.626, 'ld50_zhu': 0.588, 'herg': 0.874, 'ames': 0.869, 'dili': 0.919},
    "tdc3": {'caco2_wang': 0.287, 'bioavailability_ma': 0.730, 'lipophilicity_astrazeneca': 0.479, 'solubility_aqsoldb': 0.789, 'hia_hou': 0.986, 'pgp_broccatelli': 0.930, 'bbb_martins': 0.913, 'ppbr_az': 7.788, 'vdss_lombardo': 0.627, 'cyp2c9_veith': 0.829, 'cyp2d6_veith': 0.723, 'cyp3a4_veith': 0.902, 'cyp2c9_substrate_carbonmangels': 0.437, 'cyp2d6_substrate_carbonmangels': 0.713, 'cyp3a4_substrate_carbonmangels': 0.647, 'half_life_obach': 0.547, 'clearance_hepatocyte_az': 0.440, 'clearance_microsome_az': 0.625, 'ld50_zhu': 0.606, 'herg': 0.871, 'ames': 0.868, 'dili': 0.917},
    "tdc4": {'caco2_wang': 0.287, 'bioavailability_ma': 0.706, 'lipophilicity_astrazeneca': 0.525, 'solubility_aqsoldb': 0.792, 'hia_hou': 0.981, 'pgp_broccatelli': 0.929, 'bbb_martins': 0.912, 'ppbr_az': 7.914, 'vdss_lombardo': 0.609, 'cyp2c9_veith': 0.786, 'cyp2d6_veith': 0.721, 'cyp3a4_veith': 0.881, 'cyp2c9_substrate_carbonmangels': 0.433, 'cyp2d6_substrate_carbonmangels': 0.704, 'cyp3a4_substrate_carbonmangels': 0.640, 'half_life_obach': 0.544, 'clearance_hepatocyte_az': 0.439, 'clearance_microsome_az': 0.599, 'ld50_zhu': 0.621, 'herg': 0.856, 'ames': 0.850, 'dili': 0.909},
    "tdc5": {'caco2_wang': 0.289, 'bioavailability_ma': 0.672, 'lipophilicity_astrazeneca': 0.535, 'solubility_aqsoldb': 0.827, 'hia_hou': 0.978, 'pgp_broccatelli': 0.929, 'bbb_martins': 0.910, 'ppbr_az': 8.288, 'vdss_lombardo': 0.582, 'cyp2c9_veith': 0.783, 'cyp2d6_veith': 0.673, 'cyp3a4_veith': 0.876, 'cyp2c9_substrate_carbonmangels': 0.415, 'cyp2d6_substrate_carbonmangels': 0.686, 'cyp3a4_substrate_carbonmangels': 0.639, 'half_life_obach': 0.438, 'clearance_hepatocyte_az': 0.431, 'clearance_microsome_az': 0.597, 'ld50_zhu': 0.622, 'herg': 0.841, 'ames': 0.842, 'dili': 0.899},
    "mole":     {"caco2_wang": 0.310, "bioavailability_ma": 0.654, "lipophilicity_astrazeneca": 0.469, "solubility_aqsoldb": 0.792, "hia_hou": 0.963, "pgp_broccatelli": 0.915, "bbb_martins": 0.903, "ppbr_az": 8.073, "vdss_lombardo": 0.654, "cyp2c9_veith": 0.801, "cyp2d6_veith": 0.682, "cyp3a4_veith": 0.877, "cyp2c9_substrate_carbonmangels": 0.446, "cyp2d6_substrate_carbonmangels": 0.699, "cyp3a4_substrate_carbonmangels": 0.670, "half_life_obach": 0.549, "clearance_hepatocyte_az": 0.381, "clearance_microsome_az": 0.607, "ld50_zhu": 0.823, "herg": 0.813, "ames": 0.883, "dili": 0.577},
    "minimol":  {"caco2_wang": 0.350, "bioavailability_ma": 0.689, "lipophilicity_astrazeneca": 0.456, "solubility_aqsoldb": 0.741, "hia_hou": 0.993, "pgp_broccatelli": 0.942, "bbb_martins": 0.924, "ppbr_az": 7.696, "vdss_lombardo": 0.535, "cyp2c9_veith": 0.823, "cyp2d6_veith": 0.719, "cyp3a4_veith": 0.877, "cyp2c9_substrate_carbonmangels": 0.474, "cyp2d6_substrate_carbonmangels": 0.695, "cyp3a4_substrate_carbonmangels": 0.663, "half_life_obach": 0.495, "clearance_hepatocyte_az": 0.446, "clearance_microsome_az": 0.628, "ld50_zhu": 0.585, "herg": 0.846, "ames": 0.849, "dili": 0.956}
}

metrics = {
    "caco2_wang"                    : "MAE",
    "bioavailability_ma"            : "AUROC",
    "lipophilicity_astrazeneca"     : "MAE",
    "solubility_aqsoldb"            : "MAE",
    "hia_hou"                       : "AUROC",
    "pgp_broccatelli"               : "AUROC",
    "bbb_martins"                   : "AUROC",
    "ppbr_az"                       : "MAE",
    "vdss_lombardo"                 : "Spearman",
    "cyp2c9_veith"                  : "AUPRC",
    "cyp2d6_veith"                  : "AUPRC",
    "cyp3a4_veith"                  : "AUPRC",
    "cyp2c9_substrate_carbonmangels": "AUPRC",
    "cyp2d6_substrate_carbonmangels": "AUPRC",
    "cyp3a4_substrate_carbonmangels": "AUROC",
    "half_life_obach"               : "Spearman",
    "clearance_hepatocyte_az"       : "Spearman",
    "clearance_microsome_az"        : "Spearman",
    "ld50_zhu"                      : "MAE",
    "herg"                          : "AUROC",
    "ames"                          : "AUROC",
    "dili"                          : "AUROC"
}

import pandas as pd
result_reference_df = pd.DataFrame(tdc_reference)

def evaluate_new_model(df, new_model_dict, new_model_name):
    df['metrics'] = metrics
    df[new_model_name] = df.index.map(lambda x: new_model_dict[x][0])
    def rank_row(row):
        metric = metrics.get(row.name)
        row_without_rank = row.drop(labels=[f"{new_model_name}_rank", "metrics"], errors='ignore') 
        if metric in ["AUROC", "AUPRC", "Spearman"]:
            return row_without_rank.rank(ascending=False)
        elif metric in ["MAE"]:
            return row_without_rank.rank(ascending=True)

    rank_column = f"{new_model_name}_rank"
    df[rank_column] = df.apply(rank_row, axis=1)[new_model_name]
    
    cols = df.columns.tolist()
    cols.insert(0, cols.pop(cols.index('metrics')))
    df = df[cols]

    return df


def get_scores(df, models: list = ['tdc1', 'tdc2', 'tdc3', 'tdc4', 'tdc5', 'mole', 'minimol', 'this_run']):
    global metrics

    scores = {}
    for model in models:
        total_score = 0
        for dataset, metric in metrics.items():
            score = df.loc[dataset, model]
            score_min = df.loc[dataset, models].min()
            score_max = df.loc[dataset, models].max()
            total_score += score if metric in ['AUROC', 'AUPRC', 'Spearman'] else 1 - (score - score_min) / (score_max - score_min)
        scores[model] = total_score

    return scores


result_reference_df



Unnamed: 0,tdc1,tdc2,tdc3,tdc4,tdc5,mole,minimol
caco2_wang,0.28,0.28,0.29,0.29,0.29,0.31,0.35
bioavailability_ma,0.75,0.74,0.73,0.71,0.67,0.65,0.69
lipophilicity_astrazeneca,0.47,0.47,0.48,0.53,0.54,0.47,0.46
solubility_aqsoldb,0.76,0.78,0.79,0.79,0.83,0.79,0.74
hia_hou,0.99,0.99,0.99,0.98,0.98,0.96,0.99
pgp_broccatelli,0.94,0.94,0.93,0.93,0.93,0.92,0.94
bbb_martins,0.92,0.92,0.91,0.91,0.91,0.9,0.92
ppbr_az,7.53,7.66,7.79,7.91,8.29,8.07,7.7
vdss_lombardo,0.71,0.71,0.63,0.61,0.58,0.65,0.54
cyp2c9_veith,0.86,0.84,0.83,0.79,0.78,0.8,0.82


In [84]:
import json

import torch
import numpy as np
from tdc.benchmark_group import admet_group

group = admet_group(path='admet_data/')


class MultitaskTrainer:
    def __init__(self, device, batch_size: int = 256):
        self.device = device
        self.bs = batch_size
        self.task_losses = {}
        self.test_dataloaders = {}
        self.validation_dataloaders = {}
        self.train_dataloader = None
    
    def set_multitask_dataloader(self, data):
        task_names = list(self.task_losses.keys())
        self.train_dataloader = DataLoader(MultiDataset(data, task_names), batch_size=124, shuffle=True)

    def set_per_task_dataloaders(self, test_loaders, val_loaders, task_losses):
        self.task_losses = task_losses
        self.test_dataloaders = test_loaders
        self.validation_dataloaders = val_loaders

    def _eval(self, model, dataloaders, tdc_eval: bool = False):
        model.eval()

        predictions = {}
        total_loss = torch.tensor(0.0, requires_grad=True, device=self.device)
        per_task_losses = torch.tensor([0.0]*len(dataloaders.keys()), requires_grad=False, device=self.device)

        for task_i, (task_name, dataloader) in enumerate(dataloaders.items()):
            predictions[task_name] = []
            
            for batch_idx, (samples, targets) in enumerate(dataloader):
                samples = samples.to(device)
                targets = targets.to(device)

                output = model(samples, task=task_name)

                loss_fn = self.task_losses[task_name]
                task_loss = loss_fn(output.float(), targets.float())
                
                if tdc_eval:
                    if isinstance(loss_fn, torch.nn.BCEWithLogitsLoss):
                        output = torch.nn.functional.sigmoid(output)
                    predictions[task_name] += list(output.detach().cpu())

                per_task_losses[task_i] = per_task_losses[task_i] + (task_loss / len(dataloader))
                total_loss = total_loss + (task_loss / len(dataloader))

        total_loss = total_loss / len(dataloaders.keys())

        if tdc_eval:
            return total_loss, per_task_losses, predictions 
        
        return total_loss, per_task_losses
    
    def eval_on_val(self, model):
        return self._eval(model, self.validation_dataloaders)

    def eval_on_test(self, model):
        return self._eval(model, self.test_dataloaders)

    def _compute_train_loss(self, outputs, filtered_targets):
        total_loss = torch.tensor(0.0, requires_grad=True, device=self.device)
        per_task_losses = torch.tensor([0.0]*len(self.task_losses.keys()), requires_grad=False, device=self.device)

        for task_i, (task, loss) in enumerate(self.task_losses.items()):
            if task not in outputs.keys():
                continue
            
            output = outputs[task]
            target = filtered_targets[task]
            assert not torch.isnan(output).any(), "NaNs IN THE OUTPUT!!!"

            task_loss = loss(output.float(), target.float())
            total_loss = total_loss + task_loss
            per_task_losses[task_i] = task_loss

        return total_loss / len(outputs.keys()), per_task_losses

    def train(self, model, optimizer, scheduler, num_epochs, n_freeze_epochs):
        model.train()

        total_val_loss, per_task_val_loss = self._eval(model, self.validation_dataloaders)
        print(f"Epoch [0 / {num_epochs}], Val loss: {total_val_loss:.2f}")
        
        for epoch in range(num_epochs):
            
            if epoch > num_epochs - n_freeze_epochs:
                model.freeze_trunk()

            running_loss = 0.0
            running_per_task_loss = torch.tensor([0.0]*len(self.task_losses.keys()), requires_grad=False, device=self.device)

            for batch_idx, (samples, targets) in enumerate(self.train_dataloader):
                samples = samples.to(self.device)
                targets = targets.to(self.device)
                
                outputs, filtered_targets = model(samples, targets)
                total_loss, per_task_losses = self._compute_train_loss(outputs, filtered_targets)
                
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()

                running_loss += total_loss.item()
                running_per_task_loss += per_task_losses
            
            scheduler.step()
            
            total_val_loss, per_task_val_loss = self._eval(model, self.validation_dataloaders)
            print(f"Epoch [{epoch + 1} / {num_epochs}], Train loss: {running_loss / len(self.train_dataloader):.2f}, Val loss: {total_val_loss:.2f}")

        total_test_loss, per_task_test_loss, predictions = self._eval(model, self.test_dataloaders, tdc_eval=True)
        print(f"Epoch [{epoch + 1} / {num_epochs}], Test loss: {total_test_loss:.2f}")
        self._compare_results(predictions)

    @staticmethod
    def _compare_results(predictions):
        tdc_evaluation = group.evaluate_many([predictions]*5)
        results_df = evaluate_new_model(result_reference_df, tdc_evaluation, 'this_run')
        scores_row = pd.DataFrame(get_scores(results_df), index=['Scores'])
        results_df = pd.concat([results_df, scores_row], axis=0)

        pd.set_option('display.max_columns', None)
        pd.set_option('display.expand_frame_repr', False)
        pd.options.display.float_format = '{:.2f}'.format
        print(results_df)

def trainer_factory(batch_size: int = 128):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    trainer = MultitaskTrainer(device)
    trainer.set_per_task_dataloaders(test_dataloaders, validation_dataloaders, task_losses)
    trainer.set_multitask_dataloader(df)
    return trainer

Found local copy...


In [86]:
from typing import Union

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class ResidualMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, depth, dropout):
        super(ResidualMLP, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(input_dim, hidden_dim))
        self.layers.append(nn.BatchNorm1d(hidden_dim))
        for _ in range(depth - 1):
            self.layers.append(nn.ReLU())
            self.layers.append(nn.Dropout(dropout))
            self.layers.append(nn.Linear(hidden_dim, hidden_dim))
            self.layers.append(nn.BatchNorm1d(hidden_dim))
        self.output_layer = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        identity = x
        for layer in self.layers:
            x = layer(x)
        x += identity
        return self.output_layer(x)


class TaskHead(nn.Module):
    def __init__(self, input_dim, hidden_dim, depth, dropout):
        super(TaskHead, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(input_dim, hidden_dim))
        self.layers.append(nn.LayerNorm(hidden_dim))
        for _ in range(depth - 1):
            self.layers.append(nn.ReLU())
            self.layers.append(nn.Dropout(dropout))
            self.layers.append(nn.Linear(hidden_dim, hidden_dim))
            self.layers.append(nn.LayerNorm(hidden_dim))
        self.output_layer = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        identity = x
        for layer in self.layers:
            x = layer(x)
        x += identity
        return self.output_layer(x)


class MultiTaskModel(nn.Module):

    def __init__(self,
                 head_params         : dict,
                 trunk_params        : dict, 
                 dropout             : float,
                 tasks               : Union[tuple, list], 
                 uncertainty_weighing: bool = False):

        super(MultiTaskModel, self).__init__()

        if trunk_params['depth'] > 0:
            self.trunk = ResidualMLP(dropout=dropout, **trunk_params)
        else:
            assert trunk_params['input_dim'] == head_params['input_dim'], "Input size must match when trunk depth is 0."
            self.trunk = None
        
        self.heads = nn.ModuleDict({task: TaskHead(dropout=dropout, **head_params) for task in tasks})
        self.log_variance = nn.Parameter(torch.tensor(2.0), requires_grad=True) if uncertainty_weighing else None
        self.tasks = list(tasks)
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def freeze_trunk(self):
        if self.trunk:
            for param in self.trunk.parameters():
                param.requires_grad = False

    def forward(self, x, targets=None, task=None):
        x = self.trunk(x) if self.trunk else x
        
        outputs = {}
        filtered_targets = {}

        if task:
            return self.heads[task](x).squeeze()

        task_mask = ~torch.isnan(targets)
        for idx, task in enumerate(self.tasks): 
            if task_mask[:, idx].any():
                indices = torch.nonzero(task_mask[:, idx], as_tuple=False).squeeze()
                outputs[task] = self.heads[task](x[indices]).squeeze()
                filtered_targets[task] = targets[indices, idx].squeeze()

        return outputs, filtered_targets

In [92]:
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR

hparams = {
    "dropout": 0.5,
    "tasks": tuple(trainer.task_losses.keys()), 
    "head_params": {
        "depth": 3, 
        "input_dim": 256,
        "hidden_dim": 256
    },
    "trunk_params": {
        "depth": 4, 
        "input_dim": 512, 
        "hidden_dim": 512,
        "output_dim": 256
    }
}
lr              = 4e-3
n_epochs        = 20
n_warmup_epochs = 5
batch_size      = 2048
weight_decay    = 1e-4
n_freeze_epochs = 5

trainer = trainer_factory(batch_size)
model   = MultiTaskModel(**hparams).to(device)

optimizer        = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
warmup_scheduler = LinearLR(optimizer, start_factor=lr / 1e2, total_iters=n_warmup_epochs)
cosine_scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs - n_warmup_epochs)
scheduler        = SequentialLR(optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[n_warmup_epochs])

trainer.train(model, optimizer, scheduler, num_epochs=n_epochs, n_freeze_epochs=n_freeze_epochs)

Epoch [0 / 20], Val loss: 1566.21
Epoch [1 / 20], Train loss: 934.92, Val loss: 1539.21
Epoch [2 / 20], Train loss: 499.76, Val loss: 1035.63
Epoch [3 / 20], Train loss: 462.32, Val loss: 1024.52
Epoch [4 / 20], Train loss: 418.21, Val loss: 987.14
Epoch [5 / 20], Train loss: 397.27, Val loss: 984.20
Epoch [6 / 20], Train loss: 464.40, Val loss: 1043.79
Epoch [7 / 20], Train loss: 348.19, Val loss: 1003.67
Epoch [8 / 20], Train loss: 461.35, Val loss: 982.85
Epoch [9 / 20], Train loss: 375.08, Val loss: 938.08
Epoch [10 / 20], Train loss: 334.84, Val loss: 989.84
Epoch [11 / 20], Train loss: 915.14, Val loss: 1018.72
Epoch [12 / 20], Train loss: 308.75, Val loss: 895.22
Epoch [13 / 20], Train loss: 284.77, Val loss: 817.78
Epoch [14 / 20], Train loss: 289.26, Val loss: 805.19
Epoch [15 / 20], Train loss: 301.03, Val loss: 805.07
Epoch [16 / 20], Train loss: 226.54, Val loss: 816.26
Epoch [17 / 20], Train loss: 232.52, Val loss: 795.15
Epoch [18 / 20], Train loss: 136.09, Val loss: 793.

Epoch [0 / 20], Val loss: 1529.36
Epoch [1 / 20], Train loss: 987.49, Val loss: 1515.75
Epoch [2 / 20], Train loss: 478.12, Val loss: 997.80
Epoch [3 / 20], Train loss: 531.30, Val loss: 996.84
Epoch [4 / 20], Train loss: 440.09, Val loss: 944.69
Epoch [5 / 20], Train loss: 451.10, Val loss: 995.97
Epoch [6 / 20], Train loss: 493.83, Val loss: 958.22
Epoch [7 / 20], Train loss: 326.13, Val loss: 950.37
Epoch [8 / 20], Train loss: 319.23, Val loss: 902.01
Epoch [9 / 20], Train loss: 502.29, Val loss: 853.17
Epoch [10 / 20], Train loss: 384.61, Val loss: 883.08
Epoch [11 / 20], Train loss: 280.79, Val loss: 788.70
Epoch [12 / 20], Train loss: 209.52, Val loss: 767.53
Epoch [13 / 20], Train loss: 175.38, Val loss: 763.10
Epoch [14 / 20], Train loss: 239.95, Val loss: 702.57
Epoch [15 / 20], Train loss: 168.21, Val loss: 688.34
Epoch [16 / 20], Train loss: 111.77, Val loss: 711.55
Epoch [17 / 20], Train loss: 132.02, Val loss: 673.70
Epoch [18 / 20], Train loss: 106.15, Val loss: 673.86
Epoch [19 / 20], Train loss: 117.40, Val loss: 675.25
Epoch [20 / 20], Train loss: 117.32, Val loss: 676.77
Epoch [20 / 20], Test loss: 186.12

                                 metrics  tdc1  tdc2  tdc3  tdc4  tdc5  mole  minimol  this_run  this_run_rank
caco2_wang                           MAE  0.28  0.28  0.29  0.29  0.29  0.31     0.35      0.38           8.00
bioavailability_ma                 AUROC  0.75  0.74  0.73  0.71  0.67  0.65     0.69      0.64           8.00
lipophilicity_astrazeneca            MAE  0.47  0.47  0.48  0.53  0.54  0.47     0.46      0.55           8.00
solubility_aqsoldb                   MAE  0.76  0.78  0.79  0.79  0.83  0.79     0.74      1.08           8.00
hia_hou                            AUROC  0.99  0.99  0.99  0.98  0.98  0.96     0.99      0.97           7.00
pgp_broccatelli                    AUROC  0.94  0.94  0.93  0.93  0.93  0.92     0.94      0.93           7.00
bbb_martins                        AUROC  0.92  0.92  0.91  0.91  0.91  0.90     0.92      0.87           8.00
ppbr_az                              MAE  7.53  7.66  7.79  7.91  8.29  8.07     7.70      8.88           8.00
vdss_lombardo                   Spearman  0.71  0.71  0.63  0.61  0.58  0.65     0.54      0.35           8.00
cyp2c9_veith                       AUPRC  0.86  0.84  0.83  0.79  0.78  0.80     0.82      0.77           8.00
cyp2d6_veith                       AUPRC  0.79  0.74  0.72  0.72  0.67  0.68     0.72      0.64           8.00
cyp3a4_veith                       AUPRC  0.92  0.90  0.90  0.88  0.88  0.88     0.88      0.85           8.00
cyp2c9_substrate_carbonmangels     AUPRC  0.47  0.44  0.44  0.43  0.41  0.45     0.47      0.50           1.00
cyp2d6_substrate_carbonmangels     AUPRC  0.74  0.72  0.71  0.70  0.69  0.70     0.69      0.70           4.50
cyp3a4_substrate_carbonmangels     AUROC  0.66  0.65  0.65  0.64  0.64  0.67     0.66      0.67           2.00
half_life_obach                 Spearman  0.56  0.56  0.55  0.54  0.44  0.55     0.49      0.45           7.00
clearance_hepatocyte_az         Spearman  0.50  0.47  0.44  0.44  0.43  0.38     0.45      0.48           2.00
clearance_microsome_az          Spearman  0.63  0.63  0.62  0.60  0.60  0.61     0.63      0.56           8.00
ld50_zhu                             MAE  0.55  0.59  0.61  0.62  0.62  0.82     0.58      0.66           7.00
herg                               AUROC  0.88  0.87  0.87  0.86  0.84  0.81     0.85      0.79           8.00
ames                               AUROC  0.87  0.87  0.87  0.85  0.84  0.88     0.85      0.80           8.00
dili                               AUROC  0.93  0.92  0.92  0.91  0.90  0.58     0.96      0.91           6.00
Scores                               NaN 17.93 17.32 16.83 15.98 15.17 15.06    16.60     12.46            NaN

In [None]:
import wandb
import os

entity, project = "blazejba-gc", "multitask_pretraining"
run = wandb.init(entity=entity, project=project)

# Simulate logging metrics over epochs
for epoch in range(10):
    accuracy = 0.9 + epoch * 0.01
    wandb.log({"epoch": epoch, "accuracy": accuracy})

wandb.finish()