In [1]:
import pandas as pd
import numpy as np
import torch
import scripts
from functools import lru_cache
import torchmetrics
from torch import nn
import optuna

  from .autonotebook import tqdm as notebook_tqdm


# Data loading

First we load the data. The basic idea is to create dictionaries with features associated to the drugs and cell-lines. In principle, the splits and the data shouldn't be changed

In [2]:
@lru_cache(maxsize = None)
def get_data(n_fold = 0, fp_radius = 2):
    smile_dict = pd.read_csv("data/smiles.csv", index_col=0)
    fp = scripts.FingerprintFeaturizer(R = fp_radius)
    drug_dict = fp(smile_dict.iloc[:, 1], smile_dict.iloc[:, 0])
    driver_genes = pd.read_csv("data/driver_genes.csv").loc[:, "symbol"].dropna()
    rnaseq = pd.read_csv("data/rnaseq_normcount.csv", index_col=0)
    driver_columns = rnaseq.columns.isin(driver_genes)
    filtered_rna = rnaseq.loc[:, driver_columns]
    tensor_exp = torch.Tensor(filtered_rna.to_numpy())
    cell_dict = {cell: tensor_exp[i] for i, cell in enumerate(filtered_rna.index.to_numpy())}
    data = pd.read_csv("data/GDSC12.csv", index_col=0)
    # default, remove data where lines or drugs are missing:
    data = data.query("SANGER_MODEL_ID in @cell_dict.keys() & DRUG_ID in @drug_dict.keys()")
    unique_cell_lines = data.loc[:, "SANGER_MODEL_ID"].unique()
    np.random.seed(420) # for comparibility, don't change it!
    np.random.shuffle(unique_cell_lines)
    folds = np.array_split(unique_cell_lines, 10)
    test_lines = folds[0]
    train_idxs = list(range(10))
    train_idxs.remove(n_fold)
    np.random.seed(420)
    validation_idx = np.random.choice(train_idxs)
    train_idxs.remove(validation_idx)
    train_lines = np.concatenate([folds[idx] for idx in train_idxs])
    validation_lines = folds[validation_idx]
    test_lines = folds[n_fold]
    train_data = data.query("SANGER_MODEL_ID in @train_lines")
    validation_data = data.query("SANGER_MODEL_ID in @validation_lines")
    test_data = data.query("SANGER_MODEL_ID in @test_lines")
    return (scripts.OmicsDataset_drugwise(cell_dict, drug_dict, train_data),
    scripts.OmicsDataset_drugwise(cell_dict, drug_dict, validation_data),
    scripts.OmicsDataset_drugwise(cell_dict, drug_dict, test_data))

# Configuration

we declare the configuration, this is going to be model-specific and we get the datasets

In [8]:
config = {"features" : {"fp_radius":2},
           "optimizer": {"batch_size": 4,
                         "clip_norm":19,
                         "learning_rate":0.00004592646200179472,
                         "stopping_patience":15,
                         "c_alpha":0.001},
           "model":{"embed_dim":485,
                  "hidden_dim":696,
                  "dropout":0.48541242824674574,
                  "n_layers": 4,
                  "norm": "batchnorm"},
          "env": {"fold": 0,
                  "device":"cpu",
                  "max_epochs": 100,
                  "search_hyperparameters":True}}

In [4]:
train_dataset, validation_dataset, test_dataset = get_data(n_fold = config["env"]["fold"],
                                                           fp_radius = config["features"]["fp_radius"])





In [5]:
cell_features, drug_features, target, drug_index, cell_indices, labels = train_dataset[4]

ins = train_dataset[7]
print(ins[5].shape)
print(ins[4].shape)

print("Drug Features Shape:", drug_features.shape)
print("Cell Features Shape:", cell_features.shape)  
print("Labels Shape:", labels.shape)
print("Target Shape:", target.shape)               
print("Drug Index Shape:", drug_index.shape)      
print("Cell Indices Shape:", cell_indices.shape) 




torch.Size([325])
torch.Size([325])
Drug Features Shape: torch.Size([323, 2048])
Cell Features Shape: torch.Size([323, 777])
Labels Shape: torch.Size([323])
Target Shape: torch.Size([323])
Drug Index Shape: torch.Size([323])
Cell Indices Shape: torch.Size([323])


In [6]:

def collate_fn_custom(batch):
    omics = []
    drugs = []
    targets = []
    cell_ids = []
    drug_ids = []
    labels = [] 

    for i, (o, d, t, cid, did, r) in enumerate(batch):
        omics.append(o)  # Add each omics tensor
        drugs.append(d)  # Add each drugs tensor
        targets.append(t)  # Add each target tensor
        cell_ids.append(cid)  # Add each cell ID tensor
        drug_ids.append(did)  # Add each drug ID tensor
        r = r.clone()
        r[r==1] = 2*i + 1
        r[r==2] = 2*(i + 1)
        labels.append(r)  # Add each label tensor
    # Concatenate along the first dimension for all tensors
    omics = torch.cat(omics, dim=0)
    drugs = torch.cat(drugs, dim=0)
    targets = torch.cat(targets, dim=0)
    cell_ids = torch.cat(cell_ids, dim=0)
    drug_ids = torch.cat(drug_ids, dim=0)
    labels = torch.cat(labels, dim=0)
    


    return omics, drugs, targets, cell_ids, drug_ids, labels


# Hyperparameter optimization

we wrap the function for training the model in a function that can be used by optuna

In [7]:
def train_model_optuna(trial, config):
    def pruning_callback(epoch, train_r):
        trial.report(train_r, step = epoch)
        if np.isnan(train_r):
            raise optuna.TrialPruned()
        if trial.should_prune():
            raise optuna.TrialPruned()
    config["model"] = {"embed_dim": trial.suggest_int("embed_dim", 64, 512),
                    "hidden_dim": trial.suggest_int("hidden_dim", 64, 2048),
                    "n_layers": trial.suggest_int("n_layers", 1, 6),
                    "norm": trial.suggest_categorical("norm", ["batchnorm", "layernorm", None]),
                    "dropout": trial.suggest_float("dropout", 0.0, 0.5)}
    config["optimizer"] = { "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-1, log=True),
                            "clip_norm": trial.suggest_int("clip_norm", 0.1, 20),
                            "batch_size": trial.suggest_int("batch_size", 2, 10),
                            "stopping_patience":10}
    try:
        R, model = scripts.train_model(config,
                                       train_dataset,
                                       validation_dataset,
                                       use_momentum=True,
                                       callback_epoch = pruning_callback,
                                       collate_fn=collate_fn_custom)
        
        return R
    except Exception as e:
        print(e)
        return 0

In [None]:



if config["env"]["search_hyperparameters"]:
    study_name = f"baseline_model"
    storage_name = "sqlite:///studies/{}.db".format(study_name)
    study = optuna.create_study(study_name=study_name,
                                storage=storage_name,
                                direction='maximize',
                                load_if_exists=True,
                                pruner=optuna.pruners.MedianPruner(n_startup_trials=30,
                                                               n_warmup_steps=5,
                                                               interval_steps=5))
    objective = lambda x: train_model_optuna(x, config)
    study.optimize(objective, n_trials=40)
    best_config = study.best_params
    print(best_config)
    config["model"]["embed_dim"] = best_config["embed_dim"]
    config["model"]["hidden_dim"] = best_config["hidden_dim"]
    config["model"]["n_layers"] = best_config["n_layers"]
    config["model"]["norm"] = best_config["norm"]
    config["model"]["dropout"] = best_config["dropout"]
    config["optimizer"]["learning_rate"] = best_config["learning_rate"]
    config["optimizer"]["clip_norm"] = best_config["clip_norm"]
    config["optimizer"]["batch_size"] = best_config["batch_size"]



[I 2025-02-07 15:54:10,974] Using an existing study with name 'baseline_model' instead of creating a new one.


sensitive torch.Size([17, 512])
resistant torch.Size([8, 512])
torch.Size([1, 17, 512])
torch.Size([17, 1, 512])
sensitive torch.Size([15, 512])
resistant torch.Size([22, 512])
torch.Size([1, 15, 512])
torch.Size([15, 1, 512])
sensitive torch.Size([23, 512])
resistant torch.Size([7, 512])
torch.Size([1, 23, 512])
torch.Size([23, 1, 512])
sensitive torch.Size([20, 512])
resistant torch.Size([11, 512])
torch.Size([1, 20, 512])
torch.Size([20, 1, 512])
sensitive torch.Size([24, 512])
resistant torch.Size([15, 512])
torch.Size([1, 24, 512])
torch.Size([24, 1, 512])
sensitive torch.Size([15, 512])
resistant torch.Size([31, 512])
torch.Size([1, 15, 512])
torch.Size([15, 1, 512])
sensitive torch.Size([0, 512])
resistant torch.Size([0, 512])
torch.Size([1, 0, 512])
torch.Size([0, 1, 512])
sensitive torch.Size([0, 512])
resistant torch.Size([0, 512])
torch.Size([1, 0, 512])
torch.Size([0, 1, 512])
sensitive torch.Size([0, 512])
resistant torch.Size([0, 512])
torch.Size([1, 0, 512])
torch.Size([

[I 2025-02-07 15:54:12,752] Trial 520 finished with value: 0.0 and parameters: {'embed_dim': 512, 'hidden_dim': 1362, 'n_layers': 3, 'norm': None, 'dropout': 0.4421278678625774, 'learning_rate': 3.910583034110933e-06, 'clip_norm': 15, 'batch_size': 6}. Best is trial 379 with value: 0.3149450400351803.


sensitive torch.Size([5, 319])
resistant torch.Size([30, 319])
torch.Size([1, 5, 319])
torch.Size([5, 1, 319])
sensitive torch.Size([22, 319])
resistant torch.Size([4, 319])
torch.Size([1, 22, 319])
torch.Size([22, 1, 319])
sensitive torch.Size([19, 319])
resistant torch.Size([10, 319])
torch.Size([1, 19, 319])
torch.Size([19, 1, 319])
sensitive torch.Size([5, 319])
resistant torch.Size([9, 319])
torch.Size([1, 5, 319])
torch.Size([5, 1, 319])
sensitive torch.Size([40, 319])
resistant torch.Size([8, 319])
torch.Size([1, 40, 319])
torch.Size([40, 1, 319])
sensitive torch.Size([20, 319])
resistant torch.Size([7, 319])
torch.Size([1, 20, 319])
torch.Size([20, 1, 319])
sensitive torch.Size([43, 319])
resistant torch.Size([0, 319])
torch.Size([1, 43, 319])
torch.Size([43, 1, 319])
sensitive torch.Size([28, 319])
resistant torch.Size([4, 319])
torch.Size([1, 28, 319])
torch.Size([28, 1, 319])
sensitive torch.Size([13, 319])
resistant torch.Size([4, 319])
torch.Size([1, 13, 319])
torch.Size([

[I 2025-02-07 15:54:14,650] Trial 521 finished with value: 0.0 and parameters: {'embed_dim': 319, 'hidden_dim': 1985, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.22743928213134712, 'learning_rate': 5.8368812863601746e-06, 'clip_norm': 7, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


sensitive torch.Size([27, 418])
resistant torch.Size([1, 418])
torch.Size([1, 27, 418])
torch.Size([27, 1, 418])
sensitive torch.Size([4, 418])
resistant torch.Size([29, 418])
torch.Size([1, 4, 418])
torch.Size([4, 1, 418])
sensitive torch.Size([51, 418])
resistant torch.Size([0, 418])
torch.Size([1, 51, 418])
torch.Size([51, 1, 418])
sensitive torch.Size([24, 418])
resistant torch.Size([3, 418])
torch.Size([1, 24, 418])
torch.Size([24, 1, 418])
sensitive torch.Size([7, 418])
resistant torch.Size([27, 418])
torch.Size([1, 7, 418])
torch.Size([7, 1, 418])
sensitive torch.Size([17, 418])
resistant torch.Size([26, 418])
torch.Size([1, 17, 418])
torch.Size([17, 1, 418])
sensitive torch.Size([17, 418])
resistant torch.Size([5, 418])
torch.Size([1, 17, 418])
torch.Size([17, 1, 418])
sensitive torch.Size([26, 418])
resistant torch.Size([22, 418])
torch.Size([1, 26, 418])
torch.Size([26, 1, 418])
sensitive torch.Size([2, 418])
resistant torch.Size([8, 418])
torch.Size([1, 2, 418])
torch.Size([

[I 2025-02-07 15:54:17,014] Trial 522 finished with value: 0.0 and parameters: {'embed_dim': 418, 'hidden_dim': 1124, 'n_layers': 5, 'norm': 'layernorm', 'dropout': 0.05198663641852695, 'learning_rate': 8.442599640502513e-06, 'clip_norm': 3, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


# Model training and evaluation

After we have a set of optimal hyperparameters we train our model. The train model function could be changed, but:
- test_dataset cannot be used until we call the final evaluation step
- the evaluation step cannot be modified, it must take the model produced by your pipeline, a dataloader that provides the correct data for your model, and the final metrics have to be printed

In [10]:
import scripts

%reload_ext autoreload

%autoreload 2

_, model = scripts.train_model(config, train_dataset,  validation_dataset=None, use_momentum=True, callback_epoch=None, collate_fn=collate_fn_custom)
device = torch.device(config["env"]["device"])
metrics = torchmetrics.MetricTracker(torchmetrics.MetricCollection(
    {"R_cellwise_residuals":scripts.GroupwiseMetric(metric=torchmetrics.functional.pearson_corrcoef,
                          grouping="drugs",
                          average="macro",
                          residualize=True),
    "R_cellwise":scripts.GroupwiseMetric(metric=torchmetrics.functional.pearson_corrcoef,
                          grouping="cell_lines",
                          average="macro",
                          residualize=False),
    "MSE":torchmetrics.MeanSquaredError()}))
metrics.to(device)
test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                      batch_size=config["optimizer"]["batch_size"],
                                      drop_last=False,
                                      shuffle=False,
                                      collate_fn=collate_fn_custom)


sensitive torch.Size([18, 137])
resistant torch.Size([3, 137])
torch.Size([1, 18, 137])
torch.Size([18, 1, 137])
sensitive torch.Size([8, 137])
resistant torch.Size([22, 137])
torch.Size([1, 8, 137])
torch.Size([8, 1, 137])
sensitive torch.Size([19, 137])
resistant torch.Size([12, 137])
torch.Size([1, 19, 137])
torch.Size([19, 1, 137])
sensitive torch.Size([4, 137])
resistant torch.Size([6, 137])
torch.Size([1, 4, 137])
torch.Size([4, 1, 137])
sensitive torch.Size([0, 137])
resistant torch.Size([41, 137])
torch.Size([1, 0, 137])
torch.Size([0, 1, 137])
sensitive torch.Size([32, 137])
resistant torch.Size([1, 137])
torch.Size([1, 32, 137])
torch.Size([32, 1, 137])
sensitive torch.Size([22, 137])
resistant torch.Size([4, 137])
torch.Size([1, 22, 137])
torch.Size([22, 1, 137])
sensitive torch.Size([10, 137])
resistant torch.Size([27, 137])
torch.Size([1, 10, 137])
torch.Size([10, 1, 137])
sensitive torch.Size([3, 137])
resistant torch.Size([15, 137])
torch.Size([1, 3, 137])
torch.Size([3,

KeyError: 'c_alpha'

In [16]:
final_metrics = scripts.evaluate_step(model, test_dataloader, metrics, device)
print(final_metrics)

  return torch.linalg.solve(A, Xy).T


{'MSE': 2.363308906555176, 'R_cellwise': 0.2313278466463089, 'R_cellwise_residuals': 0.21419179439544678}
