In [2]:
from minimol import Minimol

featurizer = Minimol()

  predictor.load_state_dict(torch.load(state_dict_path), strict=False)


In [None]:
import os
import pandas as pd
from datamol.mol import standardize_smiles
from tdc.benchmark_group import admet_group
from contextlib import redirect_stdout, redirect_stderr

from torch.nn import MSELoss as mse_loss
from torch.nn import BCEWithLogitsLoss as bce_loss

seed = 42
batch_size = 128

group = admet_group(path='admet_data/')

num_mols = 0
task_losses = {}
test_dataloaders = {}
validation_dataloaders = {}
df = pd.DataFrame(columns=['smiles'])

for dataset_i, dataset_name in enumerate(group.dataset_names):
    print(f"{dataset_i + 1} / {len(group.dataset_names)} - {dataset_name}")
    benchmark = group.get(dataset_name)
    name = benchmark['name']
    mols_test = benchmark['test']
    
    with open(os.devnull, 'w') as fnull, redirect_stdout(fnull), redirect_stderr(fnull): # suppress output
        mols_train, mols_valid = group.get_train_valid_split(benchmark=name, split_type='default', seed=seed)
        mols_test['embeddings'] = featurizer(list(mols_test['Drug']))
        mols_valid['embeddings'] = featurizer(list(mols_valid['Drug']))

    temp_df = pd.DataFrame({
        'smiles': mols_train['Drug'],
        dataset_name: mols_train['Y']
    })

    num_mols += len(temp_df)
    df = pd.merge(df, temp_df, on='smiles', how='outer')

    task_losses[name] =  bce_loss() if len(mols_test['Y'].unique()) == 2 else mse_loss()
    test_dataloaders[name] = DataLoader(AdmetDataset(mols_test), batch_size=batch_size, shuffle=False)
    validation_dataloaders[name] = DataLoader(AdmetDataset(mols_valid), batch_size=batch_size, shuffle=False)

df['embeddings'] = featurizer(list(df['smiles']))
print(f"All mols: {num_mols}")
print(f"Unique mols: {len(df)}")

In [4]:
df.head()

Unnamed: 0,smiles,caco2_wang,hia_hou,pgp_broccatelli,bioavailability_ma,lipophilicity_astrazeneca,solubility_aqsoldb,bbb_martins,ppbr_az,vdss_lombardo,...,cyp3a4_substrate_carbonmangels,cyp2c9_substrate_carbonmangels,half_life_obach,clearance_microsome_az,clearance_hepatocyte_az,herg,ames,dili,ld50_zhu,embeddings
0,CNC1(c2ccccc2Cl)CCCCC1=O,-4.26,,,0.0,,,1.0,44.84,,...,,,,,,,,0.0,,"[tensor(1.1916), tensor(0.3334), tensor(0.8102..."
1,CNC1(c2ccccc2Cl)CCCCC1=O,-4.26,,,0.0,,,1.0,42.01,,...,,,,,,,,0.0,,"[tensor(1.1916), tensor(0.3334), tensor(0.8102..."
2,C/C=C/C/C=C/CCC(=O)[C@@H]1O[C@@H]1C(N)=O,-5.422406,,,,,,,,,...,,,,,,,,,,"[tensor(1.2584), tensor(0.7886), tensor(1.3459..."
3,O=C(NC1(C(=O)N[C@H](Cc2ccccc2)C(=O)NCCCC(=O)N2...,-5.769776,,,,,,,,,...,,,,,,,,,,"[tensor(1.8761), tensor(0.1187), tensor(1.0743..."
4,NC(=O)[C@H](Cc1ccccc1)NC(=O)[C@H](Cc1ccccc1)NC...,-7.431799,,,,,,,,,...,,,,,,,,,,"[tensor(2.0538), tensor(1.9126), tensor(1.1178..."


In [39]:
import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


class MultiDataset(Dataset):
    def __init__(self, samples, task_names):
        self.samples = samples['embeddings'].tolist()
        self.targets = samples[task_names].fillna(np.nan).values

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        target = self.targets[idx]
        return sample, target


class AdmetDataset(Dataset):
    def __init__(self, samples):
        self.samples = samples['embeddings'].tolist()
        self.targets = [float(target) for target in samples['Y'].tolist()]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = torch.tensor(self.samples[idx])
        target = torch.tensor(self.targets[idx])
        return sample, target


class MultitaskTrainer:
    def __init__(self, device, batch_size: int = 256):
        self.device = device
        self.bs = batch_size
        self.task_losses = {}
        self.test_dataloaders = {}
        self.validation_dataloaders = {}
        self.train_dataloader = None
    
    def set_multitask_dataloader(self, data):
        task_names = list(self.task_losses.keys())
        self.train_dataloader = DataLoader(MultiDataset(data, task_names), batch_size=124, shuffle=True)

    def set_per_task_dataloaders(self, test_loaders, val_loaders, task_losses):
        self.task_losses = task_losses
        self.test_dataloaders = test_loaders
        self.validation_dataloaders = val_loaders

    def _eval(self, model, dataloaders):
        total_loss = torch.tensor(0.0, requires_grad=True, device=self.device)
        per_task_losses = torch.tensor([0.0]*len(dataloaders.keys()), requires_grad=False, device=self.device)

        for task_i, (task_name, dataloader) in enumerate(dataloaders.items()):
            for batch_idx, (samples, targets) in enumerate(dataloader):
                samples = samples.to(device)
                targets = targets.to(device)

                output = model(samples, task=task_name)

                task_loss = self.task_losses[task_name](output.float(), targets.float())
                per_task_losses[task_i] = per_task_losses[task_i] + (task_loss / len(dataloader))
                total_loss = total_loss + (task_loss / len(dataloader))

        return total_loss, per_task_losses
    
    def eval_on_val(self, model):
        return self._eval(model, self.validation_dataloaders)

    def eval_on_test(self, model):
        return self._eval(model, self.test_dataloaders)

    def _compute_train_loss(self, outputs, filtered_targets):
        total_loss = torch.tensor(0.0, requires_grad=True, device=self.device)
        per_task_losses = torch.tensor([0.0]*len(self.task_losses.keys()), requires_grad=False, device=self.device)

        for task_i, (task, loss) in enumerate(self.task_losses.items()):
            if task not in outputs.keys():
                continue
            
            output = outputs[task]
            target = filtered_targets[task]
            assert not torch.isnan(output).any(), "NaNs IN THE OUTPUT!!!"

            task_loss = loss(output.float(), target.float())
            total_loss = total_loss + task_loss 
            per_task_losses[task_i] = task_loss

        return total_loss / len(outputs.keys()), per_task_losses

    def train(self, model, optimizer, num_epochs):
        model.train()
        for epoch in range(num_epochs):
            running_loss = 0.0
            running_per_task_loss = torch.tensor([0.0]*len(self.task_losses.keys()), requires_grad=False, device=self.device)
            for batch_idx, (samples, targets) in enumerate(self.train_dataloader):
                samples = samples.to(self.device)
                targets = targets.to(self.device)
                
                outputs, filtered_targets = model(samples, targets)
                total_loss, per_task_losses = self._compute_train_loss(outputs, filtered_targets)
                
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()

                running_loss += total_loss.item()
                running_per_task_loss += per_task_losses
            
            total_val_loss, per_task_val_loss = self._eval(model, self.validation_dataloaders)
            print(f"Epoch [{epoch + 1} / {num_epochs}], Train loss: {running_loss / len(self.train_dataloader):.2f}, Val loss: {total_val_loss}")


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trainer = MultitaskTrainer(device)
trainer.set_per_task_dataloaders(test_dataloaders, validation_dataloaders, task_losses)
trainer.set_multitask_dataloader(df)

In [30]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class ResidualMLP(nn.Module):
    def __init__(self, input_size, head_input_size, hidden_size, depth, dropout):
        super(ResidualMLP, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(input_size, hidden_size))
        self.layers.append(nn.BatchNorm1d(hidden_size))
        for _ in range(depth - 1):
            self.layers.append(nn.ReLU())
            self.layers.append(nn.Dropout(dropout))
            self.layers.append(nn.Linear(hidden_size, hidden_size))
            self.layers.append(nn.BatchNorm1d(hidden_size))
        self.output_layer = nn.Linear(hidden_size, head_input_size)

    def forward(self, x):
        identity = x
        for layer in self.layers:
            x = layer(x)
        x += identity
        return self.output_layer(x)


class TaskHead(nn.Module):
    def __init__(self, hidden_size, depth, dropout):
        super(TaskHead, self).__init__()
        self.layers = nn.ModuleList()
        for _ in range(depth):
            self.layers.append(nn.Linear(hidden_size, hidden_size))
            self.layers.append(nn.LayerNorm(hidden_size))
            self.layers.append(nn.ReLU())
            self.layers.append(nn.Dropout(dropout))
        self.output_layer = nn.Linear(hidden_size, 1)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return self.output_layer(x)


class MultiTaskModel(nn.Module):
    def __init__(self, input_size, trunk_hidden_size, trunk_depth, head_hidden_size, head_depth, dropout, tasks):
        super(MultiTaskModel, self).__init__()
        self.trunk = ResidualMLP(input_size, head_hidden_size, trunk_hidden_size, trunk_depth, dropout)
        self.heads = nn.ModuleDict({task: TaskHead(head_hidden_size, head_depth, dropout) for task in tasks})
        self.tasks = list(tasks)
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x, targets=None, task=None):
        x = self.trunk(x)
        outputs = {}
        filtered_targets = {}

        if task:
            return self.heads[task](x)

        task_mask = ~torch.isnan(targets)
        for idx, task in enumerate(self.tasks): 
            if task_mask[:, idx].any():
                indices = torch.nonzero(task_mask[:, idx], as_tuple=False).squeeze()
                outputs[task] = self.heads[task](x[indices]).squeeze()
                filtered_targets[task] = targets[indices, idx].squeeze()
        return outputs, filtered_targets
    

In [40]:
import torch
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss as bce_loss
from torch.nn import MSELoss as mse_loss

hparams = {
    "trunk_hidden_size": 512,
    "head_hidden_size": 128, 
    "tasks": trainer.task_losses.keys(), 
    "input_size": 512, 
    "trunk_depth": 2, 
    "head_depth": 3, 
    "dropout": 0.5
}
model = MultiTaskModel(**hparams).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=.9, weight_decay=1e-4)
trainer.train(model, optimizer, num_epochs=10)

ValueError: Target size (torch.Size([58])) must be the same as input size (torch.Size([58, 1]))

In [None]:
import wandb
import os

entity, project = "blazejba-gc", "multitask_pretraining"
run = wandb.init(entity=entity, project=project)

# Simulate logging metrics over epochs
for epoch in range(10):
    accuracy = 0.9 + epoch * 0.01
    wandb.log({"epoch": epoch, "accuracy": accuracy})

wandb.finish()