In [1]:
import pandas as pd
import numpy as np
import torch
import scripts
from functools import lru_cache
import torchmetrics
from torch import nn
import optuna

  from .autonotebook import tqdm as notebook_tqdm


# Data loading

First we load the data. The basic idea is to create dictionaries with features associated to the drugs and cell-lines. In principle, the splits and the data shouldn't be changed

In [2]:
@lru_cache(maxsize = None)
def get_data(n_fold = 0, fp_radius = 2):
    smile_dict = pd.read_csv("data/smiles.csv", index_col=0)
    fp = scripts.FingerprintFeaturizer(R = fp_radius)
    drug_dict = fp(smile_dict.iloc[:, 1], smile_dict.iloc[:, 0])
    driver_genes = pd.read_csv("data/driver_genes.csv").loc[:, "symbol"].dropna()
    rnaseq = pd.read_csv("data/rnaseq_normcount.csv", index_col=0)
    driver_columns = rnaseq.columns.isin(driver_genes)
    filtered_rna = rnaseq.loc[:, driver_columns]
    tensor_exp = torch.Tensor(filtered_rna.to_numpy())
    cell_dict = {cell: tensor_exp[i] for i, cell in enumerate(filtered_rna.index.to_numpy())}
    data = pd.read_csv("data/GDSC12.csv", index_col=0)
    # default, remove data where lines or drugs are missing:
    data = data.query("SANGER_MODEL_ID in @cell_dict.keys() & DRUG_ID in @drug_dict.keys()")
    unique_cell_lines = data.loc[:, "SANGER_MODEL_ID"].unique()
    np.random.seed(420) # for comparibility, don't change it!
    np.random.shuffle(unique_cell_lines)
    folds = np.array_split(unique_cell_lines, 10)
    test_lines = folds[0]
    train_idxs = list(range(10))
    train_idxs.remove(n_fold)
    np.random.seed(420)
    validation_idx = np.random.choice(train_idxs)
    train_idxs.remove(validation_idx)
    train_lines = np.concatenate([folds[idx] for idx in train_idxs])
    validation_lines = folds[validation_idx]
    test_lines = folds[n_fold]
    train_data = data.query("SANGER_MODEL_ID in @train_lines")
    validation_data = data.query("SANGER_MODEL_ID in @validation_lines")
    test_data = data.query("SANGER_MODEL_ID in @test_lines")
    return (scripts.OmicsDataset_drugwise(cell_dict, drug_dict, train_data),
    scripts.OmicsDataset_drugwise(cell_dict, drug_dict, validation_data),
    scripts.OmicsDataset_drugwise(cell_dict, drug_dict, test_data))

# Configuration

we declare the configuration, this is going to be model-specific and we get the datasets

In [4]:
config = {"features" : {"fp_radius":2},
           "optimizer": {"batch_size": 4,
                         "clip_norm":19,
                         "learning_rate":0.00004592646200179472,
                         "stopping_patience":15,
                         "c_alpha":0.01},
           "model":{"embed_dim":485,
                  "hidden_dim":696,
                  "dropout":0.48541242824674574,
                  "n_layers": 4,
                  "norm": "batchnorm"},
          "env": {"fold": 0,
                  "device":"cpu",
                  "max_epochs": 100,
                  "search_hyperparameters":True}}

In [5]:
train_dataset, validation_dataset, test_dataset = get_data(n_fold = config["env"]["fold"],
                                                           fp_radius = config["features"]["fp_radius"])





In [6]:
cell_features, drug_features, target, drug_index, cell_indices, labels = train_dataset[4]

ins = train_dataset[7]
print(ins[5].shape)
print(ins[4].shape)

print("Drug Features Shape:", drug_features.shape)
print("Cell Features Shape:", cell_features.shape)  
print("Labels Shape:", labels.shape)
print("Target Shape:", target.shape)               
print("Drug Index Shape:", drug_index.shape)      
print("Cell Indices Shape:", cell_indices.shape) 




torch.Size([325])
torch.Size([325])
Drug Features Shape: torch.Size([323, 2048])
Cell Features Shape: torch.Size([323, 777])
Labels Shape: torch.Size([323])
Target Shape: torch.Size([323])
Drug Index Shape: torch.Size([323])
Cell Indices Shape: torch.Size([323])


In [7]:

def collate_fn_custom(batch):
    omics = []
    drugs = []
    targets = []
    cell_ids = []
    drug_ids = []
    labels = [] 

    for i, (o, d, t, cid, did, r) in enumerate(batch):
        omics.append(o)  # Add each omics tensor
        drugs.append(d)  # Add each drugs tensor
        targets.append(t)  # Add each target tensor
        cell_ids.append(cid)  # Add each cell ID tensor
        drug_ids.append(did)  # Add each drug ID tensor
        r = r.clone()
        r[r==1] = 2*i + 1
        r[r==2] = 2*(i + 1)
        labels.append(r)  # Add each label tensor
    # Concatenate along the first dimension for all tensors
    omics = torch.cat(omics, dim=0)
    drugs = torch.cat(drugs, dim=0)
    targets = torch.cat(targets, dim=0)
    cell_ids = torch.cat(cell_ids, dim=0)
    drug_ids = torch.cat(drug_ids, dim=0)
    labels = torch.cat(labels, dim=0)
    


    return omics, drugs, targets, cell_ids, drug_ids, labels


# Hyperparameter optimization

we wrap the function for training the model in a function that can be used by optuna

In [None]:
def train_model_optuna(trial, config):
    def pruning_callback(epoch, train_r):
        trial.report(train_r, step = epoch)
        if np.isnan(train_r):
            raise optuna.TrialPruned()
        if trial.should_prune():
            raise optuna.TrialPruned()
    config["model"] = {"embed_dim": trial.suggest_int("embed_dim", 64, 512),
                    "hidden_dim": trial.suggest_int("hidden_dim", 64, 2048),
                    "n_layers": trial.suggest_int("n_layers", 1, 6),
                    "norm": trial.suggest_categorical("norm", ["batchnorm", "layernorm", None]),
                    "dropout": trial.suggest_float("dropout", 0.0, 0.5)}
    config["optimizer"] = { "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-1, log=True),
                            "clip_norm": trial.suggest_int("clip_norm", 0.1, 20),
                            "batch_size": trial.suggest_int("batch_size", 2, 6),
                            "stopping_patience":10,
                            "c_alpha": trial.suggest_float("c_alpha", 0.001, 0.1)}
    try:
        R, model = scripts.train_model(config,
                                       train_dataset,
                                       validation_dataset,
                                       use_momentum=True,
                                       callback_epoch = pruning_callback,
                                       collate_fn=collate_fn_custom)
        
        return R
    except Exception as e:
        print(e)
        return 0

In [9]:
import scripts

%reload_ext autoreload

%autoreload 2



if config["env"]["search_hyperparameters"]:
    study_name = f"baseline_model"
    storage_name = "sqlite:///studies/{}.db".format(study_name)
    study = optuna.create_study(study_name=study_name,
                                storage=storage_name,
                                direction='maximize',
                                load_if_exists=True,
                                pruner=optuna.pruners.MedianPruner(n_startup_trials=30,
                                                               n_warmup_steps=5,
                                                               interval_steps=5))
    objective = lambda x: train_model_optuna(x, config)
    study.optimize(objective, n_trials=40)
    best_config = study.best_params
    print(best_config)
    config["model"]["embed_dim"] = best_config["embed_dim"]
    config["model"]["hidden_dim"] = best_config["hidden_dim"]
    config["model"]["n_layers"] = best_config["n_layers"]
    config["model"]["norm"] = best_config["norm"]
    config["model"]["dropout"] = best_config["dropout"]
    config["optimizer"]["learning_rate"] = best_config["learning_rate"]
    config["optimizer"]["clip_norm"] = best_config["clip_norm"]
    config["optimizer"]["batch_size"] = best_config["batch_size"]
    config["model"]["c_alpha"] = best_config["c_alpha"]



[I 2025-02-14 15:43:14,035] Using an existing study with name 'baseline_model' instead of creating a new one.
  return torch.linalg.solve(A, Xy).T
[I 2025-02-14 16:01:26,148] Trial 723 finished with value: 0.0 and parameters: {'embed_dim': 442, 'hidden_dim': 363, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.10718473630354604, 'learning_rate': 0.00016956947424386467, 'clip_norm': 16, 'batch_size': 12, 'c_alpha': 0.004761148731176376}. Best is trial 379 with value: 0.3149450400351803.





[I 2025-02-14 16:10:31,305] Trial 724 finished with value: 0.0 and parameters: {'embed_dim': 469, 'hidden_dim': 419, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.11431379071590313, 'learning_rate': 0.00011833022529002715, 'clip_norm': 17, 'batch_size': 11, 'c_alpha': 0.02847918716102137}. Best is trial 379 with value: 0.3149450400351803.





[I 2025-02-14 16:13:04,596] Trial 725 finished with value: 0.0 and parameters: {'embed_dim': 463, 'hidden_dim': 357, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.10882361398257548, 'learning_rate': 0.00020879863371138418, 'clip_norm': 16, 'batch_size': 12, 'c_alpha': 0.024186120979785487}. Best is trial 379 with value: 0.3149450400351803.





[I 2025-02-14 16:14:47,627] Trial 726 finished with value: 0.0 and parameters: {'embed_dim': 476, 'hidden_dim': 395, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.10049514452138893, 'learning_rate': 0.000152268989268865, 'clip_norm': 17, 'batch_size': 11, 'c_alpha': 0.024114870392056516}. Best is trial 379 with value: 0.3149450400351803.





[I 2025-02-14 16:16:29,939] Trial 727 finished with value: 0.0 and parameters: {'embed_dim': 473, 'hidden_dim': 484, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.12182394534395748, 'learning_rate': 0.0002034415260268505, 'clip_norm': 18, 'batch_size': 11, 'c_alpha': 0.06629208995676242}. Best is trial 379 with value: 0.3149450400351803.





[I 2025-02-14 16:17:22,290] Trial 728 finished with value: 0.0 and parameters: {'embed_dim': 484, 'hidden_dim': 352, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.13613440898385418, 'learning_rate': 0.00010548538097863983, 'clip_norm': 17, 'batch_size': 11, 'c_alpha': 0.011734915728058275}. Best is trial 379 with value: 0.3149450400351803.





[W 2025-02-14 16:17:38,196] Trial 729 failed with parameters: {'embed_dim': 471, 'hidden_dim': 289, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.12081951580743718, 'learning_rate': 0.00014801626986905912, 'clip_norm': 17, 'batch_size': 12, 'c_alpha': 0.04038640151047405} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "/home/karolyi/miniconda/envs/project_thesis/lib/python3.13/site-packages/optuna/study/_optimize.py", line 197, in _run_trial
    value_or_values = func(trial)
  File "/tmp/ipykernel_1615646/2979125567.py", line 19, in <lambda>
    objective = lambda x: train_model_optuna(x, config)
                          ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^
  File "/tmp/ipykernel_1615646/1248304885.py", line 19, in train_model_optuna
    R, model = scripts.train_model(config,
               ~~~~~~~~~~~~~~~~~~~^^^^^^^^
                                   train_dataset,
                                   ^^^^^^^^^^^^^^
    ...<2 lines>...
     

KeyboardInterrupt: 

# Model training and evaluation

After we have a set of optimal hyperparameters we train our model. The train model function could be changed, but:
- test_dataset cannot be used until we call the final evaluation step
- the evaluation step cannot be modified, it must take the model produced by your pipeline, a dataloader that provides the correct data for your model, and the final metrics have to be printed

In [13]:
import scripts

%reload_ext autoreload

%autoreload 2

_, model = scripts.train_model(config, train_dataset,  validation_dataset=None, use_momentum=True, callback_epoch=None, collate_fn=collate_fn_custom)
device = torch.device(config["env"]["device"])
metrics = torchmetrics.MetricTracker(torchmetrics.MetricCollection(
    {"R_cellwise_residuals":scripts.GroupwiseMetric(metric=torchmetrics.functional.pearson_corrcoef,
                          grouping="drugs",
                          average="macro",
                          residualize=True),
    "R_cellwise":scripts.GroupwiseMetric(metric=torchmetrics.functional.pearson_corrcoef,
                          grouping="cell_lines",
                          average="macro",
                          residualize=False),
    "MSE":torchmetrics.MeanSquaredError()}))
metrics.to(device)
test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                      batch_size=config["optimizer"]["batch_size"],
                                      drop_last=False,
                                      shuffle=False,
                                      collate_fn=collate_fn_custom)


KeyError: 'c_alpha'

In [16]:
final_metrics = scripts.evaluate_step(model, test_dataloader, metrics, device)
print(final_metrics)

  return torch.linalg.solve(A, Xy).T


{'MSE': 2.363308906555176, 'R_cellwise': 0.2313278466463089, 'R_cellwise_residuals': 0.21419179439544678}
