In [29]:
import logging
import io
from typing import List, Union
import pandas as pd
import numpy as np
from PIL import Image

from torch import from_numpy, Tensor
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import torchvision.transforms as T

from google.cloud import storage

import logging
log_fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
logging.basicConfig(level=logging.INFO, format=log_fmt)


# CheXpert pathologies on original paper
pathologies = ['Atelectasis',
               'Cardiomegaly',
               'Consolidation',
               'Edema',
               'Pleural Effusion']

# Uncertainty policies on original paper
uncertainty_policies = ['U-Ignore',
                        'U-Zeros',
                        'U-Ones',
                        'U-SelfTrained',
                        'U-MultiClass']

######################
## Create a Dataset ##
######################
class CheXpertDataset(Dataset):
    def __init__(self,
                 data_path: Union[str, None] = None,
                 uncertainty_policy: str = 'U-Ones',
                 logger: logging.Logger = logging.getLogger(__name__),
                 pathologies: List[str] = pathologies,
                 train: bool = True,
                 resize_shape: tuple = (384, 384)) -> None:
        """ Innitialize dataset and preprocess according to uncertainty policy.

        Args:
            data_path (str): Path to csv file.
            uncertainty_policy (str): Uncertainty policies compared in the original paper.
            Check if options are implemented. Options: 'U-Ignore', 'U-Zeros', 'U-Ones', 'U-SelfTrained', and 'U-MultiClass'.
            logger (logging.Logger): Logger to log events during training.
            pathologies (List[str], optional): Pathologies to classify.
            Defaults to 'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', and 'Pleural Effusion'.
            transform (type): method to transform image.
            train (bool): If true, returns data selected for training, if not, returns data selected for validation (dev set), as the CheXpert research group splitted.

        Returns:
            None
        """
        
        if not(uncertainty_policy in uncertainty_policies):
            logger.error(f"Unknown uncertainty policy. Known policies: {uncertainty_policies}")
            return None
        
        project_id = 'labshurb'

        storage_client = storage.Client(project=project_id)
        self.bucket = storage_client.bucket('chexpert_database_stanford')

        split = 'train' if train  else 'valid'
        csv_path = f"CheXpert-v1.0/{split}.csv"
        path = str(data_path) + csv_path

        data = pd.DataFrame()
        try:
            data = pd.read_csv(path)
        except Exception as e:
            try:
              blob = self.bucket.get_blob(csv_path)
              blob.download_to_filename('tmp.csv')
              data = pd.read_csv('tmp.csv')
            except:  
              logger.error(f"Couldn't read csv at path {path}.\n{e}")
              quit()

        data['Path'] = data['Path'] # data_path + 
        data.set_index('Path', inplace=True)

        #data = data.loc[data['Frontal/Lateral'] == 'Frontal'].copy()
        data = data.loc[:, pathologies].copy()
        
        data.fillna(0, inplace=True)

        # U-Ignore
        if uncertainty_policy == uncertainty_policies[0]:
            data = data.loc[(data[pathologies] != -1).all(axis=1)].copy()
        
        # U-Zeros
        elif uncertainty_policy == uncertainty_policies[1]:
            data.replace({-1: 0}, inplace=True)

        # U-Ones
        elif uncertainty_policy == uncertainty_policies[2]:
            data.replace({-1: 1}, inplace=True)

        # U-SelfTrained
        elif uncertainty_policy == uncertainty_policies[3]:
            logger.error(f"Uncertainty policy {uncertainty_policy} not implemented.")
            return None

        # U-MultiClass
        elif uncertainty_policy == uncertainty_policies[4]:
            # Do nothing and leave -1 as a label, but check if whole system works.
            logger.error(f"Uncertainty policy {uncertainty_policy} not implemented.")
            return None

        self.image_names = data.index.to_numpy()
        self.labels = data.loc[:, pathologies].to_numpy()
        self.transform = T.Compose([
                  T.Resize(resize_shape),
                  T.ToTensor(),
                  T.Normalize(mean=[0.5330], std=[0.0349])
              ]) # whiten with dataset mean and stdif transform)

    def __getitem__(self, index: int) -> Union[np.array, Tensor]:
        """ Returns image and label from given index.

        Args:
            index (int): Index of sample in dataset.

        Returns:
            np.array: Array of grayscale image.
            torch.Tensor: Tensor of labels.
        """
        img_bytes = self.bucket.blob(self.image_names[index]).download_as_bytes()#.download_to_filename('tmp.jpg')
        img = Image.open(io.BytesIO(img_bytes)).convert('RGB')
        img = self.transform(img)

        label = self.labels[index].astype(np.float32)
        return {"pixel_values": img, "labels": label}

    def __len__(self) -> int:
        """ Return length of dataset.

        Returns:
            int: length of dataset.
        """
        return len(self.image_names)


In [30]:
train_dataset = CheXpertDataset(
    data_path="gcs://chexpert_database_stanford/",
    uncertainty_policy="U-Ones",
    train=True,
    resize_shape=(224, 224))


val_dataset = CheXpertDataset(
    data_path="gcs://chexpert_database_stanford/",
    uncertainty_policy="U-Ones",
    train=False,
    resize_shape=(224, 224))

In [31]:
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224", 
    problem_type="multi_label_classification",
    num_labels=5,
    ignore_mismatched_sizes=True
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([5, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([5]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [32]:
import evaluate
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

METRICS = evaluate.combine(["accuracy", "f1", "precision", "recall"])

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return METRICS.compute(predictions=predictions, references=labels)


args = TrainingArguments(
    f"vit_xray",
    remove_unused_columns=False,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    optim='adamw_torch',
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy"
)

trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)



In [33]:
train_results = trainer.train()

#trainer.save_model()
#trainer.log_metrics("train", train_results.metrics)
#trainer.save_metrics("train", train_results.metrics)
#trainer.save_state()


[A
  0%|          | 0/10473 [06:14<?, ?it/s]
  0%|          | 0/10473 [04:37<?, ?it/s]


In [None]:
metrics = trainer.evaluate()
# some nice to haves:
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

In [1]:
import torch
torch.cuda.empty_cache()
torch.cuda.is_available()

True

In [2]:
torch.zeros(1).cuda()

tensor([0.], device='cuda:0')

In [66]:
from torch import nn
import torch
from torch import masked_select

labels = torch.ones([10, 5], dtype=torch.float32)  # 64 classes, batch size = 10
logits = torch.full([10, 5], 0.5)  # A prediction (logit)

labels[2,3] = -1.
labels[3,3] = -1.
labels[6,1] = -1.
labels[8,4] = -1.

mask = labels > -1.

masked_labels = masked_select(labels.view(-1, 5), mask)
masked_logits = masked_select(logits.view(-1, 5), mask)

loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(masked_logits, masked_labels)

In [67]:
loss

tensor(0.5541)

In [65]:
logits, masked_logits

(tensor([[0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000]]),
 tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000]))

In [85]:
logits.view(-1)

tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000])