# Generate train test splits for harmonized cell line data

In [None]:
from __future__ import annotations

import pickle

import pandas as pd
import polars as pl
import numpy as np

from pathlib import Path
from sklearn.model_selection import StratifiedKFold, train_test_split, KFold

In [None]:
SEED = 41
np.random.seed(SEED)

In [None]:
data_folder = Path("../../data/inputs/GDSCCellLine")
model_info = pd.read_csv(data_folder / "MetaModelAnnotations.csv")
model_info.head()

In [None]:
fitted_dose_response = pd.read_csv(data_folder / "LabelDoseResponse.csv")
fitted_dose_response.head()

In [None]:
def strict_train_validation_split(
    train_model_ids: np.ndarray, train_tissues: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
    """"""
    uniq_tissues, tissue_counts = np.unique(train_tissues, return_counts=True)
    keep_tissues = uniq_tissues[tissue_counts >= 2]

    mask = np.isin(train_tissues, keep_tissues)
    train_tissues_subset = train_tissues[mask]
    train_model_ids_subset = train_model_ids[mask]

    _, val_model_ids = train_test_split(
        train_model_ids_subset,
        random_state=SEED,
        stratify=train_tissues_subset,
        test_size=0.11,
    )

    train_model_ids = train_model_ids[
        np.isin(train_model_ids, val_model_ids, invert=True)
    ]

    return train_model_ids, val_model_ids

In [None]:
# tumor blind, drug blind, mixed, disjoint

In [None]:
strict_split_folder = Path("../../data/inputs/GDSCCellLine/splits/tumor_blind")
strict_split_folder.mkdir(exist_ok=True, parents=True)

model_ids = model_info["model_id"].to_numpy()
tissues = model_info["cancer_type"].to_numpy()

skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=SEED)
split_iterator = skf.split(model_ids, tissues)

for i, (train_idx, test_idx) in enumerate(split_iterator, 1):
    train_tissues = tissues[train_idx]
    train_model_ids = model_ids[train_idx]
    test_model_ids = model_ids[test_idx]

    train_model_ids, val_model_ids = strict_train_validation_split(
        train_model_ids, train_tissues
    )

    train_obs_ids = fitted_dose_response[
        fitted_dose_response["cell_id"].isin(train_model_ids)
    ]["id"].to_list()
    val_obs_ids = fitted_dose_response[
        fitted_dose_response["cell_id"].isin(val_model_ids)
    ]["id"].to_list()
    test_obs_ids = fitted_dose_response[
        fitted_dose_response["cell_id"].isin(test_model_ids)
    ]["id"].to_list()

    train_path = Path(strict_split_folder / f"train_{i}.pickle")
    val_path = Path(strict_split_folder / f"val_{i}.pickle")
    test_path = Path(strict_split_folder / f"test_{i}.pickle")

    with open(train_path, "wb") as fh:
        pickle.dump(train_obs_ids, fh)

    with open(val_path, "wb") as fh:
        pickle.dump(val_obs_ids, fh)

    with open(test_path, "wb") as fh:
        pickle.dump(test_obs_ids, fh)

In [None]:
mixed_split_folder = Path("../../data/inputs/GDSCCellLine/splits/mixed")
mixed_split_folder.mkdir(exist_ok=True, parents=True)

In [None]:
mixed_split_folder

In [None]:
obs_ids = fitted_dose_response["id"].to_numpy()
cell_ids = fitted_dose_response["cell_id"].to_numpy()

skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=SEED)

split_iterator = skf.split(obs_ids, cell_ids)

for i, (train_idx, test_idx) in enumerate(split_iterator, 1):
    train_cell_ids = cell_ids[train_idx]
    train_obs_ids = obs_ids[train_idx]
    test_obs_ids = obs_ids[test_idx]

    train_obs_ids, val_obs_ids = train_test_split(
        train_obs_ids,
        random_state=SEED,
        stratify=train_cell_ids,
        test_size=0.11,
    )
    train_obs_ids = train_obs_ids.tolist()
    val_obs_ids = val_obs_ids.tolist()
    test_obs_ids = test_obs_ids.tolist()

    train_path = Path(mixed_split_folder / f"train_{i}.pickle")
    val_path = Path(mixed_split_folder / f"val_{i}.pickle")
    test_path = Path(mixed_split_folder / f"test_{i}.pickle")

    with open(train_path, "wb") as fh:
        pickle.dump(train_obs_ids, fh)

    with open(val_path, "wb") as fh:
        pickle.dump(val_obs_ids, fh)

    with open(test_path, "wb") as fh:
        pickle.dump(test_obs_ids, fh)

In [None]:
# NOTE: what I need to do is to use the index as the split instead of the model_id

In [None]:

# NOTE: now use kfold with the index
# TODO:
#   - [X] refactor split loaders (run_screendl)
#   - [X] refactor the split file
#   - [X] add a new method for selecting on the index
#   - [X] regenerate the splits for strict versions
#   - [X] generate splits for the lenient versions
#   - [X] fix the fold_i+1 issues
#   - [ ] rerun with the new splits for the 4-way experiment
#   - [ ] once this is running, commit everything
#   - [ ] When loading the dataset, I could set the index column and then use .loc
#   - [ ] Pander schema validation?

In [None]:
# Just do a kfold cv across the responses