In [1]:
import json

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import train_test_split

# Validation split (Binary, Zero shot prediction)

- This code prepares a dataset for zero-shot prediction in drug response modeling:
    - Loads and processes a drug response dataset.
    - Splits data into train, validation, and test sets.
    - Ensures no overlap of cell lines or compounds between train and test sets.
    - Creates features (X) and labels (y) for each set.
    - Saves processed data as CSV files and NumPy arrays.
- Key aspects of the zero-shot prediction setup:
    - Completely different cell lines and compounds in train vs. test sets.
    - Evaluates the model's ability to predict for unseen cell-compound combinations.
    - Tests generalization to novel cell lines and compounds.
- Data refinement for zero-shot setup:
    - Filters and concatenates entries based on NSCs associated with binary class 0, refining the dataset to focus on underrepresented data points for improved model training on these specific NSCs.
    - Removes rows that are duplicates based on the combination of `NSC` and `CELL_NAME` to prevent any bias or error in predictions due to duplicate entries. This ensures that the refined dataset used for zero-shot prediction contains only unique combinations of NSC and cell names, maintaining the integrity of the model’s evaluation.

In [2]:
df = pd.read_csv("../data/IC50wBinary.csv.gz", index_col=0)
df

Unnamed: 0,NSC,CELL_NAME,binary
0,1,786-0,0.0
1,1,A498,0.0
2,1,A549/ATCC,0.0
3,1,ACHN,0.0
4,1,BT-549,0.0
...,...,...,...
331553,837892,TK-10,0.0
331554,837892,U251,0.0
331555,837892,UACC-257,0.0
331556,837892,UACC-62,0.0


In [3]:
with open("name_map.json", "r") as f:
    name_map = json.load(f)

In [4]:
df["CELL_NAME"] = [name_map[i] for i in df["CELL_NAME"]]
cols = pd.read_csv("../data/nci60_gene_exp.csv", index_col=0).columns
df = df[df.CELL_NAME.isin(cols)]
df

Unnamed: 0,NSC,CELL_NAME,binary
0,1,786_0,0.0
1,1,A498,0.0
2,1,A549,0.0
3,1,ACHN,0.0
4,1,BT_549,0.0
...,...,...,...
331553,837892,TK_10,0.0
331554,837892,U251,0.0
331555,837892,UACC_257,0.0
331556,837892,UACC_62,0.0


In [5]:
df["binary"].value_counts()

binary
1.0    174669
0.0    141109
Name: count, dtype: int64

In [6]:
duplicated_df = df[df.duplicated(subset=["NSC", "CELL_NAME"], keep=False)]
binary_counts = duplicated_df["binary"].value_counts()
binary_counts

binary
0.0    23
1.0     9
Name: count, dtype: int64

In [7]:
df = df[~df.duplicated(subset=["NSC", "CELL_NAME"], keep=False)]
df

Unnamed: 0,NSC,CELL_NAME,binary
0,1,786_0,0.0
1,1,A498,0.0
2,1,A549,0.0
3,1,ACHN,0.0
4,1,BT_549,0.0
...,...,...,...
331553,837892,TK_10,0.0
331554,837892,U251,0.0
331555,837892,UACC_257,0.0
331556,837892,UACC_62,0.0


In [8]:
df["binary"].value_counts()

binary
1.0    174660
0.0    141086
Name: count, dtype: int64

In [9]:
drug_list = torch.load("../data/drug_list.pt", weights_only=False)

In [10]:
df = df[df.NSC.isin(drug_list)]
df

Unnamed: 0,NSC,CELL_NAME,binary
61,186,A498,0.0
62,186,A549,0.0
63,186,CAKI_1,0.0
64,186,CCRF_CEM,0.0
65,186,COLO205,0.0
...,...,...,...
327432,820799,TK_10,1.0
327433,820799,U251,1.0
327434,820799,UACC_257,1.0
327435,820799,UACC_62,1.0


In [11]:
unique_drugs_per_cell = df.groupby("CELL_NAME")["NSC"].nunique().sort_values()
unique_drugs_per_cell

CELL_NAME
MDA_N          315
BT_549         790
HS578T         797
T47D           804
SR             809
MDA_MB_231     810
PC_3           811
DU_145         811
MDA_MB_435     821
M14            862
TK_10          864
786_0          869
ACHN           872
MCF7           872
NCI_ADR_RES    881
A498           892
EKVX           895
HOP_92         897
RXF_393        912
HL_60          913
HCC_2998       915
NCI_H522       917
RPMI_8226      919
MALME_3M       919
UACC_62        920
CAKI_1         921
SNB_75         922
NCI_H226       923
SK_MEL_2       924
SK_OV_3        924
CCRF_CEM       925
SK_MEL_5       926
LOXIMVI        929
K_562          929
OVCAR_4        932
NCI_H322M      934
OVCAR_3        934
IGROV1         934
UACC_257       934
SK_MEL_28      936
NCI_H23        936
SF_295         936
SF_539         937
UO_31          938
NCI_H460       938
SNB_19         940
HT29           940
OVCAR_5        940
KM12           941
COLO205        942
SN12C          942
HOP_62         943
MO

In [12]:
unique_cells_per_drug = df.groupby("NSC")["CELL_NAME"].nunique().sort_values()
unique_cells_per_drug

NSC
663883    33
656160    33
123538    33
656159    36
645797    37
          ..
83950     60
89671     60
93134     60
654663    60
261726    60
Name: CELL_NAME, Length: 952, dtype: int64

In [13]:
genes = pd.read_csv("../data/genes.csv").T
dti = pd.read_csv("../data/drug_gene_score.csv.gz")
dti = dti[dti.gene.isin(list(genes.index))]
dti

Unnamed: 0,NSC,gene,PMID_count,log,Y,log_Y,CID
0,3188.0,AAK1,0.0,0.0,0.5,0.5,4.0
1,3188.0,ADRB1,0.0,0.0,0.5,0.5,4.0
2,3188.0,BMP2K,0.0,0.0,0.5,0.5,4.0
5,3188.0,CACNB3,0.0,0.0,0.5,0.5,4.0
10,3188.0,CREBBP,0.0,0.0,0.5,0.5,4.0
...,...,...,...,...,...,...,...
1038831,852991.0,BRAF,0.0,0.0,0.5,0.5,156297592.0
1038832,852991.0,NRAS,0.0,0.0,0.5,0.5,156297592.0
1038833,841442.0,ATM,0.0,0.0,0.5,0.5,156487652.0
1038834,841442.0,CDK12,0.0,0.0,0.5,0.5,156487652.0


In [14]:
# Get unique CELL_NAME and NSC
unique_cells = df["CELL_NAME"].unique()
unique_nscs = df["NSC"].unique()

print("unique cells: ", len(unique_cells))
print("unique nscs: ", len(unique_nscs))

unique cells:  60
unique nscs:  952


In [15]:
np.random.seed(42)
train_cells = np.random.choice(
    unique_cells, size=int(len(unique_cells) * 0.7), replace=False
)
test_cells = np.setdiff1d(unique_cells, train_cells)

train_nscs = np.random.choice(
    unique_nscs, size=int(len(unique_nscs) * 0.6), replace=False
)
test_nscs = np.setdiff1d(unique_nscs, train_nscs)

In [16]:
train = df[df.NSC.isin(train_nscs) & df.CELL_NAME.isin(train_cells)]
test = df[df.NSC.isin(test_nscs) & df.CELL_NAME.isin(test_cells)]

train = train.sample(frac=1, random_state=42)
test = test.sample(frac=1, random_state=42)

val_size = int(len(train) * 0.2)
val = train.sample(n=val_size, random_state=42)
train = train.drop(val.index)

In [17]:
# Check the number of unique NSCs and CELL_NAMEs for each set
for name, dataset in [("Train", train), ("Validation", val), ("Test", test)]:
    print(f"\n{name} set:")
    print("Total number of data:", len(dataset))
    print(f"Number of unique NSCs: {dataset['NSC'].nunique()}")
    print(f"Number of unique CELL_NAMEs: {dataset['CELL_NAME'].nunique()}")


Train set:
Total number of data: 18067
Number of unique NSCs: 571
Number of unique CELL_NAMEs: 42

Validation set:
Total number of data: 4516
Number of unique NSCs: 571
Number of unique CELL_NAMEs: 42

Test set:
Total number of data: 6525
Number of unique NSCs: 381
Number of unique CELL_NAMEs: 18


In [18]:
# Check for overlaps
train_nscs = set(train["NSC"])
train_cells = set(train["CELL_NAME"])
test_nscs = set(test["NSC"])
test_cells = set(test["CELL_NAME"])

print("\nOverlap check:")
print(f"NSC overlap between train and test: {len(train_nscs.intersection(test_nscs))}")
print(
    f"CELL_NAME overlap between train and test: {len(train_cells.intersection(test_cells))}"
)


Overlap check:
NSC overlap between train and test: 0
CELL_NAME overlap between train and test: 0


In [19]:
y_train = list(train["binary"])
y_val = list(val["binary"])
y_test = list(test["binary"])

X_train = train.drop("binary", axis=1)
X_val = val.drop("binary", axis=1)
X_test = test.drop("binary", axis=1)

In [20]:
X_train.to_csv("../data/train_IC50.csv", index=False)
X_test.to_csv("../data/test_IC50.csv", index=False)
X_val.to_csv("../data/valid_IC50.csv", index=False)

np.save("../data/train_IC50_labels.npy", y_train)
np.save("../data/test_IC50_labels.npy", y_test)
np.save("../data/valid_IC50_labels.npy", y_val)

In [21]:
y_train

[1.0,
 1.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0

In [22]:
y_test

[1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 1.0,
 0.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0

In [23]:
y_val

[0.0,
 0.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 0.0,
 0.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0