# Split Datasets

In [1]:
%load_ext autoreload
%autoreload 2

## 1. Converted Only Dataset

In [2]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

from plotly.subplots import make_subplots
from src.config import PROCESSED_DATA_DIR, RAW_DATA_DIR
from src.utils import make_grouped_splits, persist_splits, make_grouped_holdout_split

In [None]:
converted_df = pd.read_csv(f'{PROCESSED_DATA_DIR}/cmmd_converted.csv')

In [None]:
converted_df.head()

In [None]:
converted_df.info()

In [None]:
converted_df['patientId'].nunique()

## Split by patient id and store in folder

In [None]:
# Split in train, validation and test, stratified by patientId to ensure that each patient is only in one of the sets and that the distribution of the target is the same in all sets

trainval_df, test_df = make_grouped_holdout_split(converted_df, patient_col='patientId', subtype_col='subtype')

In [None]:
px.pie(names=['TrainVal','Test'], values=[len(trainval_df), len(test_df)], title='Data split')

In [None]:
fig = make_subplots(rows=1, cols=2,
                    specs=[[{"type": "pie"}, {"type": "pie"}]],
                    subplot_titles=('Train set distribution',
                                    'Test set distribution'))

def add_pie(df, row, col):
    counts = df['subtype'].value_counts(normalize=True)
    fig.add_trace(go.Pie(labels=counts.index, values=counts.values), row=row, col=col)

add_pie(trainval_df, 1, 1)
add_pie(test_df, 1, 2)

fig.update_layout()

fig.show()

In [None]:
# Persist splitted images
persist_splits(trainval_df, None, test_df, patient_col='patientId', subtype_col='subtype')

In [None]:
trainval_df['convertedPath']

# Mammo-Bench Split

In [3]:
mb_df = pd.read_csv(f'{RAW_DATA_DIR}/MammoBench/mammo-bench_molecular_subtype.csv')

In [4]:
mb_df.head()

Unnamed: 0,source_dataset,laterality,view,preprocessed_image_path,classification,density,BIRADS,abnormality,molecular_subtype,raw_image_path,mask_path,ROI_path,x,y,radius,subject_age,source_subjectID,original_source_path
0,cmmd,L,CC,Preprocessed_Dataset/cmmd/cmmd_1026.jpg,Malignant,,,calcification,Luminal B,Original_Dataset/cmmd/cmmd_1026.jpg,Masks/cmmd/cmmd_1026.jpg,,,,,64.0,D2-0001,CMMD/CMMD/D2-0001/07-18-2011-NA-NA-75485/1.000...
1,cmmd,L,CC,Preprocessed_Dataset/cmmd/cmmd_1027.jpg,Malignant,,,calcification,Luminal B,Original_Dataset/cmmd/cmmd_1027.jpg,Masks/cmmd/cmmd_1027.jpg,,,,,69.0,D2-0002,CMMD/CMMD/D2-0002/07-18-2010-NA-NA-26354/1.000...
2,cmmd,L,CC,Preprocessed_Dataset/cmmd/cmmd_1028.jpg,Malignant,,,calcification,Luminal B,Original_Dataset/cmmd/cmmd_1028.jpg,Masks/cmmd/cmmd_1028.jpg,,,,,44.0,D2-0003,CMMD/CMMD/D2-0003/07-18-2010-NA-NA-57046/1.000...
3,cmmd,L,CC,Preprocessed_Dataset/cmmd/cmmd_1029.jpg,Malignant,,,calcification,Luminal B,Original_Dataset/cmmd/cmmd_1029.jpg,Masks/cmmd/cmmd_1029.jpg,,,,,38.0,D2-0004,CMMD/CMMD/D2-0004/07-18-2010-NA-NA-29234/1.000...
4,cmmd,L,CC,Preprocessed_Dataset/cmmd/cmmd_1030.jpg,Malignant,,,calcification,HER2-enriched,Original_Dataset/cmmd/cmmd_1030.jpg,Masks/cmmd/cmmd_1030.jpg,,,,,41.0,D2-0005,CMMD/CMMD/D2-0005/07-18-2010-NA-NA-26051/1.000...


In [5]:
mb_df.rename(columns={'source_subjectID': 'patientId', 'molecular_subtype': 'subtype'}, inplace=True)

In [6]:
print(mb_df['subtype'].value_counts())

subtype
Luminal B          1482
Luminal A           600
HER2-enriched       532
triple negative     342
Name: count, dtype: int64


In [7]:
mb_df['subtype'] = mb_df['subtype'].map({'Luminal B': 'luminal-b', 'Luminal A': 'luminal-a', 'HER2-enriched': 'her2-enriched', 'triple negative': 'triple-negative'})

In [8]:
mb_df['convertedPath'] = mb_df['preprocessed_image_path'].str.replace('Preprocessed_Dataset/cmmd/', '../data/raw/MammoBench/cmmd/')

In [9]:
mb_df.head()

Unnamed: 0,source_dataset,laterality,view,preprocessed_image_path,classification,density,BIRADS,abnormality,subtype,raw_image_path,mask_path,ROI_path,x,y,radius,subject_age,patientId,original_source_path,convertedPath
0,cmmd,L,CC,Preprocessed_Dataset/cmmd/cmmd_1026.jpg,Malignant,,,calcification,luminal-b,Original_Dataset/cmmd/cmmd_1026.jpg,Masks/cmmd/cmmd_1026.jpg,,,,,64.0,D2-0001,CMMD/CMMD/D2-0001/07-18-2011-NA-NA-75485/1.000...,../data/raw/MammoBench/cmmd/cmmd_1026.jpg
1,cmmd,L,CC,Preprocessed_Dataset/cmmd/cmmd_1027.jpg,Malignant,,,calcification,luminal-b,Original_Dataset/cmmd/cmmd_1027.jpg,Masks/cmmd/cmmd_1027.jpg,,,,,69.0,D2-0002,CMMD/CMMD/D2-0002/07-18-2010-NA-NA-26354/1.000...,../data/raw/MammoBench/cmmd/cmmd_1027.jpg
2,cmmd,L,CC,Preprocessed_Dataset/cmmd/cmmd_1028.jpg,Malignant,,,calcification,luminal-b,Original_Dataset/cmmd/cmmd_1028.jpg,Masks/cmmd/cmmd_1028.jpg,,,,,44.0,D2-0003,CMMD/CMMD/D2-0003/07-18-2010-NA-NA-57046/1.000...,../data/raw/MammoBench/cmmd/cmmd_1028.jpg
3,cmmd,L,CC,Preprocessed_Dataset/cmmd/cmmd_1029.jpg,Malignant,,,calcification,luminal-b,Original_Dataset/cmmd/cmmd_1029.jpg,Masks/cmmd/cmmd_1029.jpg,,,,,38.0,D2-0004,CMMD/CMMD/D2-0004/07-18-2010-NA-NA-29234/1.000...,../data/raw/MammoBench/cmmd/cmmd_1029.jpg
4,cmmd,L,CC,Preprocessed_Dataset/cmmd/cmmd_1030.jpg,Malignant,,,calcification,her2-enriched,Original_Dataset/cmmd/cmmd_1030.jpg,Masks/cmmd/cmmd_1030.jpg,,,,,41.0,D2-0005,CMMD/CMMD/D2-0005/07-18-2010-NA-NA-26051/1.000...,../data/raw/MammoBench/cmmd/cmmd_1030.jpg


In [11]:
train_df, test_df = make_grouped_holdout_split(mb_df, patient_col='patientId', subtype_col='subtype')

In [12]:
persist_splits(train_df, None, test_df, patient_col='patientId', subtype_col='subtype')