In [1]:
import pandas as pd
import numpy as np
import torch
import scripts
from functools import lru_cache
import torchmetrics
from torch import nn
import optuna

  from .autonotebook import tqdm as notebook_tqdm


# Data loading

First we load the data. The basic idea is to create dictionaries with features associated to the drugs and cell-lines. In principle, the splits and the data shouldn't be changed

In [2]:
@lru_cache(maxsize = None)
def get_data(n_fold = 0, fp_radius = 2):
    smile_dict = pd.read_csv("data/smiles.csv", index_col=0)
    fp = scripts.FingerprintFeaturizer(R = fp_radius)
    drug_dict = fp(smile_dict.iloc[:, 1], smile_dict.iloc[:, 0])
    driver_genes = pd.read_csv("data/driver_genes.csv").loc[:, "symbol"].dropna()
    rnaseq = pd.read_csv("data/rnaseq_normcount.csv", index_col=0)
    driver_columns = rnaseq.columns.isin(driver_genes)
    filtered_rna = rnaseq.loc[:, driver_columns]
    tensor_exp = torch.Tensor(filtered_rna.to_numpy())
    cell_dict = {cell: tensor_exp[i] for i, cell in enumerate(filtered_rna.index.to_numpy())}
    data = pd.read_csv("data/GDSC12.csv", index_col=0)
    # default, remove data where lines or drugs are missing:
    data = data.query("SANGER_MODEL_ID in @cell_dict.keys() & DRUG_ID in @drug_dict.keys()")
    unique_cell_lines = data.loc[:, "SANGER_MODEL_ID"].unique()
    np.random.seed(420) # for comparibility, don't change it!
    np.random.shuffle(unique_cell_lines)
    folds = np.array_split(unique_cell_lines, 10)
    test_lines = folds[0]
    train_idxs = list(range(10))
    train_idxs.remove(n_fold)
    np.random.seed(420)
    validation_idx = np.random.choice(train_idxs)
    train_idxs.remove(validation_idx)
    train_lines = np.concatenate([folds[idx] for idx in train_idxs])
    validation_lines = folds[validation_idx]
    test_lines = folds[n_fold]
    train_data = data.query("SANGER_MODEL_ID in @train_lines")
    validation_data = data.query("SANGER_MODEL_ID in @validation_lines")
    test_data = data.query("SANGER_MODEL_ID in @test_lines")
    return (scripts.OmicsDataset_drugwise(cell_dict, drug_dict, train_data),
    scripts.OmicsDataset_drugwise(cell_dict, drug_dict, validation_data),
    scripts.OmicsDataset_drugwise(cell_dict, drug_dict, test_data))

# Configuration

we declare the configuration, this is going to be model-specific and we get the datasets

In [3]:
config = {"features" : {"fp_radius":2},
           "optimizer": {"batch_size": 4,
                         "clip_norm":19,
                         "learning_rate":0.00004592646200179472,
                         "stopping_patience":15,
                         "c_alpha":0.01},
           "model":{"embed_dim":485,
                  "hidden_dim":696,
                  "dropout":0.48541242824674574,
                  "n_layers": 4,
                  "norm": "batchnorm"},
          "env": {"fold": 0,
                  "device":"cpu",
                  "max_epochs": 100,
                  "search_hyperparameters":True}}

In [4]:
train_dataset, validation_dataset, test_dataset = get_data(n_fold = config["env"]["fold"],
                                                           fp_radius = config["features"]["fp_radius"])





In [5]:
cell_features, drug_features, target, drug_index, cell_indices, labels = train_dataset[4]

ins = train_dataset[7]
print(ins[5].shape)
print(ins[4].shape)

print("Drug Features Shape:", drug_features.shape)
print("Cell Features Shape:", cell_features.shape)  
print("Labels Shape:", labels.shape)
print("Target Shape:", target.shape)               
print("Drug Index Shape:", drug_index.shape)      
print("Cell Indices Shape:", cell_indices.shape) 




torch.Size([325])
torch.Size([325])
Drug Features Shape: torch.Size([323, 2048])
Cell Features Shape: torch.Size([323, 777])
Labels Shape: torch.Size([323])
Target Shape: torch.Size([323])
Drug Index Shape: torch.Size([323])
Cell Indices Shape: torch.Size([323])


In [6]:

def collate_fn_custom(batch):
    omics = []
    drugs = []
    targets = []
    cell_ids = []
    drug_ids = []
    labels = [] 

    for i, (o, d, t, cid, did, r) in enumerate(batch):
        omics.append(o)  # Add each omics tensor
        drugs.append(d)  # Add each drugs tensor
        targets.append(t)  # Add each target tensor
        cell_ids.append(cid)  # Add each cell ID tensor
        drug_ids.append(did)  # Add each drug ID tensor
        r = r.clone()
        r[r==1] = 2*i + 1
        r[r==2] = 2*(i + 1)
        labels.append(r)  # Add each label tensor
    # Concatenate along the first dimension for all tensors
    omics = torch.cat(omics, dim=0)
    drugs = torch.cat(drugs, dim=0)
    targets = torch.cat(targets, dim=0)
    cell_ids = torch.cat(cell_ids, dim=0)
    drug_ids = torch.cat(drug_ids, dim=0)
    labels = torch.cat(labels, dim=0)
    


    return omics, drugs, targets, cell_ids, drug_ids, labels


# Hyperparameter optimization

we wrap the function for training the model in a function that can be used by optuna

In [None]:
def train_model_optuna(trial, config):
    def pruning_callback(epoch, train_r):
        trial.report(train_r, step = epoch)
        if np.isnan(train_r):
            raise optuna.TrialPruned()
        if trial.should_prune():
            raise optuna.TrialPruned()
    config["model"] = {"embed_dim": trial.suggest_int("embed_dim", 64, 512),
                    "hidden_dim": trial.suggest_int("hidden_dim", 64, 2048),
                    "n_layers": trial.suggest_int("n_layers", 1, 6),
                    "norm": trial.suggest_categorical("norm", ["batchnorm", "layernorm", None]),
                    "dropout": trial.suggest_float("dropout", 0.0, 0.5)}
    config["optimizer"] = { "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-1, log=True),
                            "clip_norm": trial.suggest_int("clip_norm", 0.1, 20),
                            "batch_size": trial.suggest_int("batch_size", 6, 12),
                            "stopping_patience":10,
                            "c_alpha": trial.suggest_float("c_alpha", 0.001, 0.01)}
    try:
        R, model = scripts.train_model(config,
                                       train_dataset,
                                       validation_dataset,
                                       use_momentum=True,
                                       callback_epoch = pruning_callback,
                                       collate_fn=collate_fn_custom)
        
        return R
    except Exception as e:
        print(e)
        return 0

In [8]:
import scripts

%reload_ext autoreload

%autoreload 2



if config["env"]["search_hyperparameters"]:
    study_name = f"baseline_model"
    storage_name = "sqlite:///studies/{}.db".format(study_name)
    study = optuna.create_study(study_name=study_name,
                                storage=storage_name,
                                direction='maximize',
                                load_if_exists=True,
                                pruner=optuna.pruners.MedianPruner(n_startup_trials=30,
                                                               n_warmup_steps=5,
                                                               interval_steps=5))
    objective = lambda x: train_model_optuna(x, config)
    study.optimize(objective, n_trials=40)
    best_config = study.best_params
    print(best_config)
    config["model"]["embed_dim"] = best_config["embed_dim"]
    config["model"]["hidden_dim"] = best_config["hidden_dim"]
    config["model"]["n_layers"] = best_config["n_layers"]
    config["model"]["norm"] = best_config["norm"]
    config["model"]["dropout"] = best_config["dropout"]
    config["optimizer"]["learning_rate"] = best_config["learning_rate"]
    config["optimizer"]["clip_norm"] = best_config["clip_norm"]
    config["optimizer"]["batch_size"] = best_config["batch_size"]
    config["model"]["c_alpha"] = best_config["c_alpha"]



[I 2025-02-18 13:35:58,374] Using an existing study with name 'baseline_model' instead of creating a new one.
  return torch.linalg.solve(A, Xy).T
[I 2025-02-18 13:37:14,273] Trial 776 finished with value: 0.0 and parameters: {'embed_dim': 265, 'hidden_dim': 310, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.10208840726658444, 'learning_rate': 5.450025146041031e-05, 'clip_norm': 17, 'batch_size': 3, 'c_alpha': 0.010489403437990402}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 13:38:07,788] Trial 777 finished with value: 0.0 and parameters: {'embed_dim': 443, 'hidden_dim': 408, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.09336403656761488, 'learning_rate': 7.22585627340308e-05, 'clip_norm': 18, 'batch_size': 3, 'c_alpha': 0.005228611982399028}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 13:38:43,244] Trial 778 finished with value: 0.0 and parameters: {'embed_dim': 139, 'hidden_dim': 277, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.08032180873572226, 'learning_rate': 0.00010854875257767366, 'clip_norm': 16, 'batch_size': 5, 'c_alpha': 0.01381401252088977}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 13:41:52,282] Trial 779 finished with value: 0.0 and parameters: {'embed_dim': 130, 'hidden_dim': 459, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.12806210157605477, 'learning_rate': 0.00017079677485126306, 'clip_norm': 18, 'batch_size': 4, 'c_alpha': 0.004380824517558049}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 13:42:33,831] Trial 780 finished with value: 0.0 and parameters: {'embed_dim': 248, 'hidden_dim': 228, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.1035554794599842, 'learning_rate': 8.599733082119479e-05, 'clip_norm': 17, 'batch_size': 3, 'c_alpha': 0.017365451440899835}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 13:43:07,485] Trial 781 finished with value: 0.0 and parameters: {'embed_dim': 190, 'hidden_dim': 349, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.11454579877313943, 'learning_rate': 5.808520583711319e-05, 'clip_norm': 16, 'batch_size': 5, 'c_alpha': 0.008398270360818867}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 13:44:16,857] Trial 782 finished with value: 0.0 and parameters: {'embed_dim': 229, 'hidden_dim': 145, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.0819256016806876, 'learning_rate': 0.0001198882839618522, 'clip_norm': 17, 'batch_size': 3, 'c_alpha': 0.04005884698619659}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 13:44:52,225] Trial 783 finished with value: 0.0 and parameters: {'embed_dim': 160, 'hidden_dim': 313, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.11749405463374961, 'learning_rate': 7.75370218047305e-05, 'clip_norm': 18, 'batch_size': 3, 'c_alpha': 0.003924371911544542}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 13:46:47,696] Trial 784 finished with value: 0.0 and parameters: {'embed_dim': 459, 'hidden_dim': 238, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.0894439156813659, 'learning_rate': 0.00010169041682721735, 'clip_norm': 17, 'batch_size': 4, 'c_alpha': 0.014685086359714406}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 13:48:45,451] Trial 785 finished with value: 0.0 and parameters: {'embed_dim': 479, 'hidden_dim': 97, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.13299815908811333, 'learning_rate': 0.0001359548439260335, 'clip_norm': 16, 'batch_size': 4, 'c_alpha': 0.02295209140860993}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 13:49:47,346] Trial 786 finished with value: 0.0 and parameters: {'embed_dim': 240, 'hidden_dim': 339, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.3659773129929042, 'learning_rate': 0.0001851351701159502, 'clip_norm': 18, 'batch_size': 5, 'c_alpha': 0.030172226986796027}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 13:50:21,460] Trial 787 finished with value: 0.0 and parameters: {'embed_dim': 153, 'hidden_dim': 448, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.1015189522111862, 'learning_rate': 6.648979875448681e-05, 'clip_norm': 16, 'batch_size': 3, 'c_alpha': 0.09668302988821839}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 13:51:00,890] Trial 788 finished with value: 0.0 and parameters: {'embed_dim': 436, 'hidden_dim': 113, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.0944332853324994, 'learning_rate': 9.529809019677528e-05, 'clip_norm': 17, 'batch_size': 5, 'c_alpha': 0.009570303718432675}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 13:51:36,072] Trial 789 finished with value: 0.0 and parameters: {'embed_dim': 175, 'hidden_dim': 416, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.1161177766006473, 'learning_rate': 5.032256533445251e-05, 'clip_norm': 17, 'batch_size': 4, 'c_alpha': 0.0011710952247103877}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 13:52:15,964] Trial 790 finished with value: 0.0 and parameters: {'embed_dim': 278, 'hidden_dim': 165, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.0770919881197511, 'learning_rate': 0.00013023795203017204, 'clip_norm': 18, 'batch_size': 5, 'c_alpha': 0.05107779938833425}. Best is trial 752 with value: 0.34190896010425487.






[I 2025-02-18 13:53:59,902] Trial 791 finished with value: 0.0 and parameters: {'embed_dim': 287, 'hidden_dim': 258, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.11887817145788565, 'learning_rate': 0.00016261896415479843, 'clip_norm': 18, 'batch_size': 3, 'c_alpha': 0.017723819882147725}. Best is trial 752 with value: 0.34190896010425487.
[I 2025-02-18 13:54:34,347] Trial 792 finished with value: 0.0 and parameters: {'embed_dim': 78, 'hidden_dim': 205, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.10480558697682336, 'learning_rate': 7.206740952396231e-05, 'clip_norm': 1, 'batch_size': 3, 'c_alpha': 0.006791379577197798}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 13:55:53,158] Trial 793 finished with value: 0.0 and parameters: {'embed_dim': 467, 'hidden_dim': 316, 'n_layers': 6, 'norm': 'batchnorm', 'dropout': 0.4094170515603797, 'learning_rate': 8.658720456387971e-05, 'clip_norm': 17, 'batch_size': 3, 'c_alpha': 0.07088485196728905}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 13:56:55,071] Trial 794 finished with value: 0.0 and parameters: {'embed_dim': 64, 'hidden_dim': 366, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.10499817551083136, 'learning_rate': 0.00023084178673255453, 'clip_norm': 19, 'batch_size': 4, 'c_alpha': 0.012614377352270756}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 13:57:32,750] Trial 795 finished with value: 0.0 and parameters: {'embed_dim': 146, 'hidden_dim': 192, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.07629837960127925, 'learning_rate': 0.00011510716015965542, 'clip_norm': 17, 'batch_size': 3, 'c_alpha': 0.0037532584892282805}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 13:58:16,540] Trial 796 finished with value: 0.0 and parameters: {'embed_dim': 470, 'hidden_dim': 109, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.45151207092957685, 'learning_rate': 5.800388442310509e-05, 'clip_norm': 18, 'batch_size': 5, 'c_alpha': 0.02628136006155294}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 13:59:10,765] Trial 797 finished with value: 0.0 and parameters: {'embed_dim': 166, 'hidden_dim': 297, 'n_layers': 6, 'norm': 'batchnorm', 'dropout': 0.3399487889188295, 'learning_rate': 0.00010395077652965321, 'clip_norm': 0, 'batch_size': 3, 'c_alpha': 0.020993921710538585}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 14:04:30,981] Trial 798 finished with value: 0.0 and parameters: {'embed_dim': 452, 'hidden_dim': 362, 'n_layers': 6, 'norm': 'batchnorm', 'dropout': 0.12589088164453482, 'learning_rate': 0.00013965171027831261, 'clip_norm': 18, 'batch_size': 4, 'c_alpha': 0.009617737988158626}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 14:06:49,884] Trial 799 finished with value: 0.0 and parameters: {'embed_dim': 206, 'hidden_dim': 267, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.4996335652656398, 'learning_rate': 7.695027105780489e-05, 'clip_norm': 1, 'batch_size': 3, 'c_alpha': 0.007924528718017604}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 14:08:16,586] Trial 800 finished with value: 0.0 and parameters: {'embed_dim': 125, 'hidden_dim': 155, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.0902223461438596, 'learning_rate': 0.07519842711767094, 'clip_norm': 16, 'batch_size': 6, 'c_alpha': 0.0033644820524185277}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 14:14:34,078] Trial 801 finished with value: 0.0 and parameters: {'embed_dim': 255, 'hidden_dim': 525, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.38277355715820754, 'learning_rate': 0.00018152370720504597, 'clip_norm': 17, 'batch_size': 4, 'c_alpha': 0.006606668900605073}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 14:15:12,093] Trial 802 finished with value: 0.0 and parameters: {'embed_dim': 300, 'hidden_dim': 224, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.1443630698454927, 'learning_rate': 6.244536206364e-05, 'clip_norm': 17, 'batch_size': 6, 'c_alpha': 0.087718048277364}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 14:16:40,535] Trial 803 finished with value: 0.0 and parameters: {'embed_dim': 481, 'hidden_dim': 415, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.24964361354041603, 'learning_rate': 0.0002695972976514007, 'clip_norm': 9, 'batch_size': 5, 'c_alpha': 0.054415497137565196}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 14:17:30,950] Trial 804 finished with value: 0.0 and parameters: {'embed_dim': 96, 'hidden_dim': 469, 'n_layers': 6, 'norm': 'batchnorm', 'dropout': 0.10983043335289791, 'learning_rate': 4.804210099772134e-05, 'clip_norm': 16, 'batch_size': 6, 'c_alpha': 0.016290450487567033}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 14:18:11,544] Trial 805 finished with value: 0.0 and parameters: {'embed_dim': 140, 'hidden_dim': 1947, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.07376801236427696, 'learning_rate': 9.795960265492037e-05, 'clip_norm': 19, 'batch_size': 6, 'c_alpha': 0.03329516740492343}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 14:19:49,061] Trial 806 finished with value: 0.0 and parameters: {'embed_dim': 217, 'hidden_dim': 283, 'n_layers': 6, 'norm': 'batchnorm', 'dropout': 0.08469985882877346, 'learning_rate': 0.0001485315864964056, 'clip_norm': 4, 'batch_size': 6, 'c_alpha': 0.012847884715856642}. Best is trial 752 with value: 0.34190896010425487.






[I 2025-02-18 14:20:30,741] Trial 807 finished with value: 0.0 and parameters: {'embed_dim': 184, 'hidden_dim': 356, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.46500157069329945, 'learning_rate': 8.076781908889194e-05, 'clip_norm': 18, 'batch_size': 3, 'c_alpha': 0.01897737703360634}. Best is trial 752 with value: 0.34190896010425487.
[I 2025-02-18 14:21:15,521] Trial 808 finished with value: 0.0 and parameters: {'embed_dim': 105, 'hidden_dim': 2019, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.09780709675345431, 'learning_rate': 0.00012120606914739909, 'clip_norm': 17, 'batch_size': 6, 'c_alpha': 0.01061110278432441}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 14:22:05,543] Trial 809 finished with value: 0.0 and parameters: {'embed_dim': 461, 'hidden_dim': 194, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.11771195926850372, 'learning_rate': 0.00010118940629898607, 'clip_norm': 1, 'batch_size': 6, 'c_alpha': 0.002980046465232003}. Best is trial 752 with value: 0.34190896010425487.






[I 2025-02-18 14:27:47,122] Trial 810 finished with value: 0.0 and parameters: {'embed_dim': 445, 'hidden_dim': 1905, 'n_layers': 6, 'norm': 'batchnorm', 'dropout': 0.3495729395305669, 'learning_rate': 5.946026388127592e-05, 'clip_norm': 5, 'batch_size': 6, 'c_alpha': 0.005528454333350833}. Best is trial 752 with value: 0.34190896010425487.
[I 2025-02-18 14:28:28,729] Trial 811 finished with value: 0.0 and parameters: {'embed_dim': 477, 'hidden_dim': 128, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.13019297542027727, 'learning_rate': 8.102029231532702e-05, 'clip_norm': 19, 'batch_size': 6, 'c_alpha': 0.02566837079434395}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 14:29:38,000] Trial 812 finished with value: 0.0 and parameters: {'embed_dim': 230, 'hidden_dim': 90, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.08966500936246546, 'learning_rate': 0.00018270251827991145, 'clip_norm': 5, 'batch_size': 4, 'c_alpha': 0.040399160905909316}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 14:30:15,792] Trial 813 finished with value: 0.0 and parameters: {'embed_dim': 158, 'hidden_dim': 394, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.10099837726346586, 'learning_rate': 0.00012181454970579427, 'clip_norm': 0, 'batch_size': 6, 'c_alpha': 0.015664102161016262}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 14:31:04,758] Trial 814 finished with value: 0.0 and parameters: {'embed_dim': 273, 'hidden_dim': 71, 'n_layers': 6, 'norm': 'batchnorm', 'dropout': 0.4007792298452206, 'learning_rate': 6.926601450217504e-05, 'clip_norm': 4, 'batch_size': 6, 'c_alpha': 0.013347615214050139}. Best is trial 752 with value: 0.34190896010425487.





[I 2025-02-18 14:31:41,045] Trial 815 finished with value: 0.0 and parameters: {'embed_dim': 115, 'hidden_dim': 242, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.37339359270857464, 'learning_rate': 0.00011183966087362417, 'clip_norm': 2, 'batch_size': 6, 'c_alpha': 0.0019682467751211204}. Best is trial 752 with value: 0.34190896010425487.



{'embed_dim': 461, 'hidden_dim': 331, 'n_layers': 1, 'norm': 'batchnorm', 'dropout': 0.08924372196854445, 'learning_rate': 0.00012309642970619275, 'clip_norm': 17, 'batch_size': 3, 'c_alpha': 0.0010341689454045453}


# Model training and evaluation

After we have a set of optimal hyperparameters we train our model. The train model function could be changed, but:
- test_dataset cannot be used until we call the final evaluation step
- the evaluation step cannot be modified, it must take the model produced by your pipeline, a dataloader that provides the correct data for your model, and the final metrics have to be printed

In [10]:
import scripts

%reload_ext autoreload

%autoreload 2

_, model = scripts.train_model(config, train_dataset,  validation_dataset=None, use_momentum=True, callback_epoch=None, collate_fn=collate_fn_custom)
device = torch.device(config["env"]["device"])
metrics = torchmetrics.MetricTracker(torchmetrics.MetricCollection(
    {"R_cellwise_residuals":scripts.GroupwiseMetric(metric=torchmetrics.functional.pearson_corrcoef,
                          grouping="drugs",
                          average="macro",
                          residualize=True),
    "R_cellwise":scripts.GroupwiseMetric(metric=torchmetrics.functional.pearson_corrcoef,
                          grouping="cell_lines",
                          average="macro",
                          residualize=False),
    "MSE":torchmetrics.MeanSquaredError()}))
metrics.to(device)
test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                      batch_size=config["optimizer"]["batch_size"],
                                      drop_last=False,
                                      shuffle=False,
                                      collate_fn=collate_fn_custom)


epoch : 0: train loss: 7.993910234708053 Smoothed R interaction (validation) None
epoch : 1: train loss: 6.8697050248201075 Smoothed R interaction (validation) None
epoch : 2: train loss: 6.528052566143183 Smoothed R interaction (validation) None
epoch : 3: train loss: 6.073817754021058 Smoothed R interaction (validation) None
epoch : 4: train loss: 5.35983569175005 Smoothed R interaction (validation) None
epoch : 5: train loss: 4.56787799184139 Smoothed R interaction (validation) None
epoch : 6: train loss: 4.327051956493121 Smoothed R interaction (validation) None
epoch : 7: train loss: 3.9188241351109285 Smoothed R interaction (validation) None
epoch : 8: train loss: 3.9119853199674535 Smoothed R interaction (validation) None
epoch : 9: train loss: 3.48958902576795 Smoothed R interaction (validation) None


KeyboardInterrupt: 

In [16]:
final_metrics = scripts.evaluate_step(model, test_dataloader, metrics, device)
print(final_metrics)

  return torch.linalg.solve(A, Xy).T


{'MSE': 2.363308906555176, 'R_cellwise': 0.2313278466463089, 'R_cellwise_residuals': 0.21419179439544678}
