In [6]:
import pandas as pd
import glob
import os
from pathlib import Path

# Get repository root directory
# Try multiple methods to find the repo root
cwd = Path.cwd()
if (cwd / 'data').exists():
    REPO_ROOT = cwd
elif (cwd.parent / 'data').exists():
    REPO_ROOT = cwd.parent
else:
    # Fallback: assume we're in dev_notebooks and go up one level
    REPO_ROOT = cwd.parent

### Findings:

-- Hernia, mass, nodule, Fibrosis, Emphysema, Pneumothorax, Infiltration only have 0 labels in tr/val/test

In [7]:
# Investigate how good the tools are at predicting the labels
tr_df = pd.read_csv(REPO_ROOT / 'data' / 'openi' / 'labels' / 'Train.csv')
tr_df['id'] = tr_df['id'].apply(lambda x: str(x) + '.jpg')
tr_df.set_index('id', inplace=True)

tool_pred_csvs = glob.glob(str(REPO_ROOT / 'data' / 'openi' / 'predictions' / '*_train.csv'))

tool_metrics = {}
for tool in tool_pred_csvs:
    tool_name = tool.split('/')[-1][:-10]
    preds_df = pd.read_csv(tool).set_index('filename')
    # Check which columns have all values equal to 0.5
    cols_all_05 = [col for col in preds_df.columns if (preds_df[col] == 0.5).all()]
    preds_df.drop(cols_all_05, axis=1, inplace=True)
    preds_df = (preds_df > 0.5).astype(int)
    acc = (tr_df.loc[preds_df.index, preds_df.columns].values.flatten() == preds_df.values.flatten()).mean()

    tool_metrics[tool_name] = {
        'cols_doesnt_predict': cols_all_05,
        'acc': acc,
    }

In [8]:
tool_metrics

{'resnet_mgca_pt_openi': {'cols_doesnt_predict': ['Lung Lesion',
   'Fracture',
   'Lung Opacity',
   'Enlarged Cardiomediastinum'],
  'acc': np.float64(0.9494270435446907)},
 'densenet121_res224_chex': {'cols_doesnt_predict': ['Infiltration',
   'Emphysema',
   'Fibrosis',
   'Pleural_Thickening',
   'Nodule',
   'Mass',
   'Hernia'],
  'acc': np.float64(0.6384054448225571)},
 'densenet121_res224_all': {'cols_doesnt_predict': [],
  'acc': np.float64(0.6884135472370766)},
 'densenet_medical_mae_pt_openi': {'cols_doesnt_predict': ['Lung Lesion',
   'Fracture',
   'Lung Opacity',
   'Enlarged Cardiomediastinum'],
  'acc': np.float64(0.9497326203208556)},
 'densenet_mocov2_pt_openi': {'cols_doesnt_predict': ['Lung Lesion',
   'Fracture',
   'Lung Opacity',
   'Enlarged Cardiomediastinum'],
  'acc': np.float64(0.9493506493506494)},
 'densenet121_res224_mimic_nb': {'cols_doesnt_predict': ['Infiltration',
   'Emphysema',
   'Fibrosis',
   'Pleural_Thickening',
   'Nodule',
   'Mass',
   'Her

In [4]:
# Check which columns have all values equal to 0.5
cols_all_05 = [col for col in preds_df.columns if (preds_df[col] == 0.5).all()]
cols_all_05

[]

In [16]:
tr_df = pd.read_csv(REPO_ROOT / 'data' / 'openi' / 'labels' / 'Train.csv').set_index('id')
val_df = pd.read_csv(REPO_ROOT / 'data' / 'openi' / 'labels' / 'Valid.csv').set_index('id')
te_df = pd.read_csv(REPO_ROOT / 'data' / 'openi' / 'labels' / 'Test.csv').set_index('id')


In [22]:
for col in tr_df.columns:
    print(f"Pathology: {col}")
    print(f"tr: {tr_df[col].value_counts()}")
    print(f"val: {val_df[col].value_counts()}")
    print(f"te: {te_df[col].value_counts()}")
    print('-----')

Pathology: Atelectasis
tr: Atelectasis
0    820
1    115
Name: count, dtype: int64
val: Atelectasis
0    235
1     29
Name: count, dtype: int64
te: Atelectasis
0    140
1     16
Name: count, dtype: int64
-----
Pathology: Consolidation
tr: Consolidation
0    907
1     28
Name: count, dtype: int64
val: Consolidation
0    254
1     10
Name: count, dtype: int64
te: Consolidation
0    148
1      8
Name: count, dtype: int64
-----
Pathology: Infiltration
tr: Infiltration
0    935
Name: count, dtype: int64
val: Infiltration
0    264
Name: count, dtype: int64
te: Infiltration
0    156
Name: count, dtype: int64
-----
Pathology: Pneumothorax
tr: Pneumothorax
0    907
1     28
Name: count, dtype: int64
val: Pneumothorax
0    258
1      6
Name: count, dtype: int64
te: Pneumothorax
0    156
Name: count, dtype: int64
-----
Pathology: Edema
tr: Edema
0    919
1     16
Name: count, dtype: int64
val: Edema
0    260
1      4
Name: count, dtype: int64
te: Edema
0    154
1      2
Name: count, dtype: int64
