In [1]:
import os
import pandas as pd
import argparse
from sklearn.model_selection import train_test_split


parser = argparse.ArgumentParser(description='Creating splits for whole slide classification')
parser.add_argument('--csv_path', type=str, default="./tcga_ov_dfs/clinical_dfs_data.csv")
parser.add_argument('--data_name', type=str, default="tcga_ov_dfs")
parser.add_argument('--seed', type=int, default=58,
                    help='random seed (default: 58)')
parser.add_argument('--k', type=int, default=5,
                    help='number of splits (default: 10)')
parser.add_argument('--val_frac', type=float, default= 0.2,
                    help='fraction of labels for validation (default: 0.2)')
parser.add_argument('--test_frac', type=float, default= 0.2,
                    help='fraction of labels for test (d"efault: 0.2)')
parser.add_argument('--stratify_col', default="group,time")
args = parser.parse_args("")

In [4]:
split_dir = '../splits/'+ str(args.data_name)
    
os.makedirs(split_dir, exist_ok=True)
df = pd.read_csv(args.csv_path, compression="zip" if ".zip" in args.csv_path else None)
case_df = df.drop_duplicates("case_id").reset_index(drop=True)
assert case_df["case_id"].duplicated().sum() == 0

rna_cases = pd.read_csv("../datasets_csv/tcga_ov_rna.csv.zip", compression="zip")["case_id"].values.tolist()
pro_cases = pd.read_csv("../datasets_csv/tcga_ov_pro.csv.zip", compression="zip")["case_id"].values.tolist()
dna_cases = pd.read_csv("../datasets_csv/tcga_ov_dna.csv.zip", compression="zip")["case_id"].values.tolist()
mut_cases = pd.read_csv("../datasets_csv/tcga_ov_mut.csv.zip", compression="zip")["case_id"].values.tolist()
cnv_cases = pd.read_csv("../datasets_csv/tcga_ov_cnv.csv.zip", compression="zip")["case_id"].values.tolist()

rna_cases = [i for i in rna_cases if i in case_df["case_id"].values]
pro_cases = [i for i in pro_cases if i in case_df["case_id"].values]
dna_cases = [i for i in dna_cases if i in case_df["case_id"].values]
mut_cases = [i for i in mut_cases if i in case_df["case_id"].values]
cnv_cases = [i for i in cnv_cases if i in case_df["case_id"].values]
print(case_df.shape, len(rna_cases), len(pro_cases), len(dna_cases), len(mut_cases), len(cnv_cases))

(490, 24) 245 359 478 266 477


In [6]:
case_df["survival_months"].describe()

count    490.000000
mean      21.456755
std       22.180881
min        0.530000
25%        9.070000
50%       14.700000
75%       26.155000
max      180.060000
Name: survival_months, dtype: float64

In [7]:
_, time_breaks = pd.qcut(case_df["survival_months"], q=4, retbins=True, labels=False)
time_breaks[0] = 0
time_breaks[-1] *= 2
disc_labels, _ = pd.cut(case_df["survival_months"], bins=time_breaks, retbins=True, labels=False, right=False, include_lowest=True)
case_df.insert(2, 'time', disc_labels.values.astype(int))
print(time_breaks)
case_df["time"].value_counts()

[  0.      9.07   14.7    26.155 360.12 ]


1    124
3    123
2    122
0    121
Name: time, dtype: int64

In [8]:
intersecting_case_df = case_df.copy()
for case_set in [rna_cases, pro_cases, dna_cases, mut_cases, cnv_cases]:
    intersecting_case_df = intersecting_case_df[intersecting_case_df["case_id"].isin(case_set)]
remaining_case_df = case_df[~case_df["case_id"].isin(intersecting_case_df["case_id"].values)]
remaining_case_df.shape, intersecting_case_df.shape, case_df.shape

((373, 25), (117, 25), (490, 25))

In [9]:
test_intersection_min = case_df.shape[0]*.05
print("Minimum test size: ", test_intersection_min)
test_intersection_percentage = test_intersection_min/intersecting_case_df.shape[0]

test_total = case_df.shape[0]*args.test_frac

test_remaining = test_total - test_intersection_min
test_remaining_percentage = test_remaining/remaining_case_df.shape[0]
test_intersection_percentage, test_remaining_percentage

Minimum test size:  24.5


(0.2094017094017094, 0.1970509383378016)

In [10]:
strats = ["event"]
if args.stratify_col not in ["None", "none", None]:
    for i in args.stratify_col.split(","):
        strats.append(i)
    
print("Stratification: ", strats)
intersecting_test_cases = intersecting_case_df.groupby(strats, group_keys=False).apply(lambda x: x.sample(frac=test_intersection_percentage, random_state=args.seed))
intersecting_dev_df = intersecting_case_df.drop(intersecting_test_cases.index).copy()

remaining_test_cases = remaining_case_df.groupby(strats, group_keys=False).apply(lambda x: x.sample(frac=test_remaining_percentage, random_state=args.seed))
remaining_dev_df = remaining_case_df.drop(remaining_test_cases.index).copy()
intersecting_test_cases.shape, intersecting_dev_df.shape, remaining_test_cases.shape, remaining_dev_df.shape

Stratification:  ['event', 'group', 'time']


((24, 25), (93, 25), (74, 25), (299, 25))

In [11]:
test_cases = list(intersecting_test_cases["case_id"].values)+list(remaining_test_cases["case_id"].values)
rna_test = [i for i in rna_cases if i in test_cases]
cnv_test = [i for i in cnv_cases if i in test_cases]
dna_test = [i for i in dna_cases if i in test_cases]
pro_test = [i for i in pro_cases if i in test_cases]
mut_test = [i for i in mut_cases if i in test_cases]
len(rna_test), len(cnv_test), len(dna_test), len(pro_test), len(mut_test)

(51, 97, 97, 65, 51)

In [12]:
from itertools import  combinations
import pandas as pd
genetic_dict = {
    "rna": rna_cases,
    "cnv": cnv_cases,
    "dna": dna_cases,
    "pro": pro_cases,
    "mut": mut_cases,
}

dataset_sizes = []
for i in range(1, 5):
    combos = list(combinations(genetic_dict.keys(), i))
    for c in combos:
        all_cases = case_df["case_id"].values.copy()
        selected_cases = test_cases.copy()
        for gen in c:
            selected_cases = [i for i in selected_cases if i in genetic_dict[gen]]
            all_cases = [i for i in all_cases if i in genetic_dict[gen]]
        dataset_sizes.append([c, len(selected_cases), len(all_cases), len(selected_cases)/len(all_cases)])
        
dataset_sizes = pd.DataFrame(dataset_sizes, columns=["Genetic", "Test", "All", "Test_Frq"])
print("Minimum frq: ", dataset_sizes["Test_Frq"].min())
print("Minimum test: ", dataset_sizes["Test"].min())
dataset_sizes

Minimum frq:  0.16666666666666666
Minimum test:  24


Unnamed: 0,Genetic,Test,All,Test_Frq
0,"(rna,)",51,245,0.208163
1,"(cnv,)",97,477,0.203354
2,"(dna,)",97,478,0.202929
3,"(pro,)",65,359,0.181058
4,"(mut,)",51,266,0.191729
5,"(rna, cnv)",50,243,0.205761
6,"(rna, dna)",51,238,0.214286
7,"(rna, pro)",39,197,0.19797
8,"(rna, mut)",34,150,0.226667
9,"(cnv, dna)",96,469,0.204691


In [13]:

for k in range(args.k):
    intersecting_val_cases = intersecting_dev_df.groupby(strats, group_keys=False).apply(lambda x: x.sample(frac=args.val_frac, random_state=args.seed+k))
    intersecting_train_cases = intersecting_dev_df.drop(intersecting_val_cases.index).copy()

    remaining_val_cases = remaining_dev_df.groupby(strats, group_keys=False).apply(lambda x: x.sample(frac=args.val_frac, random_state=args.seed+k))
    remaining_train_cases = remaining_dev_df.drop(remaining_val_cases.index).copy()
    
    train_cases = pd.concat([intersecting_train_cases, remaining_train_cases])
    val_cases = pd.concat([intersecting_val_cases, remaining_val_cases])
    test_cases = pd.concat([intersecting_test_cases, remaining_test_cases])
    print("Event ratios: ")
    print("\t Total: ")
    print(train_cases["event"].value_counts()/len(train_cases))
    print(val_cases["event"].value_counts()/len(val_cases))
    print(test_cases["event"].value_counts()/len(test_cases))
    # print("\t Intersecting: ")
    # print(intersecting_train_cases["event"].value_counts()/len(intersecting_train_cases))
    # print(intersecting_val_cases["event"].value_counts()/len(intersecting_val_cases))
    # print(intersecting_test_cases["event"].value_counts()/len(intersecting_test_cases))
    # print("\t Remaining: ")
    # print(remaining_train_cases["event"].value_counts()/len(remaining_train_cases))
    # print(remaining_val_cases["event"].value_counts()/len(remaining_val_cases))
    # print(remaining_test_cases["event"].value_counts()/len(remaining_test_cases))
    print()
    print("Group ratios: ")
    print("\t Total: ")
    print(train_cases["group"].value_counts()/len(train_cases))
    print(val_cases["group"].value_counts()/len(val_cases))
    print(test_cases["group"].value_counts()/len(test_cases))
    # print("\t Intersecting: ")
    # print(intersecting_train_cases["group"].value_counts()/len(intersecting_train_cases))
    # print(intersecting_val_cases["group"].value_counts()/len(intersecting_val_cases))
    # print(intersecting_test_cases["group"].value_counts()/len(intersecting_test_cases))
    # print("\t Remaining: ")
    # print(remaining_train_cases["group"].value_counts()/len(remaining_train_cases))
    # print(remaining_val_cases["group"].value_counts()/len(remaining_val_cases))
    # print(remaining_test_cases["group"].value_counts()/len(remaining_test_cases))
    print()
    print("Label ratios: ")
    print("\t Total: ")
    print(train_cases["time"].value_counts()/len(train_cases))
    print(val_cases["time"].value_counts()/len(val_cases))
    print(test_cases["time"].value_counts()/len(test_cases))
    split_df = pd.concat([
        train_cases[["case_id"]].rename(columns={"case_id": "train"}).reset_index(drop=True),
        val_cases[["case_id"]].rename(columns={"case_id": "val"}).reset_index(drop=True),
        test_cases[["case_id"]].rename(columns={"case_id": "test"}).reset_index(drop=True)
    ], axis=1)
    train_slides = df[df["case_id"].isin(train_cases["case_id"].values)]
    val_slides = df[df["case_id"].isin(val_cases["case_id"].values)]
    test_slides = df[df["case_id"].isin(test_cases["case_id"].values)]
    print(f"\nFold: {k}")
    print("\nTrain:")
    for i in sorted(pd.unique(df["event"])):
        print("\tNb of patients on class {}: {} ({:.2f}%)" .format(i, (train_cases["event"]==i).sum(), 100*(train_cases["event"]==i).sum()/len(train_cases)))
        print("\tNb of slides on class {}: {} ({:.2f}%)" .format(i, (train_slides["event"]==i).sum(), 100*(train_slides["event"]==i).sum()/len(train_slides)))
    print("Validation:")
    for i in sorted(pd.unique(df["event"])):
        print("\tNb of patients on class {}: {} ({:.2f}%)" .format(i, (val_cases["event"]==i).sum(), 100*(val_cases["event"]==i).sum()/len(val_cases)))
        print("\tNb of slides on class {}: {} ({:.2f}%)" .format(i, (val_slides["event"]==i).sum(), 100*(val_slides["event"]==i).sum()/len(val_slides)))
        
    if len(test_slides) > 0:
        print("Test:")
        for i in sorted(pd.unique(df["event"])):
            print("\tNb of patients on class {}: {} ({:.2f}%)" .format(i, (test_cases["event"]==i).sum(), 100*(test_cases["event"]==i).sum()/len(test_cases)))
            print("\tNb of slides on class {}: {} ({:.2f}%)" .format(i, (test_slides["event"]==i).sum(), 100*(test_slides["event"]==i).sum()/len(test_slides)))
    
    print("\nTrain:")
    for i in sorted(pd.unique(case_df["time"])):
        print("\tNb of patients on class {}: {} ({:.2f}%)" .format(i, (train_cases["time"]==i).sum(), 100*(train_cases["time"]==i).sum()/len(train_cases)))
        # print("\tNb of slides on class {}: {} ({:.2f}%)" .format(i, (train_slides["time"]==i).sum(), 100*(train_slides["time"]==i).sum()/len(train_slides)))
    print("Validation:")
    for i in sorted(pd.unique(case_df["time"])):
        print("\tNb of patients on class {}: {} ({:.2f}%)" .format(i, (val_cases["time"]==i).sum(), 100*(val_cases["time"]==i).sum()/len(val_cases)))
        # print("\tNb of slides on class {}: {} ({:.2f}%)" .format(i, (val_slides["time"]==i).sum(), 100*(val_slides["time"]==i).sum()/len(val_slides)))
        
    if len(test_slides) > 0:
        print("Test:")
        for i in sorted(pd.unique(case_df["time"])):
            print("\tNb of patients on class {}: {} ({:.2f}%)" .format(i, (test_cases["time"]==i).sum(), 100*(test_cases["time"]==i).sum()/len(test_cases)))
            # print("\tNb of slides on class {}: {} ({:.2f}%)" .format(i, (test_slides["time"]==i).sum(), 100*(test_slides["time"]==i).sum()/len(test_slides)))
    
    split_df.to_csv(os.path.join(split_dir, "splits_{}.csv".format(k)))

Event ratios: 
	 Total: 
1.0    0.724359
0.0    0.275641
Name: event, dtype: float64
1.0    0.7125
0.0    0.2875
Name: event, dtype: float64
1.0    0.734694
0.0    0.265306
Name: event, dtype: float64

Group ratios: 
	 Total: 
1.0    0.842949
0.0    0.157051
Name: group, dtype: float64
1.0    0.85
0.0    0.15
Name: group, dtype: float64
1.0    0.836735
0.0    0.163265
Name: group, dtype: float64

Label ratios: 
	 Total: 
1    0.256410
3    0.253205
2    0.246795
0    0.243590
Name: time, dtype: float64
0    0.25
3    0.25
1    0.25
2    0.25
Name: time, dtype: float64
0    0.255102
2    0.255102
3    0.244898
1    0.244898
Name: time, dtype: float64

Fold: 0

Train:
	Nb of patients on class 0.0: 86 (27.56%)
	Nb of slides on class 0.0: 180 (27.52%)
	Nb of patients on class 1.0: 226 (72.44%)
	Nb of slides on class 1.0: 474 (72.48%)
Validation:
	Nb of patients on class 0.0: 23 (28.75%)
	Nb of slides on class 0.0: 49 (28.49%)
	Nb of patients on class 1.0: 57 (71.25%)
	Nb of slides on class

In [15]:
import numpy as np
base_train = None
for i in range(5):
    df = pd.read_csv(f"{split_dir}/splits_{i}.csv")
    if base_train is None:
        base_train = df["train"].dropna().values
        base_test = df["test"].dropna().values
    elif np.array_equal(base_train, df["train"].dropna().values):
        print("equal splits..")
    if not np.array_equal(base_test, df["test"].dropna().values):
        print("incorrect test splits..")

In [12]:
df1 = pd.read_csv("../../splits/tcga_ov_son/splits_0.csv")
df2 = pd.read_csv("../../splits/tcga_ov/splits_0.csv")
df1["train"] == df2["train"]

0      False
1      False
2      False
3      False
4      False
       ...  
364     True
365     True
366     True
367     True
368     True
Name: train, Length: 369, dtype: bool