In [None]:
import numpy as np
#from geojson import GeoJSON
import pandas as pd
import json
import os
import glob
import shapely
from rtree import index
from shapely.ops import cascaded_union, unary_union
from shapely.plotting import plot_polygon
from collections import Counter
import matplotlib.pyplot as plt
from openslide import OpenSlide

from tiatoolbox import utils
from tiatoolbox.wsicore import wsireader
from tiatoolbox import data
from tiatoolbox.tools import stainnorm
from tqdm import tqdm
import h5py
import cv2

from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import StratifiedGroupKFold

from matplotlib.patches import Polygon
from matplotlib.colors import ListedColormap

from omegaconf import OmegaConf
# Load config
preproc_conf = OmegaConf.load("../conf/preproc.yaml")
preproc_conf = preproc_conf['classic_mil_on_embeddings_bag']['bracs_224_224_patches']

In [None]:
preproc_conf.data_root_dir+'BRACS.xlsx', preproc_conf.cv_split_dir

### Locate annotations

In [None]:
annotation_folder = preproc_conf.annotation_root_dir

In [None]:
bracs_df = pd.read_excel(preproc_conf.cv_split_dir+'BRACS.xlsx')
bracs_df.head()

### Filter the patients that appear in more than one set

In [None]:
# Group by Patient Id and count the number of unique sets they appear in
patient_set_overlap = bracs_df.groupby('Patient Id')['Set'].nunique()

# Filter the patients that appear in more than one set
leaked_patients = patient_set_overlap[patient_set_overlap > 1]
leaked_patients.index# Filter the patients that appear in more than one set

In [None]:
bracs_df[ bracs_df['Patient Id'] == 67 ]

In [None]:
# Update the 'Set' column to 'Validation' for all rows where 'Patient Id' is 67
bracs_df.loc[bracs_df['Patient Id'] == 67, 'Set'] = 'Validation'

In [None]:
# Group by Patient Id and count the number of unique sets they appear in
patient_set_overlap = bracs_df.groupby('Patient Id')['Set'].nunique()

# Filter the patients that appear in more than one set
leaked_patients = patient_set_overlap[patient_set_overlap > 1]
leaked_patients.index

In [None]:
bracs_df

### Save test fold

In [None]:
bracs_df_test = bracs_df[ bracs_df.Set.isin(['Testing']) ]
bracs_df_test.reset_index(inplace=True, drop=True)
bracs_df_test.head()

In [None]:
bracs_df_test.shape

In [None]:
bracs_df_test[['WSI Filename', 'Patient Id', 'RoI ', 'WSI label']].to_csv(f"{preproc_conf.cv_split_dir}/test_split_stratified.csv", index=False)

### Save rest into 5-fold CV

In [None]:
bracs_df_into_splits = bracs_df[ bracs_df.Set.isin(['Validation', 'Training']) ]
bracs_df_into_splits.reset_index(inplace=True, drop=True)
bracs_df_into_splits.head()

#### Generate splits

In [None]:
n_splits = 5

X = bracs_df_into_splits[['WSI Filename']]
y = bracs_df_into_splits['Patient Id']

cv = StratifiedGroupKFold(n_splits=n_splits)

In [None]:
train_splits = []
val_splits = []

for fold, (train_idx, val_idx) in enumerate(cv.split(X, y, groups=bracs_df_into_splits['Patient Id'])):
    train_set = bracs_df_into_splits.iloc[train_idx]
    val_set = bracs_df_into_splits.iloc[val_idx]
    
    print(f"Fold {fold + 1}")
    print("Train Set:      ", train_set.shape, np.unique(train_set['Patient Id'].values, return_counts=True))
    print("Validation Set: ", val_set.shape, np.unique(val_set['Patient Id'].values, return_counts=True))
    print("-" * 40)
    
    train_splits.append(train_idx)
    val_splits.append(val_idx)

In [None]:
# check if there is any overlap in the val sets
for i in range(n_splits-1):
    print(list(set(val_splits[0]) & set(val_splits[i+1])))

In [None]:
for s in range(n_splits):
    # save train set
    bracs_df_into_splits[['WSI Filename', 'Patient Id', 'RoI ', 'WSI label']].iloc[train_splits[s]].to_csv(f"{preproc_conf.cv_split_dir}/train_split_stratified_{s}.csv", index=False)
    
    # save val set
    bracs_df_into_splits[['WSI Filename', 'Patient Id', 'RoI ', 'WSI label']].iloc[val_splits[s]].to_csv(f"{preproc_conf.cv_split_dir}/val_split_stratified_{s}.csv", index=False)

### Check RoI sums

In [None]:
for s in range(n_splits):
    # save train set
    print(s)
    print(bracs_df_into_splits[['RoI ']].iloc[train_splits[s]].values.sum())
    
    # save val set
    print(bracs_df_into_splits[['RoI ']].iloc[val_splits[s]].values.sum())