In [1]:
import torch

from train_vit import CheXpertDataset
from transformers import (
    ViTForImageClassification,
    AutoConfig
)

import pandas as pd
import numpy as np

pathologies = ['Atelectasis',
                'Cardiomegaly',
                'Consolidation',
                'Edema',
                'Pleural Effusion']

def get_predictions(ckpts, approach, train):
    val_dataset = CheXpertDataset(
            data_path='../data/raw/',
            uncertainty_policy=approach,
            train=train,
            resize_shape=(224, 224))

    models = []

    for checkpoint in ckpts:
        config = AutoConfig.from_pretrained(f"../models/{approach}/checkpoint-{checkpoint}")
        model = ViTForImageClassification(config=config)
        models.append(model)

    ds_len = val_dataset.__len__()
    multiindex = pd.MultiIndex.from_product([pathologies, list(range(ds_len))])
    res = pd.DataFrame([], index=multiindex)

    for model_number, model in enumerate(models):
        res[f'model_{model_number}'] = np.NaN
        for sample in range(ds_len):
            with torch.no_grad():
                res.loc[
                    (slice(None), sample),
                    f'model_{model_number}'] = model(val_dataset.__getitem__(sample)['pixel_values'][None, :]).logits.numpy()[0]
            
    return res

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ckpts = ['622', '1244', '1867', '2489', '3111', '3734', '4356', '4979', '5601', '6220']
approach = 'U-Ignore'
train = False

ignore_results = get_predictions(ckpts, approach, train)
ignore_results.to_parquet(f'{approach}.pqt')

2023-07-04 18:48:59,502 - train_vit - INFO - Local database found.


In [5]:
ckpts = ['872', '1745', '2618', '3490', '4363', '5236', '6108', '6981', '7854', '8720']
approach = 'U-MultiClass'
train = False

multiclass_results = get_predictions(ckpts, approach, train)
multiclass_results.to_parquet(f'{approach}.pqt')

2023-07-04 18:58:32,304 - train_vit - INFO - Local database found.


In [6]:
ckpts = ['872', '1745', '2618', '3490', '4363', '5236', '6108', '6981', '7854', '8720']
approach = 'U-Ones'
train = False

ones_results = get_predictions(ckpts, approach, train)
ones_results.to_parquet(f'{approach}.pqt')

2023-07-04 19:04:06,922 - train_vit - INFO - Local database found.


In [7]:
ckpts = ['872', '1745', '2618', '3490', '4363', '5236', '6108', '6981', '7854', '8720']
approach = 'U-Zeros'
train = False

zeros_results = get_predictions(ckpts, approach, train)
zeros_results.to_parquet(f'{approach}.pqt')

2023-07-04 19:09:33,190 - train_vit - INFO - Local database found.


# Experimental Results

# Relabel uncertainty

In [None]:
ckpts = ['622', '1244', '1867', '2489', '3111', '3734', '4356', '4979', '5601', '6220']
approach = 'U-Ignore'
train = False

ignore_results = get_predictions(ckpts, approach, train)
ignore_results.to_parquet(f'{approach}.pqt')