# HIPT TCGA BRCA Splits statistics

In [None]:
import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from pathlib import Path

In [None]:
os.chdir('/data/pathology/projects/ais-cap/code/git/clemsgrs/hipt')

### Loading splits

In [None]:
fold_num = 0
dataset_name = 'tcga_brca'

In [None]:
fold_dir = Path(f'data/{dataset_name}/splits/fold_{fold_num}')
train_df = pd.read_csv(Path(fold_dir, 'train.csv'))
train_df.head()

In [None]:
train_df.label.value_counts().reset_index().rename(columns={'index': 'label', 'label': 'count'})

What do we want to plot:

- label distribution
- nb slide

In [None]:
tmp = train_df[['slide_id', 'label']]
tmp['subtype'] = tmp['label'].apply(lambda x: 'IDC' if x == 0 else 'ILC')

In [None]:
ax = sns.countplot(data=tmp, x='subtype')
ax.bar_label(ax.containers[0], padding=5)
plt.xlabel('label')
plt.ylabel('# slide')
plt.ylim(0,699)
plt.title(f'fold_{fold_num}/train', pad=10)
plt.show()

In [None]:
dfs = []
nfold = 10
for i in range(nfold):
    fold_dir = Path(f'data/{dataset_name}/splits/fold_{i}')
    train_df = pd.read_csv(Path(fold_dir, 'train.csv'))
    dfs.append(train_df.label.value_counts().reset_index().rename(columns={'index': 'label', 'label': 'count'}))

df = pd.concat(dfs)
df['subtype'] = df['label'].apply(lambda x: 'IDC' if x == 0 else 'ILC')
df.head()

In [None]:
ax = sns.barplot(data=df, x='subtype', y='count', errorbar='sd')
ax.bar_label(ax.containers[0], padding=5)
plt.xlabel('label')
plt.ylabel('# slide')
plt.ylim(0,699)
plt.title(f'Average Sample Count (train)', pad=10)
plt.show()

In [None]:
dfs = []
nfold = 10
for i in range(nfold):
    fold_dir = Path(f'data/{dataset_name}/splits/fold_{i}')
    for partition in ['train', 'tune', 'test']:
        df = pd.read_csv(Path(fold_dir, f'{partition}.csv'))
        df = df.label.value_counts().reset_index().rename(columns={'index': 'label', 'label': 'count'})
        df['partition'] = [f'{partition}']*len(df)
        dfs.append(df)

df = pd.concat(dfs)
df['subtype'] = df['label'].apply(lambda x: 'IDC' if x == 0 else 'ILC')
df.head()

In [None]:
pad = 10
save = True

plt.figure(dpi=100)
ax = sns.barplot(data=df, x='subtype', y='count', hue='partition', errorbar='sd')
for i, p in enumerate(ax.patches):
    errbar = ax.lines[i].get_data()[1]
    std = errbar[1] - errbar[0]
    
    x, y = p.get_x(), p.get_y()
    w, h = p.get_width(), p.get_height()
    txt_x, txt_y = x+w/2, h+pad+std*2
    
    plt.text(txt_x, txt_y, f'{h}', horizontalalignment='center')

plt.xlabel('label')
plt.ylabel('# slide')
plt.ylim(0,699)
plt.title(f'Average Sample Count', pad=10)
plt.tight_layout()
if save:
    plt.savefig('average_sample_count.png', dpi=300)
plt.show()