In [4]:
import torch

from train_vit import CheXpertDataset
from transformers import (
    ViTForImageClassification,
    AutoConfig
)

import pandas as pd
import numpy as np

pathologies = ['Atelectasis',
                'Cardiomegaly',
                'Consolidation',
                'Edema',
                'Pleural Effusion']

def get_predictions(ckpts, approach, train):
    val_dataset = CheXpertDataset(
            data_path='../data/raw/',
            uncertainty_policy=approach,
            train=train,
            resize_shape=(224, 224))

    models = []

    for checkpoint in ckpts:
        config = AutoConfig.from_pretrained(f"../models/{approach}/checkpoint-{checkpoint}")
        model = ViTForImageClassification(config=config)
        models.append(model)

    ds_len = val_dataset.__len__()
    multiindex = pd.MultiIndex.from_product([pathologies, list(range(ds_len))])
    res = pd.DataFrame([], index=multiindex)

    for model_number, model in enumerate(models):
        res[f'model_{model_number}'] = np.NaN
        for sample in range(ds_len):
            with torch.no_grad():
                res.loc[
                    (slice(None), sample),
                    f'model_{model_number}'] = model(val_dataset.__getitem__(sample)['pixel_values'][None, :]).logits.numpy()[0]
            
    return res, val_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [49]:
_, val_dataset = get_predictions([], 'U-Zeros', False)

labels = pd.DataFrame([], columns=pathologies, index=list(range(val_dataset.__len__())))

labels = labels.apply(lambda x: val_dataset.__getitem__(x.name)['labels'], axis=1)
labels = pd.DataFrame(labels.values.tolist(), index=labels.index, columns=pathologies)

labels.to_parquet('labels.pqt')

In [2]:
ckpts = ['622', '1244', '1867', '2489', '3111', '3734', '4356', '4979', '5601', '6220']
approach = 'U-Ignore'
train = False

ignore_results, _ = get_predictions(ckpts, approach, train)
ignore_results.to_parquet(f'{approach}.pqt')

2023-07-04 18:48:59,502 - train_vit - INFO - Local database found.


In [5]:
ckpts = ['872', '1745', '2618', '3490', '4363', '5236', '6108', '6981', '7854', '8720']
approach = 'U-MultiClass'
train = False

multiclass_results, _ = get_predictions(ckpts, approach, train)
multiclass_results.to_parquet(f'{approach}.pqt')

2023-07-04 18:58:32,304 - train_vit - INFO - Local database found.


In [6]:
ckpts = ['872', '1745', '2618', '3490', '4363', '5236', '6108', '6981', '7854', '8720']
approach = 'U-Ones'
train = False

ones_results, _ = get_predictions(ckpts, approach, train)
ones_results.to_parquet(f'{approach}.pqt')

2023-07-04 19:04:06,922 - train_vit - INFO - Local database found.


In [7]:
ckpts = ['872', '1745', '2618', '3490', '4363', '5236', '6108', '6981', '7854', '8720']
approach = 'U-Zeros'
train = False

zeros_results, _ = get_predictions(ckpts, approach, train)
zeros_results.to_parquet(f'{approach}.pqt')

2023-07-04 19:09:33,190 - train_vit - INFO - Local database found.


# Experimental Results

In [89]:
import pandas as pd

ignore_results = pd.read_parquet('U-Ignore.pqt')
multiclass_results = pd.read_parquet('U-MultiClass.pqt')
ones_results = pd.read_parquet('U-Zeros.pqt')
zeros_results = pd.read_parquet('U-Ones.pqt')

labels = pd.read_parquet('labels.pqt')

In [90]:
from sklearn.metrics import roc_auc_score

def get_roc_score(pathologies, model_results, true_labels):
    roc_score_results = pd.DataFrame([], columns=pathologies)

    for pathology in pathologies:
        for index, model in enumerate(model_results.columns):
            roc_score_results.loc[index, pathology] = roc_auc_score(true_labels.loc[:, pathology], model_results.loc[pathology, model])

    return roc_score_results

In [91]:
ignore_auc_score = get_roc_score(pathologies, ignore_results, labels)
multi_auc_score = get_roc_score(pathologies, multiclass_results, labels)
ones_auc_score = get_roc_score(pathologies, ones_results, labels)
zeros_auc_score = get_roc_score(pathologies, zeros_results, labels)

In [92]:
ignore_auc_score

Unnamed: 0,Atelectasis,Cardiomegaly,Consolidation,Edema,Pleural Effusion
0,0.463799,0.378632,0.755314,0.609759,0.39512
1,0.432711,0.434887,0.392281,0.617519,0.540888
2,0.385471,0.571226,0.306196,0.443034,0.511931
3,0.593101,0.473246,0.515302,0.510523,0.464742
4,0.561688,0.371456,0.620534,0.636567,0.62052
5,0.371023,0.410613,0.447158,0.588595,0.532755
6,0.497403,0.537473,0.435399,0.372016,0.41648
7,0.452354,0.488926,0.448817,0.408348,0.446599
8,0.527679,0.378012,0.307704,0.647854,0.432568
9,0.412744,0.56786,0.569124,0.46067,0.52641


In [93]:
multi_auc_score

Unnamed: 0,Atelectasis,Cardiomegaly,Consolidation,Edema,Pleural Effusion
0,0.428247,0.528172,0.506106,0.555791,0.531504
1,0.393588,0.43967,0.533997,0.461023,0.525159
2,0.683929,0.558381,0.371174,0.453851,0.630351
3,0.461201,0.552622,0.598221,0.615168,0.405845
4,0.470617,0.499646,0.658526,0.55485,0.455537
5,0.582873,0.532778,0.342982,0.449148,0.612208
6,0.497078,0.572466,0.649631,0.408348,0.527929
7,0.337906,0.430634,0.484999,0.568842,0.54786
8,0.566721,0.5854,0.606061,0.584715,0.566628
9,0.594481,0.484851,0.504598,0.410347,0.392439


In [94]:
ones_auc_score

Unnamed: 0,Atelectasis,Cardiomegaly,Consolidation,Edema,Pleural Effusion
0,0.40625,0.502481,0.513795,0.575897,0.467691
1,0.656899,0.483434,0.535354,0.456437,0.422111
2,0.399432,0.455439,0.473391,0.532393,0.520511
3,0.509416,0.437544,0.383838,0.448442,0.389668
4,0.62914,0.523831,0.645108,0.506526,0.486639
5,0.469237,0.54651,0.559928,0.669136,0.385915
6,0.559334,0.513643,0.369818,0.395532,0.47502
7,0.478653,0.481662,0.630032,0.426925,0.481455
8,0.328571,0.476258,0.400121,0.68736,0.52927
9,0.55974,0.574415,0.578622,0.500176,0.587631


In [95]:
zeros_auc_score

Unnamed: 0,Atelectasis,Cardiomegaly,Consolidation,Edema,Pleural Effusion
0,0.410471,0.508239,0.471883,0.451382,0.499062
1,0.655519,0.389794,0.494648,0.462669,0.500938
2,0.500731,0.589653,0.773406,0.61458,0.55778
3,0.460065,0.545446,0.564601,0.555085,0.348288
4,0.511607,0.423016,0.630182,0.29077,0.403879
5,0.485552,0.505758,0.414744,0.380482,0.446867
6,0.410552,0.521173,0.560531,0.402469,0.49924
7,0.352597,0.521793,0.53671,0.402822,0.400483
8,0.596347,0.621456,0.571536,0.422928,0.46501
9,0.463636,0.497077,0.455299,0.569195,0.527661


# Relabel uncertainty

In [None]:
ckpts = ['622', '1244', '1867', '2489', '3111', '3734', '4356', '4979', '5601', '6220']
approach = 'U-Ignore'
train = False

ignore_results = get_predictions(ckpts, approach, train)
ignore_results.to_parquet(f'{approach}.pqt')