# Generate Splits for the GDSC/DepMap dataset

## TODO

- [ ] Add check that there are no drugs in test that are not in train

In [None]:
from __future__ import annotations

import pickle

import pandas as pd
import polars as pl
import numpy as np

from pathlib import Path
from sklearn.model_selection import StratifiedKFold, KFold, train_test_split

In [None]:
SEED = 41
np.random.seed(SEED)

In [None]:
data_folder = Path("../../../data/datasets/GDSCv2DepMap")
split_folder = Path("../../../data/inputs/GDSCv2DepMap/splits")
split_folder.mkdir(exist_ok=True)

In [None]:
model_info = pd.read_csv(data_folder / "CellLineOncotreeAnnotations.csv")
model_info.head()

In [None]:
screen_data = pd.read_csv(data_folder / "ScreenDoseResponseLabels.csv")
screen_data.head()

In [None]:
def strict_train_validation_split(
    train_model_ids: np.ndarray, train_tissues: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
    """"""
    uniq_tissues, tissue_counts = np.unique(train_tissues, return_counts=True)
    keep_tissues = uniq_tissues[tissue_counts >= 2]

    mask = np.isin(train_tissues, keep_tissues)
    train_tissues_subset = train_tissues[mask]
    train_model_ids_subset = train_model_ids[mask]

    _, val_model_ids = train_test_split(
        train_model_ids_subset,
        random_state=SEED,
        stratify=train_tissues_subset,
        test_size=0.11,
    )

    train_model_ids = train_model_ids[
        np.isin(train_model_ids, val_model_ids, invert=True)
    ]

    return train_model_ids, val_model_ids

In [None]:
tumor_blind_folder = Path(split_folder / "tumor_blind")
tumor_blind_folder.mkdir(exist_ok=True)

In [None]:
cell_ids = model_info["model_id"].values
tissues = model_info["oncotree_lineage"].values

skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=SEED)
split_iterator = skf.split(cell_ids, tissues)
for i, (train_idx, test_idx) in enumerate(split_iterator, 1):
    train_tissues = tissues[train_idx]
    train_cell_ids = cell_ids[train_idx]
    test_cell_ids = cell_ids[test_idx]

    train_cell_ids, val_cell_ids = strict_train_validation_split(
        train_cell_ids, train_tissues
    )

    train_ids = screen_data[screen_data["cell_id"].isin(train_cell_ids)][
        "id"
    ].to_list()
    val_ids = screen_data[screen_data["cell_id"].isin(val_cell_ids)][
        "id"
    ].to_list()
    test_ids = screen_data[screen_data["cell_id"].isin(test_cell_ids)][
        "id"
    ].to_list()

    with open(tumor_blind_folder / f"train_{i}.pickle", "wb") as fh:
        pickle.dump(train_ids, fh)

    with open(tumor_blind_folder / f"val_{i}.pickle", "wb") as fh:
        pickle.dump(val_ids, fh)

    with open(tumor_blind_folder / f"test_{i}.pickle", "wb") as fh:
        pickle.dump(test_ids, fh)

In [None]:
mixed_folder = split_folder / "mixed"
mixed_folder.mkdir(exist_ok=True, parents=True)

In [None]:
obs_ids = screen_data["id"].values
cell_ids = screen_data["cell_id"].values

skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=SEED)
split_iterator = skf.split(obs_ids, cell_ids)

for i, (train_idx, test_idx) in enumerate(split_iterator, 1):
    train_cell_ids = cell_ids[train_idx]
    train_obs_ids = obs_ids[train_idx]
    test_obs_ids = obs_ids[test_idx]

    train_obs_ids, val_obs_ids = train_test_split(
        train_obs_ids,
        random_state=SEED,
        stratify=train_cell_ids,
        test_size=0.11,
    )

    with open(mixed_folder / f"train_{i}.pickle", "wb") as fh:
        pickle.dump(train_obs_ids.tolist(), fh)

    with open(mixed_folder / f"val_{i}.pickle", "wb") as fh:
        pickle.dump(val_obs_ids.tolist(), fh)

    with open(mixed_folder / f"test_{i}.pickle", "wb") as fh:
        pickle.dump(test_obs_ids.tolist(), fh)