# Survival prediction with ABMIL

In [1]:
import pandas as pd 
import numpy as np
import os
import seaborn as sns
from patho_bench.SplitFactory import SplitFactory

In [2]:
# !pip install torchsurv

## Setup the config and split file



Examples for a bunch of datasets can be found here: https://huggingface.co/datasets/MahmoodLab/Patho-Bench/tree/main/cptac_ccrcc

The config file (config.yaml) should look like this (this is the CPTAC-CCRCC one):

```
datasets:
  - cptac_ccrcc

task_col: OS
extra_cols:
  - OS_event
  - OS_days
task_type: survival
metrics:
  - cindex
label_dict:
  0: "Quartile 1, Event 0"
  1: "Quartile 2, Event 0"
  2: "Quartile 3, Event 0"
  3: "Quartile 4, Event 0"
  4: "Quartile 1, Event 1"
  5: "Quartile 2, Event 1"
  6: "Quartile 3, Event 1"
  7: "Quartile 4, Event 1"

sample_col: case_id
num_samples: 94
```

Where: 

`datasets` = name of dataset \
`task_col` = specifies the column in the split file were the task is specified \
`extra_cols` = additional columns needed from the split file \
`sample_col` = specifies the name of the samples which should match the .h5 files with patch features



The split file (k=all.tsv) needs to look something like:

```

case_id   slide_id      OS_event OS_days 	OS  fold_0 	fold_1 	fold_2	fold_3	fold_4
C3L-00026 C3L-00026-21	0	     1769.0	 0	train	train	train	train	test
```



In [3]:
# Read clinical data
clindata_tcga_path = '/mnt/d/data/tcga-brca/brca_tcga_gdc_clinical_data.tsv'

clindata_tcga = pd.read_csv(clindata_tcga_path, sep='\t')
clindata_tcga.columns

Index(['Study ID', 'Patient ID', 'Sample ID', 'Diagnosis Age',
       'American Joint Committee on Cancer Publication Version Type',
       'Biopsy Site', 'Cancer Type', 'Cancer Type Detailed',
       'Disease Free (Months)', 'Disease Free Status', 'Disease Type',
       'Ethnicity Category', 'Fraction Genome Altered',
       'ICD-10 Classification', 'Is FFPE', 'Morphology', 'Mutation Count',
       'Oncotree Code', 'Overall Survival (Months)', 'Overall Survival Status',
       'Other Patient ID', 'Other Sample ID', 'AJCC Pathologic M-Stage',
       'AJCC Pathologic N-Stage', 'AJCC Pathologic Stage',
       'AJCC Pathologic T-Stage', 'Primary Diagnosis',
       'Patient Primary Tumor Site', 'Prior Malignancy', 'Prior Treatment',
       'Project Identifier', 'Project Name', 'Project State', 'Race Category',
       'Number of Samples Per Patient', 'Sample Type', 'Sex',
       'Patient's Vital Status', 'Year of Death', 'Year of Diagnosis'],
      dtype='object')

In [4]:
split_tsv = clindata_tcga[['Patient ID', 'Overall Survival (Months)', 'Overall Survival Status']].copy()
split_tsv.rename(columns={
    'Patient ID': 'case_id'
}, inplace=True)

split_tsv.loc[:, 'OS_event'] = split_tsv['Overall Survival Status'].apply(lambda x: 0 if x == '0:LIVING' else 1).astype(float)

# Convert 'Overall Survival (Months)' to days
split_tsv['OS_days'] = split_tsv['Overall Survival (Months)'] * 30.44
# Add dummy column called 'OS'
split_tsv['OS'] = 0
split_tsv

Unnamed: 0,case_id,Overall Survival (Months),Overall Survival Status,OS_event,OS_days,OS
0,TCGA-3C-AAAU,132.950066,0:LIVING,0.0,4047.0,0
1,TCGA-3C-AALI,131.570302,0:LIVING,0.0,4005.0,0
2,TCGA-3C-AALJ,48.423127,0:LIVING,0.0,1474.0,0
3,TCGA-3C-AALK,47.568988,0:LIVING,0.0,1448.0,0
4,TCGA-4H-AAAK,11.432326,0:LIVING,0.0,348.0,0
...,...,...,...,...,...,...
1098,TCGA-WT-AB44,29.007884,0:LIVING,0.0,883.0,0
1099,TCGA-XX-A899,15.341656,0:LIVING,0.0,467.0,0
1100,TCGA-XX-A89A,16.031537,0:LIVING,0.0,488.0,0
1101,TCGA-Z7-A8R5,107.982917,0:LIVING,0.0,3287.0,0


In [5]:
# Get case_id from the processed wsis 
patch_embeddings_dirs = '/mnt/d/data/tcga-brca/trident_processed/20x_256px_0px_overlap/features_conch_v1'

available_slides = os.listdir(patch_embeddings_dirs)
slide_ids = [os.path.splitext(x)[0] for x in available_slides if x.endswith('.h5')]

available_patient_ids = [x[:len('TCGA-XX-XXXX')] for x in slide_ids]
available_patient_ids

df_cases = pd.DataFrame({
    'case_id': available_patient_ids,
    'slide_id': slide_ids
})

In [6]:
# Double check that there is just one patient id per case id
assert len(set(available_patient_ids)) == len(available_patient_ids), "There are duplicate patient IDs in the case IDs."

In [7]:
# Filter the split_tsv to only include available patient IDs
split_tsv_filtered = split_tsv[split_tsv['case_id'].isin(available_patient_ids)]
split_tsv_filtered

Unnamed: 0,case_id,Overall Survival (Months),Overall Survival Status,OS_event,OS_days,OS
252,TCGA-AC-A23G,73.850197,0:LIVING,0.0,2248.0,0
379,TCGA-AO-A1KO,20.43364,0:LIVING,0.0,622.0,0
404,TCGA-AR-A0TZ,107.161629,1:DECEASED,1.0,3262.0,0
524,TCGA-BH-A0AW,20.43364,0:LIVING,0.0,622.0,0
741,TCGA-D8-A1JH,13.994744,0:LIVING,0.0,426.0,0
835,TCGA-E2-A15D,17.279895,0:LIVING,0.0,526.0,0
952,TCGA-E9-A5FL,0.788436,0:LIVING,0.0,24.0,0
971,TCGA-EW-A1P1,39.750329,0:LIVING,0.0,1210.0,0


In [8]:
# Add case_id column from df_cases
split_tsv_filtered = split_tsv_filtered.merge(df_cases, on='case_id', how='left')
split_tsv_filtered


Unnamed: 0,case_id,Overall Survival (Months),Overall Survival Status,OS_event,OS_days,OS,slide_id
0,TCGA-AC-A23G,73.850197,0:LIVING,0.0,2248.0,0,TCGA-AC-A23G-01Z-00-DX1.2F0326F7-6B77-4B3F-B4F...
1,TCGA-AO-A1KO,20.43364,0:LIVING,0.0,622.0,0,TCGA-AO-A1KO-01Z-00-DX1.EEB5E0A0-92B2-42CD-9F7...
2,TCGA-AR-A0TZ,107.161629,1:DECEASED,1.0,3262.0,0,TCGA-AR-A0TZ-01Z-00-DX1.2D58BE38-03F6-4310-8E0...
3,TCGA-BH-A0AW,20.43364,0:LIVING,0.0,622.0,0,TCGA-BH-A0AW-01Z-00-DX1.9D50A0D2-B103-411C-831...
4,TCGA-D8-A1JH,13.994744,0:LIVING,0.0,426.0,0,TCGA-D8-A1JH-01Z-00-DX1.4A4F2502-612C-421D-9F6...
5,TCGA-E2-A15D,17.279895,0:LIVING,0.0,526.0,0,TCGA-E2-A15D-01Z-00-DX1.AA5AF847-3635-4BAF-AAC...
6,TCGA-E9-A5FL,0.788436,0:LIVING,0.0,24.0,0,TCGA-E9-A5FL-01Z-00-DX1.FB810D6A-303E-45DF-BEF...
7,TCGA-EW-A1P1,39.750329,0:LIVING,0.0,1210.0,0,TCGA-EW-A1P1-01Z-00-DX1.4B670029-4B3B-4D76-8EA...


In [9]:
# Create 4 folds where 2 cases are test, rest are train
num_cases = len(split_tsv_filtered) # 8 cases
idx = list(range(num_cases))
for fold_idx in range(4):
    test_cases = idx[fold_idx * 2:(fold_idx + 1) * 2]
    train_cases = [i for i in idx if i not in test_cases]
    
    split_tsv_filtered.loc[train_cases, f'fold_{fold_idx}'] = 'train'
    split_tsv_filtered.loc[test_cases, f'fold_{fold_idx}'] = 'test'
    


In [None]:
# TODO: Don't do this
# This is for making sure that the evaluation script works when we have so few cases

split_tsv_filtered.loc[:, 'OS_event'] = 1

In [11]:
split_tsv_filtered

Unnamed: 0,case_id,Overall Survival (Months),Overall Survival Status,OS_event,OS_days,OS,slide_id,fold_0,fold_1,fold_2,fold_3
0,TCGA-AC-A23G,73.850197,0:LIVING,1.0,2248.0,0,TCGA-AC-A23G-01Z-00-DX1.2F0326F7-6B77-4B3F-B4F...,test,train,train,train
1,TCGA-AO-A1KO,20.43364,0:LIVING,1.0,622.0,0,TCGA-AO-A1KO-01Z-00-DX1.EEB5E0A0-92B2-42CD-9F7...,test,train,train,train
2,TCGA-AR-A0TZ,107.161629,1:DECEASED,1.0,3262.0,0,TCGA-AR-A0TZ-01Z-00-DX1.2D58BE38-03F6-4310-8E0...,train,test,train,train
3,TCGA-BH-A0AW,20.43364,0:LIVING,1.0,622.0,0,TCGA-BH-A0AW-01Z-00-DX1.9D50A0D2-B103-411C-831...,train,test,train,train
4,TCGA-D8-A1JH,13.994744,0:LIVING,1.0,426.0,0,TCGA-D8-A1JH-01Z-00-DX1.4A4F2502-612C-421D-9F6...,train,train,test,train
5,TCGA-E2-A15D,17.279895,0:LIVING,1.0,526.0,0,TCGA-E2-A15D-01Z-00-DX1.AA5AF847-3635-4BAF-AAC...,train,train,test,train
6,TCGA-E9-A5FL,0.788436,0:LIVING,1.0,24.0,0,TCGA-E9-A5FL-01Z-00-DX1.FB810D6A-303E-45DF-BEF...,train,train,train,test
7,TCGA-EW-A1P1,39.750329,0:LIVING,1.0,1210.0,0,TCGA-EW-A1P1-01Z-00-DX1.4B670029-4B3B-4D76-8EA...,train,train,train,test


In [12]:
# Save the split TSV file
split_tsv_filtered_path = './_tutorial_splits/tcga_brca/OS/k=all.tsv'
split_tsv_filtered.to_csv(split_tsv_filtered_path, sep='\t', index=False)

In [13]:


model_name = 'conch_v1'
train_source = 'tcga_brca' 
task_name = 'OS'

path_to_split, path_to_task_config = SplitFactory.from_hf('./_tutorial_splits', train_source, task_name)


In [14]:
split_df = pd.read_csv(path_to_split, sep='\t')

In [15]:
split_df.head()

Unnamed: 0,case_id,Overall Survival (Months),Overall Survival Status,OS_event,OS_days,OS,slide_id,fold_0,fold_1,fold_2,fold_3
0,TCGA-AC-A23G,73.850197,0:LIVING,1.0,2248.0,0,TCGA-AC-A23G-01Z-00-DX1.2F0326F7-6B77-4B3F-B4F...,test,train,train,train
1,TCGA-AO-A1KO,20.43364,0:LIVING,1.0,622.0,0,TCGA-AO-A1KO-01Z-00-DX1.EEB5E0A0-92B2-42CD-9F7...,test,train,train,train
2,TCGA-AR-A0TZ,107.161629,1:DECEASED,1.0,3262.0,0,TCGA-AR-A0TZ-01Z-00-DX1.2D58BE38-03F6-4310-8E0...,train,test,train,train
3,TCGA-BH-A0AW,20.43364,0:LIVING,1.0,622.0,0,TCGA-BH-A0AW-01Z-00-DX1.9D50A0D2-B103-411C-831...,train,test,train,train
4,TCGA-D8-A1JH,13.994744,0:LIVING,1.0,426.0,0,TCGA-D8-A1JH-01Z-00-DX1.4A4F2502-612C-421D-9F6...,train,train,test,train


In [19]:
from patho_bench.ExperimentFactory import ExperimentFactory

split = f'./_tutorial_splits/tcga_brca/{task_name}/k=all.tsv'
task_config = f'./_tutorial_splits/tcga_brca/{task_name}/config.yaml'
saveto = f'./_tutorial_finetune/{train_source}/{task_name}/{model_name}'
combine_slides_per_patient = False
COMBINE_TRAIN_VAL = False
patch_embeddings_dirs = '/mnt/d/data/tcga-brca/trident_processed/20x_256px_0px_overlap/features_conch_v1'
bag_size = 1


split, task_info, internal_dataset = ExperimentFactory._prepare_internal_dataset(split,
                                                                            task_config,
                                                                            saveto,
                                                                            combine_slides_per_patient,
                                                                            COMBINE_TRAIN_VAL,
                                                                            patch_embeddings_dirs,
                                                                            bag_size = bag_size)

Loaded split from ./_tutorial_splits/tcga_brca/OS/k=all.tsv with 8 samples and 4 folds assigned.


In [20]:
split

Split with 8 samples and 4 folds assigned.
First 5 samples:
{'id': 'TCGA-AC-A23G', 'labels': {'OS': 0}, 'folds': ['test', 'train', 'train', 'train'], 'OS_event': [1.0], 'OS_days': [2248.0], 'slide_id': ['TCGA-AC-A23G-01Z-00-DX1.2F0326F7-6B77-4B3F-B4FA-59ADB785AA07']}
{'id': 'TCGA-AO-A1KO', 'labels': {'OS': 0}, 'folds': ['test', 'train', 'train', 'train'], 'OS_event': [1.0], 'OS_days': [622.0], 'slide_id': ['TCGA-AO-A1KO-01Z-00-DX1.EEB5E0A0-92B2-42CD-9F7A-00E9250B561F']}
{'id': 'TCGA-AR-A0TZ', 'labels': {'OS': 0}, 'folds': ['train', 'test', 'train', 'train'], 'OS_event': [1.0], 'OS_days': [3261.9999999999995], 'slide_id': ['TCGA-AR-A0TZ-01Z-00-DX1.2D58BE38-03F6-4310-8E06-F1A523FB0904']}
{'id': 'TCGA-BH-A0AW', 'labels': {'OS': 0}, 'folds': ['train', 'test', 'train', 'train'], 'OS_event': [1.0], 'OS_days': [622.0], 'slide_id': ['TCGA-BH-A0AW-01Z-00-DX1.9D50A0D2-B103-411C-831E-8520C3D50173']}
{'id': 'TCGA-D8-A1JH', 'labels': {'OS': 0}, 'folds': ['train', 'train', 'test', 'train'], 'OS_even

In [None]:
from patho_bench.ExperimentFactory import ExperimentFactory

model_name = 'abmil'
train_source = 'tcga_brca'
task_name = 'OS'

experiment = ExperimentFactory.finetune(
                    split = f'./_tutorial_splits/tcga_brca/{task_name}/k=all.tsv',
                    task_config = f'./_tutorial_splits/tcga_brca/{task_name}/config.yaml',
                    patch_embeddings_dirs = patch_embeddings_dirs,
                    saveto = f'./_tutorial_finetune/{train_source}/{task_name}/{model_name}',
                    combine_slides_per_patient = False,
                    model_name = model_name,
                    bag_size = 2048,
                    base_learning_rate = 0.0003,
                    layer_decay = None, 
                    gradient_accumulation = 1,
                    weight_decay = 0.00001,
                    num_epochs = 2,
                    scheduler_type = 'cosine',
                    optimizer_type = 'AdamW',
                    balanced = True, 
                    save_which_checkpoints = 'last-1',
                    model_kwargs = {                    # ABMIL requires extra kwargs. Other models do not.
                        'input_feature_dim': 512, # CHANGE THIS DEPENDING ON THE MODEL
                        'n_heads': 1,
                        'head_dim': 512,
                        'dropout': 0.25,
                        'gated': False
                    }
                    )
experiment.train()
experiment.test()
result = experiment.report_results('cindex')

Loaded split from ./_tutorial_splits/tcga_brca/OS/k=all.tsv with 8 samples and 4 folds assigned.

Experiment dir: ./_tutorial_finetune/tcga_brca/OS/abmil
############################################################################################################
Training: Fold 1 of 4...


      Epoch 1 train: 100%|███████████| 2/2 [00:35<00:00, 17.51s/it, avg_loss=0.0000, num_batches=6/6, num_samples=6]


############################################################################################################
Training: Fold 2 of 4...


      Epoch 1 train: 100%|███████████| 2/2 [00:44<00:00, 22.13s/it, avg_loss=0.0000, num_batches=6/6, num_samples=6]


############################################################################################################
Training: Fold 3 of 4...


      Epoch 1 train: 100%|███████████| 2/2 [00:30<00:00, 15.21s/it, avg_loss=0.0000, num_batches=6/6, num_samples=6]


############################################################################################################
Training: Fold 4 of 4...


      Epoch 1 train: 100%|███████████| 2/2 [00:33<00:00, 16.90s/it, avg_loss=0.0000, num_batches=6/6, num_samples=6]
  0%|                                                                                         | 0/4 [00:00<?, ?it/s]
Running test split on 2 samples: 100%|████████████████████████████████████████████████| 4/4 [00:27<00:00,  6.94s/it]

Final summary metrics: {'cindex': {'mean': 0.5, 'se': 0.25, 'formatted': '0.500 ± 0.250'}}





KeyError: 'c-index'