# Relabel uncertainty

Relabeling of uncertainty for SelfTrained approach

In [48]:
import io
from typing import List, Union

import pandas as pd
import numpy as np
from PIL import Image

from torch import Tensor
from torch.utils.data import Dataset
import torchvision.transforms as T

from google.cloud import storage

import logging

# CheXpert pathologies on original paper
pathologies = ['Atelectasis',
               'Cardiomegaly',
               'Consolidation',
               'Edema',
               'Pleural Effusion']

# Uncertainty policies on original paper
uncertainty_policies = ['U-Ignore',
                        'U-Zeros',
                        'U-Ones',
                        'U-SelfTrained',
                        'U-MultiClass']


# #####################
# # Create a Dataset ##
# #####################
class UncetaintyOnlyCheXpertDataset(Dataset):
    def __init__(self,
                 data_path: Union[str, None] = None,
                 uncertainty_policy: str = 'U-Ones',
                 logger: logging.Logger = logging.getLogger(__name__),
                 pathologies: List[str] = pathologies,
                 train: bool = True,
                 resize_shape: tuple = (256, 256)) -> None:
        """ Innitialize dataset and preprocess according to uncertainty policy.

        Args:
            data_path (str): Path to csv file.
            uncertainty_policy (str): Uncertainty policies compared in the
            original paper.
            Check if options are implemented. Options: 'U-Ignore', 'U-Zeros',
            'U-Ones', 'U-SelfTrained', and 'U-MultiClass'.
            logger (logging.Logger): Logger to log events during training.
            pathologies (List[str], optional): Pathologies to classify.
            Defaults to 'Atelectasis', 'Cardiomegaly', 'Consolidation',
            'Edema', and 'Pleural Effusion'.
            transform (type): method to transform image.
            train (bool): If true, returns data selected for training, if not,
            returns data selected for validation (dev set), as the CheXpert
            research group splitted.

        Returns:
            None
        """

        if not (uncertainty_policy in uncertainty_policies):
            logger.error(
                "Unknown uncertainty policy. Known policies: " +
                f"{uncertainty_policies}")
            return None

        split = 'train' if train else 'valid'
        csv_path = f"CheXpert-v1.0/{split}.csv"
        path = str(data_path) + csv_path

        self.in_cloud = False

        data = pd.DataFrame()
        try:
            data = pd.read_csv(path)
            data['Path'] = data_path + data['Path']
            logger.info("Local database found.")
        except Exception as e:
            logger.warning(f"Couldn't read csv at path {path}./n{e}")
            try:
                # Find files at gcp
                project_id = 'labshurb'

                storage_client = storage.Client(project=project_id)
                self.bucket = storage_client.bucket(
                    'chexpert_database_stanford')

                blob = self.bucket.get_blob(csv_path)
                blob.download_to_filename('tmp.csv')
                data = pd.read_csv('tmp.csv')

                self.in_cloud = True
                logger.info("Cloud database found.")

            except Exception as e_:
                logger.error(f"Couldn't reach file at path {path}./n{e_}")
                quit()

        data.set_index('Path', inplace=True)

        # data = data.loc[data['Frontal/Lateral'] == 'Frontal'].copy()
        data = data.loc[:, pathologies].copy()

        # it will change for 15 in case of multiclass
        label_cols = 5

        data.fillna(0, inplace=True)

        # U-Ignore
        if uncertainty_policy == uncertainty_policies[0]:
            # the only change is in the loss function, we mask the -1 labels
            # in the calculation
            pass

        # U-Zeros
        elif uncertainty_policy == uncertainty_policies[1]:
            data.replace({-1: 0}, inplace=True)

        # U-Ones
        elif uncertainty_policy == uncertainty_policies[2]:
            data.replace({-1: 1}, inplace=True)

        # U-SelfTrained
        elif uncertainty_policy == uncertainty_policies[3]:
            logger.warning(
                f"Using {uncertainty_policy} uncertainty policy, " +
                "make sure there are no uncertainty labels in the dataset.")
            return None

        # U-MultiClass
        elif uncertainty_policy == uncertainty_policies[4]:
            #data.replace({-1: 2}, inplace=True)

            one_hot_0 = [1., 0., 0.]
            one_hot_1 = [0., 1., 0.]
            one_hot_2 = [0., 0., 1.]

            data.loc[:, pathologies] = data.map(lambda x: one_hot_0 if x == 0 else one_hot_1 if x == 1 else one_hot_2).to_numpy()

            label_cols = 15

        data = data[(data == -1).any(axis=1)].copy()

        self.image_names = data.index.to_numpy()
        self.labels = np.array(
            data.loc[:, pathologies].values.tolist()
            ).reshape((-1, label_cols))
        self.transform = T.Compose([
                  T.Resize(resize_shape),
                  T.ToTensor(),
                  T.Normalize(mean=[0.5330], std=[0.0349])
              ])  # whiten with dataset mean and stdif transform)


    def __getitem__(self, index: int) -> Union[np.array, Tensor]:
        """ Returns image and label from given index.

        Args:
            index (int): Index of sample in dataset.

        Returns:
            np.array: Array of grayscale image.
            torch.Tensor: Tensor of labels.
        """
        if self.in_cloud:
            img_bytes = self.bucket.blob(
                self.image_names[index]).download_as_bytes()
            # .download_to_filename('tmp.jpg')
            img = Image.open(io.BytesIO(img_bytes)).convert('RGB')

        else:
            img = Image.open(self.image_names[index]).convert('RGB')
        img = self.transform(img)

        label = self.labels[index].astype(np.float32)
        return {"pixel_values": img, "labels": label}

    def __len__(self) -> int:
        """ Return length of dataset.

        Returns:
            int: length of dataset.
        """
        return len(self.image_names)

In [49]:
import torch
from torch.utils.data import DataLoader

import sys
from transformers import (
    ViTForImageClassification,
    AutoConfig
)

import itertools
import pandas as pd
import numpy as np

sys.path.append('..')
from src.chexpert import CheXpertDataset

pathologies = ['Atelectasis',
                'Cardiomegaly',
                'Consolidation',
                'Edema',
                'Pleural Effusion']


def get_predictions(ckpts, approach, data_path, train=True):
    dataset = UncetaintyOnlyCheXpertDataset(
                data_path=data_path,
                uncertainty_policy=approach,
                train=train,
                resize_shape=(224, 224))
    dataloader = DataLoader(dataset, batch_size=234, shuffle=False)

    models = []
    for checkpoint in ckpts:
        model = ViTForImageClassification.from_pretrained(
            f"../output/25092023/google/vit-base-patch16-224/{approach}/checkpoint-{checkpoint}",
        ).eval()
        models.append(model)

    columns = pathologies
    if approach == 'U-MultiClass':
        columns = [comb[1]+comb[0] for comb in itertools.product(pathologies, ['neg_', 'pos_', 'unc_'])]

    general_output = []
    labels = pd.DataFrame(columns=pd.MultiIndex.from_product([['labels'], columns]))

    for i_model, model in enumerate(models):
        multiindex = pd.MultiIndex.from_product([[f'model_{i_model}'], columns], names=['model', 'pathology'])
        model_output = pd.DataFrame(columns=multiindex)
        for i_batch, sample_batched in enumerate(dataloader):
            with torch.no_grad():
                labels = pd.concat(
                    [
                        labels,
                        pd.DataFrame(sample_batched['labels'], columns=pd.MultiIndex.from_product([['labels'], columns])),
                    ],
                    axis=0,
                    ignore_index=True)
                
                model_output = pd.concat(
                    [
                        model_output,
                        pd.DataFrame(model(sample_batched['pixel_values']).logits.numpy(), columns=multiindex),
                    ],
                    axis=0,
                    ignore_index=True)
                    
        if len(general_output) == 0:
            general_output = pd.merge(labels, model_output, left_index=True, right_index=True).copy()
        else:
            general_output = pd.merge(general_output, model_output, left_index=True, right_index=True)
    return general_output

In [50]:
approach = 'U-Ignore'
id_ckpts = [4, 6, 7, 8, 9, 10, 11, 13, 14, 15]
data_path = r"C:/Users/hurbl/OneDrive/Área de Trabalho/Loon Factory/repository/Chest-X-Ray-Pathology-Classifier/data/raw/"

ckpts = [
    '1090', # 4
    '1526', # 6
    '1744', # 7
    '1962', # 8
    '2180', # 9
    '2398', # 10
    '2616', # 11
    '3052', # 13
    '3270', # 14
    '3488', # 15
    ]

ignore_results = get_predictions(
    ckpts,
    approach,
    data_path=data_path,
    train=True)

ignore_results.to_parquet(f'results/valid_{approach}_for_selftrained.pqt')

  labels = pd.concat(
  model_output = pd.concat(
  model_output = pd.concat(
  model_output = pd.concat(
  model_output = pd.concat(
  model_output = pd.concat(
  model_output = pd.concat(
  model_output = pd.concat(
  model_output = pd.concat(
  model_output = pd.concat(
  model_output = pd.concat(
  if _pandas_api.is_sparse(col):


In [None]:
ignore_results

Unnamed: 0_level_0,labels,labels,labels,labels,labels,model_0,model_0,model_0,model_0,model_0
Unnamed: 0_level_1,Atelectasis,Cardiomegaly,Consolidation,Edema,Pleural Effusion,Atelectasis,Cardiomegaly,Consolidation,Edema,Pleural Effusion
0,0.0,0.0,0.0,0.0,0.0,-1.192969,-2.480345,-3.160272,-1.461431,-1.767798
1,-1.0,-1.0,-1.0,-1.0,-1.0,-1.081272,-2.093229,-2.622588,-1.696654,-0.650424
2,0.0,0.0,-1.0,0.0,0.0,-1.365313,-3.297892,-2.282577,-2.757240,-0.223821
3,0.0,0.0,-1.0,0.0,0.0,-0.801268,-2.108772,-1.996155,-1.621541,1.135071
4,0.0,0.0,0.0,1.0,0.0,-1.406059,-2.037456,-3.590422,-1.265158,-2.009350
...,...,...,...,...,...,...,...,...,...,...
697,0.0,0.0,0.0,0.0,1.0,-1.072884,-2.707569,-2.094149,-0.956540,1.580486
698,0.0,0.0,0.0,0.0,0.0,-2.029691,-4.634951,-3.867409,-3.405506,-2.425060
699,0.0,0.0,0.0,0.0,0.0,-2.613808,-5.143997,-3.915619,-4.781992,-2.297644
700,0.0,0.0,0.0,0.0,0.0,-2.575031,-3.695695,-3.644741,-4.716095,-3.285170
