In [1]:
import pandas as pd
import numpy as np
import torch
import scripts
from functools import lru_cache
import torchmetrics
from torch import nn
import optuna

  from .autonotebook import tqdm as notebook_tqdm


# Data loading

First we load the data. The basic idea is to create dictionaries with features associated to the drugs and cell-lines. In principle, the splits and the data shouldn't be changed

In [2]:
@lru_cache(maxsize = None)
def get_data(n_fold = 0, fp_radius = 2):
    smile_dict = pd.read_csv("data/smiles.csv", index_col=0)
    fp = scripts.FingerprintFeaturizer(R = fp_radius)
    drug_dict = fp(smile_dict.iloc[:, 1], smile_dict.iloc[:, 0])
    driver_genes = pd.read_csv("data/driver_genes.csv").loc[:, "symbol"].dropna()
    rnaseq = pd.read_csv("data/rnaseq_normcount.csv", index_col=0)
    driver_columns = rnaseq.columns.isin(driver_genes)
    filtered_rna = rnaseq.loc[:, driver_columns]
    tensor_exp = torch.Tensor(filtered_rna.to_numpy())
    cell_dict = {cell: tensor_exp[i] for i, cell in enumerate(filtered_rna.index.to_numpy())}
    data = pd.read_csv("data/GDSC12.csv", index_col=0)
    # default, remove data where lines or drugs are missing:
    data = data.query("SANGER_MODEL_ID in @cell_dict.keys() & DRUG_ID in @drug_dict.keys()")
    unique_cell_lines = data.loc[:, "SANGER_MODEL_ID"].unique()
    np.random.seed(420) # for comparibility, don't change it!
    np.random.shuffle(unique_cell_lines)
    folds = np.array_split(unique_cell_lines, 10)
    test_lines = folds[0]
    train_idxs = list(range(10))
    train_idxs.remove(n_fold)
    np.random.seed(420)
    validation_idx = np.random.choice(train_idxs)
    train_idxs.remove(validation_idx)
    train_lines = np.concatenate([folds[idx] for idx in train_idxs])
    validation_lines = folds[validation_idx]
    test_lines = folds[n_fold]
    train_data = data.query("SANGER_MODEL_ID in @train_lines")
    validation_data = data.query("SANGER_MODEL_ID in @validation_lines")
    test_data = data.query("SANGER_MODEL_ID in @test_lines")
    return (scripts.OmicsDataset_drugwise(cell_dict, drug_dict, train_data),
    scripts.OmicsDataset_drugwise(cell_dict, drug_dict, validation_data),
    scripts.OmicsDataset_drugwise(cell_dict, drug_dict, test_data))

# Configuration

we declare the configuration, this is going to be model-specific and we get the datasets

In [3]:
config = {"features" : {"fp_radius":2},
          "optimizer": {"batch_size": 4,
                        "clip_norm":19,
                        "learning_rate":0.0004592646200179472,
                        "stopping_patience":15},
          "model":{"embed_dim":485,
                 "hidden_dim":696,
                 "dropout":0.48541242824674574,
                 "n_layers": 4,
                 "norm": "batchnorm"},
         "env": {"fold": 0,
                 "device":"cpu",
                 "max_epochs": 100,
                 "search_hyperparameters":True}}

In [4]:
train_dataset, validation_dataset, test_dataset = get_data(n_fold = config["env"]["fold"],
                                                           fp_radius = config["features"]["fp_radius"])





In [5]:
cell_features, drug_features, target, drug_index, cell_indices, labels = train_dataset[4]

ins = train_dataset[7]
print(ins[5].shape)
print(ins[4].shape)

print("Drug Features Shape:", drug_features.shape)
print("Cell Features Shape:", cell_features.shape)  
print("Labels Shape:", labels.shape)
print("Target Shape:", target.shape)               
print("Drug Index Shape:", drug_index.shape)      
print("Cell Indices Shape:", cell_indices.shape) 




torch.Size([325])
torch.Size([325])
Drug Features Shape: torch.Size([323, 2048])
Cell Features Shape: torch.Size([323, 777])
Labels Shape: torch.Size([323])
Target Shape: torch.Size([323])
Drug Index Shape: torch.Size([323])
Cell Indices Shape: torch.Size([323])


In [None]:

def collate_fn_custom(batch):
    omics = []
    drugs = []
    targets = []
    cell_ids = []
    drug_ids = []
    labels = [] 

    for o, d, t, cid, did, r in batch:
        omics.append(o)  # Add each omics tensor
        drugs.append(d)  # Add each drugs tensor
        targets.append(t)  # Add each target tensor
        cell_ids.append(cid)  # Add each cell ID tensor
        drug_ids.append(did)  # Add each drug ID tensor
        labels.append(r)  # Add each label tensor

    # Concatenate along the first dimension for all tensors
    omics = torch.cat(omics, dim=0)
    drugs = torch.cat(drugs, dim=0)
    targets = torch.cat(targets, dim=0)
    cell_ids = torch.cat(cell_ids, dim=0)
    drug_ids = torch.cat(drug_ids, dim=0)
    labels = torch.cat(labels, dim=0)

    return omics, drugs, targets, cell_ids, drug_ids, labels


# Hyperparameter optimization

we wrap the function for training the model in a function that can be used by optuna

In [6]:
def train_model_optuna(trial, config):
    def pruning_callback(epoch, train_r):
        trial.report(train_r, step = epoch)
        if np.isnan(train_r):
            raise optuna.TrialPruned()
        if trial.should_prune():
            raise optuna.TrialPruned()
    config["model"] = {"embed_dim": trial.suggest_int("embed_dim", 64, 512),
                    "hidden_dim": trial.suggest_int("hidden_dim", 64, 2048),
                    "n_layers": trial.suggest_int("n_layers", 1, 6),
                    "norm": trial.suggest_categorical("norm", ["batchnorm", "layernorm", None]),
                    "dropout": trial.suggest_float("dropout", 0.0, 0.5)}
    config["optimizer"] = { "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-1, log=True),
                            "clip_norm": trial.suggest_int("clip_norm", 0.1, 20),
                            "batch_size": trial.suggest_int("batch_size", 2, 10),
                            "stopping_patience":10}
    try:
        R, model = scripts.train_model(config,
                                       train_dataset,
                                       validation_dataset,
                                       use_momentum=True,
                                       callback_epoch = pruning_callback,
                                       collate_fn=collate_fn_custom)
        
        return R
    except Exception as e:
        print(e)
        return 0

In [26]:
'''
def collate_fn_custom(batch):
    import torch

    omics = []
    drugs = []
    targets = []
    cell_ids = []
    drug_ids = []
    labels = [] 

    for o, d, t, cid, did, r in batch:
        omics.append(o)  # Add each omics tensor
        drugs.append(d)  # Add each drugs tensor
        targets.append(t)  # Add each target tensor
        cell_ids.append(cid)  # Add each cell ID tensor
        drug_ids.append(did)  # Add each drug ID tensor
        labels.append(r)  # Add each label tensor

    # Concatenate along the first dimension for all tensors
    omics = torch.cat(omics, dim=0)
    drugs = torch.cat(drugs, dim=0)
    targets = torch.cat(targets, dim=0)
    cell_ids = torch.cat(cell_ids, dim=0)
    drug_ids = torch.cat(drug_ids, dim=0)
    labels = torch.cat(labels, dim=0)

    # Split each tensor into two halves
    size = omics.size(0)
    if size % 2 != 0:
        omics = omics[:size - 1]
        drugs = drugs[:size - 1]
        targets = targets[:size - 1]
        cell_ids = cell_ids[:size - 1]
        drug_ids = drug_ids[:size - 1]
        labels = labels[:size - 1]


    mid_index = omics.size(0) // 2

    # Combine the halves along the last dimension
    
    #omics_combined = torch.cat((omics[:mid_index], omics[mid_index:]), dim=-1)
    #drugs_combined = torch.cat((drugs[:mid_index], drugs[mid_index:]), dim=-1)
    #targets_combined = torch.cat((targets[:mid_index], targets[mid_index:]), dim=-1)
    #cell_ids_combined = torch.cat((cell_ids[:mid_index], cell_ids[mid_index:]), dim=-1)
    #drug_ids_combined = torch.cat((drug_ids[:mid_index], drug_ids[mid_index:]), dim=-1)
    #labels_combined = torch.cat((labels[:mid_index], labels[mid_index:]), dim=-1)
    
    omics1 = omics[:mid_index]
    omics2 = omics[mid_index:]
    drugs1 = drugs[:mid_index]
    drugs2 = drugs[mid_index:]
    targets1 = targets[:mid_index]
    targets2 = targets[mid_index:]
    cell_ids1 = cell_ids[:mid_index]
    cell_ids2 = cell_ids[mid_index:]
    drug_ids1 = drug_ids[:mid_index]
    drug_ids2 = drug_ids[mid_index:]
    labels1 = labels[:mid_index]
    labels2 = labels[mid_index:]

    pairs = []    
    for i in range(0,mid_index):
        if labels1[i] == 0:
            pairs.append(0)
        elif labels2[i] == 0:
            pairs.append(0)
        elif labels1[i] == labels2[i]:
            pairs.append(1) # positive pairs
        else:
            pairs.append(-1)  # negative pairs 
        
    



    return omics1, drugs1, targets1, cell_ids1, drug_ids1, labels1, omics2, drugs2, targets2, cell_ids2, drug_ids2, labels2, pairs
'''

In [None]:



if config["env"]["search_hyperparameters"]:
    study_name = f"baseline_model"
    storage_name = "sqlite:///studies/{}.db".format(study_name)
    study = optuna.create_study(study_name=study_name,
                                storage=storage_name,
                                direction='maximize',
                                load_if_exists=True,
                                pruner=optuna.pruners.MedianPruner(n_startup_trials=30,
                                                               n_warmup_steps=5,
                                                               interval_steps=5))
    objective = lambda x: train_model_optuna(x, config)
    study.optimize(objective, n_trials=40)
    best_config = study.best_params
    print(best_config)
    config["model"]["embed_dim"] = best_config["embed_dim"]
    config["model"]["hidden_dim"] = best_config["hidden_dim"]
    config["model"]["n_layers"] = best_config["n_layers"]
    config["model"]["norm"] = best_config["norm"]
    config["model"]["dropout"] = best_config["dropout"]
    config["optimizer"]["learning_rate"] = best_config["learning_rate"]
    config["optimizer"]["clip_norm"] = best_config["clip_norm"]
    config["optimizer"]["batch_size"] = best_config["batch_size"]



[I 2025-01-29 11:36:46,310] Using an existing study with name 'baseline_model' instead of creating a new one.


'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:37:12,866] Trial 440 finished with value: 0.0 and parameters: {'embed_dim': 143, 'hidden_dim': 1199, 'n_layers': 6, 'norm': 'batchnorm', 'dropout': 0.3843137053806652, 'learning_rate': 0.016536121149937466, 'clip_norm': 0, 'batch_size': 8}. Best is trial 379 with value: 0.3149450400351803.
[I 2025-01-29 11:37:23,147] Trial 441 finished with value: 0.0 and parameters: {'embed_dim': 261, 'hidden_dim': 684, 'n_layers': 1, 'norm': 'layernorm', 'dropout': 0.44167644863012967, 'learning_rate': 1.0239342840802722e-05, 'clip_norm': 7, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:37:36,428] Trial 442 finished with value: 0.0 and parameters: {'embed_dim': 242, 'hidden_dim': 899, 'n_layers': 3, 'norm': None, 'dropout': 0.05107493962554678, 'learning_rate': 1.6232107518432208e-06, 'clip_norm': 12, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:38:00,797] Trial 443 finished with value: 0.0 and parameters: {'embed_dim': 323, 'hidden_dim': 1270, 'n_layers': 5, 'norm': 'batchnorm', 'dropout': 0.271367200840901, 'learning_rate': 5.43934195510429e-06, 'clip_norm': 8, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:38:14,611] Trial 444 finished with value: 0.0 and parameters: {'embed_dim': 399, 'hidden_dim': 417, 'n_layers': 2, 'norm': 'layernorm', 'dropout': 0.17726656087816606, 'learning_rate': 7.502472654988606e-05, 'clip_norm': 2, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:38:57,047] Trial 445 finished with value: 0.0 and parameters: {'embed_dim': 484, 'hidden_dim': 1835, 'n_layers': 6, 'norm': None, 'dropout': 0.23934136373100728, 'learning_rate': 7.278708697276208e-06, 'clip_norm': 6, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:39:28,512] Trial 446 finished with value: 0.0 and parameters: {'embed_dim': 371, 'hidden_dim': 1544, 'n_layers': 4, 'norm': 'batchnorm', 'dropout': 0.06746828575507618, 'learning_rate': 0.000679874455054611, 'clip_norm': 5, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:39:59,499] Trial 447 finished with value: 0.0 and parameters: {'embed_dim': 418, 'hidden_dim': 1620, 'n_layers': 5, 'norm': 'layernorm', 'dropout': 0.07941659447092617, 'learning_rate': 3.301994818743923e-06, 'clip_norm': 18, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:40:14,542] Trial 448 finished with value: 0.0 and parameters: {'embed_dim': 501, 'hidden_dim': 578, 'n_layers': 2, 'norm': None, 'dropout': 0.31434577481329384, 'learning_rate': 0.029754939806525237, 'clip_norm': 11, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:40:33,544] Trial 449 finished with value: 0.0 and parameters: {'embed_dim': 86, 'hidden_dim': 1965, 'n_layers': 3, 'norm': 'batchnorm', 'dropout': 0.29898780662557106, 'learning_rate': 2.1581724401307466e-06, 'clip_norm': 20, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:40:44,693] Trial 450 finished with value: 0.0 and parameters: {'embed_dim': 267, 'hidden_dim': 2042, 'n_layers': 1, 'norm': 'layernorm', 'dropout': 0.3484875286693332, 'learning_rate': 1.3608322362620205e-06, 'clip_norm': 14, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:41:11,049] Trial 451 finished with value: 0.0 and parameters: {'embed_dim': 348, 'hidden_dim': 1760, 'n_layers': 4, 'norm': None, 'dropout': 0.15052937542094907, 'learning_rate': 4.397980651010733e-05, 'clip_norm': 4, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:41:42,746] Trial 452 finished with value: 0.0 and parameters: {'embed_dim': 149, 'hidden_dim': 1886, 'n_layers': 6, 'norm': 'layernorm', 'dropout': 0.10626788019512035, 'learning_rate': 0.08430341634683529, 'clip_norm': 19, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:42:03,049] Trial 453 finished with value: 0.0 and parameters: {'embed_dim': 188, 'hidden_dim': 1437, 'n_layers': 4, 'norm': 'batchnorm', 'dropout': 0.42203488845827825, 'learning_rate': 4.590221045007482e-06, 'clip_norm': 8, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:42:15,439] Trial 454 finished with value: 0.0 and parameters: {'embed_dim': 219, 'hidden_dim': 1233, 'n_layers': 2, 'norm': None, 'dropout': 0.04874161811992537, 'learning_rate': 2.326101907592821e-05, 'clip_norm': 3, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:42:30,637] Trial 455 finished with value: 0.0 and parameters: {'embed_dim': 163, 'hidden_dim': 1487, 'n_layers': 3, 'norm': 'layernorm', 'dropout': 0.4095226473976821, 'learning_rate': 8.741126928797564e-06, 'clip_norm': 18, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:42:59,555] Trial 456 finished with value: 0.0 and parameters: {'embed_dim': 409, 'hidden_dim': 1357, 'n_layers': 5, 'norm': 'batchnorm', 'dropout': 0.19988781837343222, 'learning_rate': 1.4495020685863254e-05, 'clip_norm': 13, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'
'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:43:05,602] Trial 457 finished with value: 0.0 and parameters: {'embed_dim': 77, 'hidden_dim': 859, 'n_layers': 1, 'norm': 'layernorm', 'dropout': 0.03381489500102286, 'learning_rate': 0.00011579814897977366, 'clip_norm': 4, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.
[I 2025-01-29 11:43:26,197] Trial 458 finished with value: 0.0 and parameters: {'embed_dim': 438, 'hidden_dim': 1800, 'n_layers': 2, 'norm': None, 'dropout': 0.2554275231827138, 'learning_rate': 0.012032704393884533, 'clip_norm': 17, 'batch_size': 5}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:43:57,875] Trial 459 finished with value: 0.0 and parameters: {'embed_dim': 381, 'hidden_dim': 1289, 'n_layers': 4, 'norm': 'batchnorm', 'dropout': 0.2261495690894432, 'learning_rate': 0.001297473970520731, 'clip_norm': 10, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:44:06,883] Trial 460 finished with value: 0.0 and parameters: {'embed_dim': 64, 'hidden_dim': 160, 'n_layers': 3, 'norm': 'layernorm', 'dropout': 0.12213504661406405, 'learning_rate': 1.2541336753375493e-06, 'clip_norm': 1, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:44:24,848] Trial 461 finished with value: 0.0 and parameters: {'embed_dim': 490, 'hidden_dim': 69, 'n_layers': 5, 'norm': None, 'dropout': 0.01259988421949989, 'learning_rate': 0.002610018626556538, 'clip_norm': 6, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'
'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:44:59,792] Trial 462 finished with value: 0.0 and parameters: {'embed_dim': 474, 'hidden_dim': 964, 'n_layers': 6, 'norm': 'batchnorm', 'dropout': 0.058865825518531506, 'learning_rate': 6.205667549511449e-05, 'clip_norm': 9, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.
[I 2025-01-29 11:45:18,764] Trial 463 finished with value: 0.0 and parameters: {'embed_dim': 115, 'hidden_dim': 1179, 'n_layers': 4, 'norm': 'layernorm', 'dropout': 0.042866192080393856, 'learning_rate': 2.731638805695992e-06, 'clip_norm': 5, 'batch_size': 9}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:45:37,289] Trial 464 finished with value: 0.0 and parameters: {'embed_dim': 251, 'hidden_dim': 1737, 'n_layers': 2, 'norm': None, 'dropout': 0.1639356664392741, 'learning_rate': 0.0018034225501766719, 'clip_norm': 0, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:46:15,919] Trial 465 finished with value: 0.0 and parameters: {'embed_dim': 335, 'hidden_dim': 1853, 'n_layers': 4, 'norm': 'batchnorm', 'dropout': 0.47611440580222514, 'learning_rate': 1.7895298358673322e-06, 'clip_norm': 3, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:46:42,232] Trial 466 finished with value: 0.0 and parameters: {'embed_dim': 298, 'hidden_dim': 1394, 'n_layers': 3, 'norm': 'layernorm', 'dropout': 0.023832555564478566, 'learning_rate': 1.1182457494874362e-05, 'clip_norm': 8, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:46:50,139] Trial 467 finished with value: 0.0 and parameters: {'embed_dim': 172, 'hidden_dim': 1675, 'n_layers': 1, 'norm': None, 'dropout': 0.2951588350455141, 'learning_rate': 4.981287665443853e-05, 'clip_norm': 2, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:47:28,174] Trial 468 finished with value: 0.0 and parameters: {'embed_dim': 133, 'hidden_dim': 1465, 'n_layers': 6, 'norm': 'batchnorm', 'dropout': 0.32560518290716983, 'learning_rate': 5.702662041440673e-06, 'clip_norm': 19, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:47:45,208] Trial 469 finished with value: 0.0 and parameters: {'embed_dim': 154, 'hidden_dim': 1991, 'n_layers': 2, 'norm': 'layernorm', 'dropout': 0.3068213829439692, 'learning_rate': 4.088551055561867e-06, 'clip_norm': 7, 'batch_size': 8}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:48:12,520] Trial 470 finished with value: 0.0 and parameters: {'embed_dim': 512, 'hidden_dim': 1517, 'n_layers': 3, 'norm': None, 'dropout': 0.2082919560964289, 'learning_rate': 1.8298378171102916e-05, 'clip_norm': 4, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'


[I 2025-01-29 11:48:46,248] Trial 471 finished with value: 0.0 and parameters: {'embed_dim': 288, 'hidden_dim': 1910, 'n_layers': 4, 'norm': 'batchnorm', 'dropout': 0.0625917757491359, 'learning_rate': 0.005824544863249496, 'clip_norm': 18, 'batch_size': 10}. Best is trial 379 with value: 0.3149450400351803.


'tuple' object has no attribute 'squeeze'


# Model training and evaluation

After we have a set of optimal hyperparameters we train our model. The train model function could be changed, but:
- test_dataset cannot be used until we call the final evaluation step
- the evaluation step cannot be modified, it must take the model produced by your pipeline, a dataloader that provides the correct data for your model, and the final metrics have to be printed

In [9]:
import scripts

%reload_ext autoreload

%autoreload 2

_, model = scripts.train_model(config, train_dataset,  validation_dataset=None, use_momentum=True, callback_epoch=None, collate_fn=collate_fn_custom)
device = torch.device(config["env"]["device"])
metrics = torchmetrics.MetricTracker(torchmetrics.MetricCollection(
    {"R_cellwise_residuals":scripts.GroupwiseMetric(metric=torchmetrics.functional.pearson_corrcoef,
                          grouping="drugs",
                          average="macro",
                          residualize=True),
    "R_cellwise":scripts.GroupwiseMetric(metric=torchmetrics.functional.pearson_corrcoef,
                          grouping="cell_lines",
                          average="macro",
                          residualize=False),
    "MSE":torchmetrics.MeanSquaredError()}))
metrics.to(device)
test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                      batch_size=config["optimizer"]["batch_size"],
                                      drop_last=False,
                                      shuffle=False,
                                      collate_fn=collate_fn_custom)


epoch : 0: train loss: 9.693892140542307 Smoothed R interaction (validation) None
epoch : 1: train loss: 7.8705514015689975 Smoothed R interaction (validation) None
epoch : 2: train loss: 7.202351277874362 Smoothed R interaction (validation) None
epoch : 3: train loss: 6.6356829558649375 Smoothed R interaction (validation) None
epoch : 4: train loss: 6.731945237805767 Smoothed R interaction (validation) None
epoch : 5: train loss: 6.667819669169765 Smoothed R interaction (validation) None


KeyboardInterrupt: 

In [32]:
final_metrics = scripts.evaluate_step(model, test_dataloader, metrics, device)
print(final_metrics)

  return torch.linalg.solve(A, Xy).T


{'MSE': 1.8528733253479004, 'R_cellwise': 0.31360891461372375, 'R_cellwise_residuals': 0.25742870569229126}
