In [86]:
import torch
from torch.utils.data import DataLoader

from train_vit import CheXpertDataset
from transformers import (
    ViTForImageClassification,
    AutoConfig
)

import pandas as pd
import numpy as np

pathologies = ['Atelectasis',
                'Cardiomegaly',
                'Consolidation',
                'Edema',
                'Pleural Effusion']

def get_predictions(ckpts, approach, train=False):
    dataset = CheXpertDataset(
                data_path='../data/raw/',
                uncertainty_policy=approach,
                train=False,
                resize_shape=(224, 224))
    dataloader = DataLoader(dataset, batch_size=dataset.__len__(), shuffle=False)


    models = []
    for checkpoint in ckpts:
        config = AutoConfig.from_pretrained(f"../models/base-vit/{approach}/checkpoint-{checkpoint}")
        model = ViTForImageClassification(config=config).eval()
        models.append(model)

    general_output = []
    for i_model, model in enumerate(models):
        multiindex = pd.MultiIndex.from_product([[f'model_{i_model}'], pathologies], names=['model', 'pathology'])
        for i_batch, sample_batched in enumerate(dataloader):
            with torch.no_grad():
                model_output = pd.DataFrame(model(sample_batched['pixel_values']).logits.numpy(), columns=multiindex)
                labels = pd.DataFrame(sample_batched['labels'], columns=pd.MultiIndex.from_product([['labels'], pathologies]))

        if len(general_output) == 0:
            general_output = pd.merge(labels, model_output, left_index=True, right_index=True).copy()
        else:
            general_output = pd.merge(general_output, model_output, left_index=True, right_index=True)
    return general_output

In [88]:
ckpts = ['622', '1244', '1867', '2489', '3111', '3734', '4356', '4979', '5601', '6220']
approach = 'U-Ignore'

ignore_results = get_predictions(ckpts, approach)
ignore_results.to_parquet(f'results/{approach}.pqt')

2023-08-21 15:05:05,203 - train_vit - INFO - Local database found.


In [90]:
ckpts = ['872', '1745', '2618', '3490', '4363', '5236', '6108', '6981', '7854', '8720']
approach = 'U-MultiClass'

multiclass_results = get_predictions(ckpts, approach)
multiclass_results.to_parquet(f'results/{approach}.pqt')

2023-08-21 15:10:50,367 - train_vit - INFO - Local database found.


In [91]:
ckpts = ['872', '1745', '2618', '3490', '4363', '5236', '6108', '6981', '7854', '8720']
approach = 'U-Ones'

ones_results = get_predictions(ckpts, approach)
ones_results.to_parquet(f'results/{approach}.pqt')

2023-08-21 15:15:11,758 - train_vit - INFO - Local database found.


In [92]:
ckpts = ['872', '1745', '2618', '3490', '4363', '5236', '6108', '6981', '7854', '8720']
approach = 'U-Zeros'

zeros_results = get_predictions(ckpts, approach)
zeros_results.to_parquet(f'results/{approach}.pqt')

2023-08-21 15:19:32,726 - train_vit - INFO - Local database found.


# Experimental Results

In [2]:
import pandas as pd

ignore_results = pd.read_parquet('results/U-Ignore.pqt')
multiclass_results = pd.read_parquet('results/U-MultiClass.pqt')
ones_results = pd.read_parquet('results/U-Zeros.pqt')
zeros_results = pd.read_parquet('results/U-Ones.pqt')

In [51]:
from sklearn.metrics import roc_auc_score, RocCurveDisplay
import math

def sigmoid(x):
    return 1 / (1 + math.exp(-x))

pathologies = ['Atelectasis',
                'Cardiomegaly',
                'Consolidation',
                'Edema',
                'Pleural Effusion']

models_names = [f"model_{i_model}" for i_model in range(10)]
multiindex = pd.MultiIndex.from_product([models_names, pathologies], names=['model', 'pathology'])

def get_roc_score(model_results):
    pred_total = pd.DataFrame(index=pd.Index(range(len(model_results))), columns=multiindex)
    roc_score_results = pd.DataFrame([], columns=models_names, index=pathologies)
    roc_score_results['mean'] = 0
    roc_curve = pd.DataFrame([], columns=models_names, index=pathologies)

    for pathology in pathologies:
        true = model_results.loc[:, ('labels', pathology)]

        for model in models_names:
            pred = model_results.loc[:, (model, pathology)].apply(sigmoid)
            pred_total.loc[:, (model, pathology)] = pred    
            
            roc_score_results.loc[pathology, model] = roc_auc_score(true, pred)
            #roc_curve.loc[pathology, model] = RocCurveDisplay.from_predictions(true, pred)

        mean_pred = model_results.loc[:, (model_results.columns.get_level_values(0) != 'labels', 'Atelectasis')].mean(axis=1)
        roc_score_results.loc[pathology, 'mean'] = roc_auc_score(true, mean_pred)
    #mean_proba = model_results.loc[:, model_results.columns.get_level_values(0) != 'labels'].mean(axis=1)
    #true
    #roc_score_results['mean'] = roc_auc_score(true, pred)

    return roc_score_results#, roc_curve

In [52]:
ones_auc_score = get_roc_score(ones_results) #, ones_auc_curve
ones_auc_score

Unnamed: 0,model_0,model_1,model_2,model_3,model_4,model_5,model_6,model_7,model_8,model_9,mean
Atelectasis,0.556169,0.440422,0.394399,0.657386,0.535877,0.621753,0.343588,0.456412,0.47638,0.400487,0.524351
Cardiomegaly,0.579288,0.463235,0.481928,0.454731,0.419826,0.400248,0.538891,0.575213,0.630758,0.409284,0.501949
Consolidation,0.581185,0.541987,0.469019,0.567616,0.375999,0.652797,0.410071,0.517564,0.413086,0.370722,0.535354
Edema,0.351911,0.51311,0.427866,0.642563,0.383774,0.630453,0.412111,0.450088,0.351911,0.567784,0.571664
Pleural Effusion,0.436411,0.598177,0.417732,0.335419,0.604165,0.516221,0.621861,0.456341,0.573688,0.523103,0.523818


0     -0.140286
1     -0.246217
2     -0.233676
3     -0.226710
4     -0.169320
         ...   
229   -0.148064
230   -0.126825
231   -0.184965
232   -0.194859
233   -0.181755
Length: 234, dtype: float32

In [134]:
ones_results.loc[:, ('labels', 'Atelectasis')]

0      0.0
1      0.0
2      0.0
3      0.0
4      0.0
      ... 
229    0.0
230    0.0
231    0.0
232    0.0
233    1.0
Name: (labels, Atelectasis), Length: 234, dtype: float32

In [91]:
ignore_auc_score = get_roc_score(pathologies, ignore_results, labels)
multi_auc_score = get_roc_score(pathologies, multiclass_results, labels)
ones_auc_score = get_roc_score(pathologies, ones_results, labels)
zeros_auc_score = get_roc_score(pathologies, zeros_results, labels)

# Relabel uncertainty

In [None]:
ckpts = ['622', '1244', '1867', '2489', '3111', '3734', '4356', '4979', '5601', '6220']
approach = 'U-Ignore'
train = False

ignore_results = get_predictions(ckpts, approach, train)
ignore_results.to_parquet(f'{approach}.pqt')