In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedKFold, train_test_split
import os

In [2]:
# Input File
input_file = '/home/chb3333/yulab/chb3333/data_extraction/wxs_sample_sheet_clean.tsv'

wxs_tsv = pd.read_csv(input_file, sep='\t')


In [3]:
wxs_tsv = wxs_tsv.drop(columns=["Data Category", 'Data Type', 'File ID', 'File Name', 'Sample ID','Sample Type'])

In [4]:
tcga_wxs_tsv = wxs_tsv[wxs_tsv['Project ID'].str.contains('TCGA', regex=False, na=False)]

In [6]:
duplicate_rows = tcga_wxs_tsv[tcga_wxs_tsv['Case ID'].duplicated(keep=False)]

duplicate_rows.to_csv("/home/chb3333/yulab/chb3333/gem-patho/data_extraction/duplicate_handling/duplicate_caseIDs.csv")
# We need to then extract the one with the most mutations - latest and no errors

In [7]:
tcga_wxs_tsv_deduplicated = tcga_wxs_tsv.drop_duplicates(keep='first')

In [8]:
skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)


In [9]:
tcga_wxs_tsv_deduplicated['Project ID']

9          TCGA-OV
10         TCGA-OV
11         TCGA-OV
12         TCGA-OV
13         TCGA-OV
           ...    
17757    TCGA-UCEC
17758    TCGA-UCEC
17759    TCGA-UCEC
17760    TCGA-UCEC
17761    TCGA-UCEC
Name: Project ID, Length: 10190, dtype: object

In [10]:
unique_project_ids = tcga_wxs_tsv_deduplicated['Project ID'].unique()
print("Unique Project IDs:")
print(unique_project_ids)

Unique Project IDs:
['TCGA-OV' 'TCGA-ESCA' 'TCGA-BRCA' 'TCGA-GBM' 'TCGA-HNSC' 'TCGA-KICH'
 'TCGA-LGG' 'TCGA-LAML' 'TCGA-KIRC' 'TCGA-KIRP' 'TCGA-LUSC' 'TCGA-MESO'
 'TCGA-LUAD' 'TCGA-CHOL' 'TCGA-COAD' 'TCGA-CESC' 'TCGA-DLBC' 'TCGA-ACC'
 'TCGA-BLCA' 'TCGA-LIHC' 'TCGA-UCS' 'TCGA-UVM' 'TCGA-SARC' 'TCGA-PRAD'
 'TCGA-READ' 'TCGA-THYM' 'TCGA-TGCT' 'TCGA-PCPG' 'TCGA-THCA' 'TCGA-STAD'
 'TCGA-PAAD' 'TCGA-SKCM' 'TCGA-UCEC']


In [11]:
unique_count = tcga_wxs_tsv_deduplicated['Project ID'].nunique()
print("Number of unique Project IDs:", unique_count)

Number of unique Project IDs: 33


In [12]:
base_path = "/home/chb3333/yulab/chb3333/gem-patho/data_extraction/kfolds/master_kfold"

In [13]:
for fold, (train_index, test_index) in enumerate(skf.split(tcga_wxs_tsv_deduplicated, tcga_wxs_tsv_deduplicated['Project ID'])):
    train_val_fold = tcga_wxs_tsv_deduplicated.iloc[train_index]
    test_fold = tcga_wxs_tsv_deduplicated.iloc[test_index]

    train_fold, val_fold = train_test_split(train_val_fold, test_size=0.1, random_state=42, shuffle=True, stratify=train_val_fold['Project ID'])
    
    #print(f"Fold {fold+1}:")
    #print(val_fold['Project ID'].value_counts())

    fold_path = os.path.join(base_path, f"fold_{fold+1}")
    os.makedirs(fold_path, exist_ok=True)
    train_file = os.path.join(fold_path, "train.parquet")
    val_file   = os.path.join(fold_path, "val.parquet")
    test_file  = os.path.join(fold_path, "test.parquet")

    train_fold.to_parquet(train_file, index=False)
    val_fold.to_parquet(val_file, index=False)
    test_fold.to_parquet(test_file, index=False)
    
    print(f"Fold {fold+1} saved in {fold_path}")

Fold 1 saved in /home/chb3333/yulab/chb3333/gem-patho/data_extraction/kfolds/master_kfold/fold_1
Fold 2 saved in /home/chb3333/yulab/chb3333/gem-patho/data_extraction/kfolds/master_kfold/fold_2
Fold 3 saved in /home/chb3333/yulab/chb3333/gem-patho/data_extraction/kfolds/master_kfold/fold_3
Fold 4 saved in /home/chb3333/yulab/chb3333/gem-patho/data_extraction/kfolds/master_kfold/fold_4
Fold 5 saved in /home/chb3333/yulab/chb3333/gem-patho/data_extraction/kfolds/master_kfold/fold_5
Fold 6 saved in /home/chb3333/yulab/chb3333/gem-patho/data_extraction/kfolds/master_kfold/fold_6
Fold 7 saved in /home/chb3333/yulab/chb3333/gem-patho/data_extraction/kfolds/master_kfold/fold_7
Fold 8 saved in /home/chb3333/yulab/chb3333/gem-patho/data_extraction/kfolds/master_kfold/fold_8
Fold 9 saved in /home/chb3333/yulab/chb3333/gem-patho/data_extraction/kfolds/master_kfold/fold_9
Fold 10 saved in /home/chb3333/yulab/chb3333/gem-patho/data_extraction/kfolds/master_kfold/fold_10
