In [1]:
import pandas as pd
import numpy as np
import torch
import scripts
from functools import lru_cache
import torchmetrics
from torch import nn
import optuna

  from .autonotebook import tqdm as notebook_tqdm


# Data loading

First we load the data. The basic idea is to create dictionaries with features associated to the drugs and cell-lines. In principle, the splits and the data shouldn't be changed

In [2]:
@lru_cache(maxsize = None)
def get_data(n_fold = 0, fp_radius = 2):
    smile_dict = pd.read_csv("data/smiles.csv", index_col=0)
    fp = scripts.FingerprintFeaturizer(R = fp_radius)
    drug_dict = fp(smile_dict.iloc[:, 1], smile_dict.iloc[:, 0])
    driver_genes = pd.read_csv("data/driver_genes.csv").loc[:, "symbol"].dropna()
    rnaseq = pd.read_csv("data/rnaseq_normcount.csv", index_col=0)
    driver_columns = rnaseq.columns.isin(driver_genes)
    filtered_rna = rnaseq.loc[:, driver_columns]
    tensor_exp = torch.Tensor(filtered_rna.to_numpy())
    cell_dict = {cell: tensor_exp[i] for i, cell in enumerate(filtered_rna.index.to_numpy())}
    data = pd.read_csv("data/GDSC12.csv", index_col=0)
    # default, remove data where lines or drugs are missing:
    data = data.query("SANGER_MODEL_ID in @cell_dict.keys() & DRUG_ID in @drug_dict.keys()")
    unique_cell_lines = data.loc[:, "SANGER_MODEL_ID"].unique()
    np.random.seed(420) # for comparibility, don't change it!
    np.random.shuffle(unique_cell_lines)
    folds = np.array_split(unique_cell_lines, 10)
    test_lines = folds[0]
    train_idxs = list(range(10))
    train_idxs.remove(n_fold)
    np.random.seed(420)
    validation_idx = np.random.choice(train_idxs)
    train_idxs.remove(validation_idx)
    train_lines = np.concatenate([folds[idx] for idx in train_idxs])
    validation_lines = folds[validation_idx]
    test_lines = folds[n_fold]
    train_data = data.query("SANGER_MODEL_ID in @train_lines")
    validation_data = data.query("SANGER_MODEL_ID in @validation_lines")
    test_data = data.query("SANGER_MODEL_ID in @test_lines")
    return (scripts.OmicsDataset_drugwise(cell_dict, drug_dict, train_data),
    scripts.OmicsDataset_drugwise(cell_dict, drug_dict, validation_data),
    scripts.OmicsDataset_drugwise(cell_dict, drug_dict, test_data))

# Configuration

we declare the configuration, this is going to be model-specific and we get the datasets

In [9]:
config = {"features" : {"fp_radius":2},
          "optimizer": {"batch_size": 4,
                        "clip_norm":19,
                        "learning_rate":0.0004592646200179472,
                        "stopping_patience":15},
          "model":{"embed_dim":485,
                 "hidden_dim":696,
                 "dropout":0.48541242824674574,
                 "n_layers": 4,
                 "norm": "batchnorm"},
         "env": {"fold": 0,
                 "device":"cpu",
                 "max_epochs": 100,
                 "search_hyperparameters":True}}

In [5]:
train_dataset, validation_dataset, test_dataset = get_data(n_fold = config["env"]["fold"],
                                                           fp_radius = config["features"]["fp_radius"])





In [5]:
cell_features, drug_features, target, drug_index, cell_indices, labels = train_dataset[4]

ins = train_dataset[7]
print(ins[5].shape)
print(ins[4].shape)

print("Drug Features Shape:", drug_features.shape)
print("Cell Features Shape:", cell_features.shape)  
print("Labels Shape:", labels.shape)
print("Target Shape:", target.shape)               
print("Drug Index Shape:", drug_index.shape)      
print("Cell Indices Shape:", cell_indices.shape) 




torch.Size([325])
torch.Size([325])
Drug Features Shape: torch.Size([323, 2048])
Cell Features Shape: torch.Size([323, 777])
Labels Shape: torch.Size([323])
Target Shape: torch.Size([323])
Drug Index Shape: torch.Size([323])
Cell Indices Shape: torch.Size([323])


# Hyperparameter optimization

we wrap the function for training the model in a function that can be used by optuna

In [6]:
def train_model_optuna(trial, config):
    def pruning_callback(epoch, train_r):
        trial.report(train_r, step = epoch)
        if np.isnan(train_r):
            raise optuna.TrialPruned()
        if trial.should_prune():
            raise optuna.TrialPruned()
    config["model"] = {"embed_dim": trial.suggest_int("embed_dim", 64, 512),
                    "hidden_dim": trial.suggest_int("hidden_dim", 64, 2048),
                    "n_layers": trial.suggest_int("n_layers", 1, 6),
                    "norm": trial.suggest_categorical("norm", ["batchnorm", "layernorm", None]),
                    "dropout": trial.suggest_float("dropout", 0.0, 0.5)}
    config["optimizer"] = { "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-1, log=True),
                            "clip_norm": trial.suggest_int("clip_norm", 0.1, 20),
                            "batch_size": trial.suggest_int("batch_size", 2, 10),
                            "stopping_patience":10}
    try:
        R, model = scripts.train_model(config,
                                       train_dataset,
                                       validation_dataset,
                                       use_momentum=True,
                                       callback_epoch = pruning_callback,
                                       collate_fn=collate_fn_custom)
        
        return R
    except Exception as e:
        print(e)
        return 0

In [None]:
'''
def collate_fn_custom(batch):
    omics = []
    drugs = []
    targets = []
    cell_ids = []
    drug_ids = []
    labels = [] 

    for o, d, t, cid, did, r in batch:
        omics.append(o)  # Add each omics tensor
        drugs.append(d)  # Add each drugs tensor
        targets.append(t)  # Add each target tensor
        cell_ids.append(cid)  # Add each cell ID tensor
        drug_ids.append(did)  # Add each drug ID tensor
        labels.append(r)  # Add each label tensor

    # Concatenate along the first dimension for all tensors
    omics = torch.cat(omics, dim=0)
    drugs = torch.cat(drugs, dim=0)
    targets = torch.cat(targets, dim=0)
    cell_ids = torch.cat(cell_ids, dim=0)
    drug_ids = torch.cat(drug_ids, dim=0)
    labels = torch.cat(labels, dim=0)

    return omics, drugs, targets, cell_ids, drug_ids, labels
'''

In [26]:
def collate_fn_custom(batch):
    import torch

    omics = []
    drugs = []
    targets = []
    cell_ids = []
    drug_ids = []
    labels = [] 

    for o, d, t, cid, did, r in batch:
        omics.append(o)  # Add each omics tensor
        drugs.append(d)  # Add each drugs tensor
        targets.append(t)  # Add each target tensor
        cell_ids.append(cid)  # Add each cell ID tensor
        drug_ids.append(did)  # Add each drug ID tensor
        labels.append(r)  # Add each label tensor

    # Concatenate along the first dimension for all tensors
    omics = torch.cat(omics, dim=0)
    drugs = torch.cat(drugs, dim=0)
    targets = torch.cat(targets, dim=0)
    cell_ids = torch.cat(cell_ids, dim=0)
    drug_ids = torch.cat(drug_ids, dim=0)
    labels = torch.cat(labels, dim=0)

    # Split each tensor into two halves
    size = omics.size(0)
    if size % 2 != 0:
        omics = omics[:size - 1]
        drugs = drugs[:size - 1]
        targets = targets[:size - 1]
        cell_ids = cell_ids[:size - 1]
        drug_ids = drug_ids[:size - 1]
        labels = labels[:size - 1]


    mid_index = omics.size(0) // 2

    # Combine the halves along the last dimension
    '''
    omics_combined = torch.cat((omics[:mid_index], omics[mid_index:]), dim=-1)
    drugs_combined = torch.cat((drugs[:mid_index], drugs[mid_index:]), dim=-1)
    targets_combined = torch.cat((targets[:mid_index], targets[mid_index:]), dim=-1)
    cell_ids_combined = torch.cat((cell_ids[:mid_index], cell_ids[mid_index:]), dim=-1)
    drug_ids_combined = torch.cat((drug_ids[:mid_index], drug_ids[mid_index:]), dim=-1)
    labels_combined = torch.cat((labels[:mid_index], labels[mid_index:]), dim=-1)
    '''
    omics1 = omics[:mid_index]
    omics2 = omics[mid_index:]
    drugs1 = drugs[:mid_index]
    drugs2 = drugs[mid_index:]
    targets1 = targets[:mid_index]
    targets2 = targets[mid_index:]
    cell_ids1 = cell_ids[:mid_index]
    cell_ids2 = cell_ids[mid_index:]
    drug_ids1 = drug_ids[:mid_index]
    drug_ids2 = drug_ids[mid_index:]
    labels1 = labels[:mid_index]
    labels2 = labels[mid_index:]

    pairs = []    
    for i in range(0,mid_index):
        if labels1[i] == 0:
            pairs.append(0)
        elif labels2[i] == 0:
            pairs.append(0)
        elif labels1[i] == labels2[i]:
            pairs.append(1) # positive pairs
        else:
            pairs.append(-1)  # negative pairs 
        
    



    return omics1, drugs1, targets1, cell_ids1, drug_ids1, labels1, omics2, drugs2, targets2, cell_ids2, drug_ids2, labels2, pairs


In [None]:



if config["env"]["search_hyperparameters"]:
    study_name = f"baseline_model"
    storage_name = "sqlite:///studies/{}.db".format(study_name)
    study = optuna.create_study(study_name=study_name,
                                storage=storage_name,
                                direction='maximize',
                                load_if_exists=True,
                                pruner=optuna.pruners.MedianPruner(n_startup_trials=30,
                                                               n_warmup_steps=5,
                                                               interval_steps=5))
    objective = lambda x: train_model_optuna(x, config)
    study.optimize(objective, n_trials=40)
    best_config = study.best_params
    print(best_config)
    config["model"]["embed_dim"] = best_config["embed_dim"]
    config["model"]["hidden_dim"] = best_config["hidden_dim"]
    config["model"]["n_layers"] = best_config["n_layers"]
    config["model"]["norm"] = best_config["norm"]
    config["model"]["dropout"] = best_config["dropout"]
    config["optimizer"]["learning_rate"] = best_config["learning_rate"]
    config["optimizer"]["clip_norm"] = best_config["clip_norm"]
    config["optimizer"]["batch_size"] = best_config["batch_size"]



[I 2025-01-09 16:42:52,276] Using an existing study with name 'baseline_model' instead of creating a new one.
  return torch.linalg.solve(A, Xy).T





[I 2025-01-09 16:46:17,299] Trial 425 finished with value: 0.0 and parameters: {'embed_dim': 492, 'hidden_dim': 1655, 'n_layers': 4, 'norm': 'batchnorm', 'dropout': 0.2130240404343698, 'learning_rate': 1.2799926346797886e-06, 'clip_norm': 11, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


# Model training and evaluation

After we have a set of optimal hyperparameters we train our model. The train model function could be changed, but:
- test_dataset cannot be used until we call the final evaluation step
- the evaluation step cannot be modified, it must take the model produced by your pipeline, a dataloader that provides the correct data for your model, and the final metrics have to be printed

In [35]:
import scripts

%reload_ext autoreload

%autoreload 2

_, model = scripts.train_model(config, train_dataset,  validation_dataset=None, use_momentum=True, callback_epoch=None, collate_fn=collate_fn_custom)
device = torch.device(config["env"]["device"])
metrics = torchmetrics.MetricTracker(torchmetrics.MetricCollection(
    {"R_cellwise_residuals":scripts.GroupwiseMetric(metric=torchmetrics.functional.pearson_corrcoef,
                          grouping="drugs",
                          average="macro",
                          residualize=True),
    "R_cellwise":scripts.GroupwiseMetric(metric=torchmetrics.functional.pearson_corrcoef,
                          grouping="cell_lines",
                          average="macro",
                          residualize=False),
    "MSE":torchmetrics.MeanSquaredError()}))
metrics.to(device)
test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                      batch_size=config["optimizer"]["batch_size"],
                                      drop_last=False,
                                      shuffle=False,
                                      collate_fn=collate_fn_custom)




epoch : 0: train loss: 7.532603945487585 Smoothed R interaction (validation) None
epoch : 1: train loss: 5.926481442573743 Smoothed R interaction (validation) None
epoch : 2: train loss: 5.679132684683188 Smoothed R interaction (validation) None
epoch : 3: train loss: 5.1906854922954855 Smoothed R interaction (validation) None
epoch : 4: train loss: 4.475778503295703 Smoothed R interaction (validation) None
epoch : 5: train loss: 4.747343601324619 Smoothed R interaction (validation) None
epoch : 6: train loss: 4.332390890671657 Smoothed R interaction (validation) None
epoch : 7: train loss: 4.22994574216696 Smoothed R interaction (validation) None
epoch : 8: train loss: 3.7467021391941953 Smoothed R interaction (validation) None
epoch : 9: train loss: 3.4907981126736374 Smoothed R interaction (validation) None
epoch : 10: train loss: 3.799708529924735 Smoothed R interaction (validation) None
epoch : 11: train loss: 3.54850672911375 Smoothed R interaction (validation) None
epoch : 12: t

In [32]:
final_metrics = scripts.evaluate_step(model, test_dataloader, metrics, device)
print(final_metrics)

  return torch.linalg.solve(A, Xy).T


{'MSE': 1.8528733253479004, 'R_cellwise': 0.31360891461372375, 'R_cellwise_residuals': 0.25742870569229126}
