In [None]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

!pip install cloud-tpu-client==0.10 torch==2.0.0 torchvision==0.15.1 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp39-cp39-linux_x86_64.whl
!pip install -r ../requirements.txt
!gcloud init
!gcloud auth application-default login

In [None]:
import logging
import io
from pathlib import PurePath
from typing import List, Union, Optional
import pandas as pd
import numpy as np
from PIL import Image

from torch import from_numpy, Tensor
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, dataloader
import torchvision.transforms as T

from google.cloud import storage

import logging
log_fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
logging.basicConfig(level=logging.INFO, format=log_fmt)


# CheXpert pathologies on original paper
pathologies = ['Atelectasis',
               'Cardiomegaly',
               'Consolidation',
               'Edema',
               'Pleural Effusion']

# Uncertainty policies on original paper
uncertainty_policies = ['U-Ignore',
                        'U-Zeros',
                        'U-Ones',
                        'U-SelfTrained',
                        'U-MultiClass']

######################
## Create a Dataset ##
######################
class CheXpertDataset(Dataset):
    def __init__(self,
                 data_path: Union[str, None] = None,
                 uncertainty_policy: str = 'U-Ones',
                 logger: logging.Logger = logging.getLogger(__name__),
                 pathologies: List[str] = pathologies,
                 train: bool = True,
                 resize_shape: tuple = (384, 384)) -> None:
        """ Innitialize dataset and preprocess according to uncertainty policy.

        Args:
            data_path (str): Path to csv file.
            uncertainty_policy (str): Uncertainty policies compared in the original paper.
            Check if options are implemented. Options: 'U-Ignore', 'U-Zeros', 'U-Ones', 'U-SelfTrained', and 'U-MultiClass'.
            logger (logging.Logger): Logger to log events during training.
            pathologies (List[str], optional): Pathologies to classify.
            Defaults to 'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', and 'Pleural Effusion'.
            transform (type): method to transform image.
            train (bool): If true, returns data selected for training, if not, returns data selected for validation (dev set), as the CheXpert research group splitted.

        Returns:
            None
        """
        
        if not(uncertainty_policy in uncertainty_policies):
            logger.error(f"Unknown uncertainty policy. Known policies: {uncertainty_policies}")
            return None
        
        project_id = 'labshurb'

        storage_client = storage.Client(project=project_id)
        self.bucket = storage_client.bucket('chexpert_database_stanford')

        split = 'train' if train  else 'valid'
        csv_path = f"CheXpert-v1.0/{split}.csv"
        path = str(data_path) + csv_path

        data = pd.DataFrame()
        try:
            data = pd.read_csv(path)
        except Exception as e:
            try:
              blob = self.bucket.get_blob(csv_path)
              blob.download_to_filename('tmp.csv')
              data = pd.read_csv('tmp.csv')
            except:  
              logger.error(f"Couldn't read csv at path {path}.\n{e}")
              quit()

        data['Path'] = data['Path'] # data_path + 
        data.set_index('Path', inplace=True)

        #data = data.loc[data['Frontal/Lateral'] == 'Frontal'].copy()
        data = data.loc[:, pathologies].copy()
        
        data.fillna(0, inplace=True)

        # U-Ignore
        if uncertainty_policy == uncertainty_policies[0]:
            data = data.loc[(data[pathologies] != -1).all(axis=1)].copy()
        
        # U-Zeros
        elif uncertainty_policy == uncertainty_policies[1]:
            data.replace({-1: 0}, inplace=True)

        # U-Ones
        elif uncertainty_policy == uncertainty_policies[2]:
            data.replace({-1: 1}, inplace=True)

        # U-SelfTrained
        elif uncertainty_policy == uncertainty_policies[3]:
            logger.error(f"Uncertainty policy {uncertainty_policy} not implemented.")
            return None

        # U-MultiClass
        elif uncertainty_policy == uncertainty_policies[4]:
            # Do nothing and leave -1 as a label, but check if whole system works.
            logger.error(f"Uncertainty policy {uncertainty_policy} not implemented.")
            return None

        self.image_names = data.index.to_numpy()
        self.labels = data.loc[:, pathologies].to_numpy()
        self.transform = T.Compose([
                  T.Resize(resize_shape),
                  T.ToTensor(),
                  T.Normalize(mean=[0.5330], std=[0.0349])
              ]) # whiten with dataset mean and stdif transform)

    def __getitem__(self, index: int) -> Union[np.array, Tensor]:
        """ Returns image and label from given index.

        Args:
            index (int): Index of sample in dataset.

        Returns:
            np.array: Array of grayscale image.
            torch.Tensor: Tensor of labels.
        """
        img_bytes = self.bucket.blob(self.image_names[index]).download_as_bytes()#.download_to_filename('tmp.jpg')
        img = Image.open(io.BytesIO(img_bytes)).convert('RGB')
        img = self.transform(img)

        label = from_numpy(self.labels[index].astype(np.float32))
        return img, label

    def __len__(self) -> int:
        """ Return length of dataset.

        Returns:
            int: length of dataset.
        """
        return len(self.image_names)


#########################
## Create a DataLoader ##
#########################
def get_dataloader(data_path: str,
                   uncertainty_policy: str,
                   logger: logging.Logger,
                   batch_size: int,
                   pathologies: List[str] = pathologies,
                   train: bool = True,
                   shuffle: bool = True,
                   random_seed: int = 123,
                   num_workers: int = 4, 
                   pin_memory: bool = True,
                   apply_transform: bool = True,
                   resize_shape: tuple = (384, 384)):
    """Get wrap dataset with dataloader class to help with paralellization, data loading order 
    (for reproducibility) and makes the code o bit cleaner.

    Args:
        data_path (str): Refer to CheXpertDataset class documentation.
        uncertainty_policy (str): Refer to CheXpertDataset class documentation.
        logger (logging.Logger): Refer to CheXpertDataset class documentation.
        pathologies (List[str], optional): Refer to CheXpertDataset class documentation.
        train (bool): Refer to CheXpertDataset class documentation.
        shuffle (bool): Shuffle datasets (independently, train or valid).
        random_seed (int): Seed to shuffle data, helps with reproducibility.

    Returns:
        torch.utils.data.DataLoader: Data loader from dataset randomly (or not) loaded.
    """

    dataset = CheXpertDataset(
        data_path=data_path,
        uncertainty_policy=uncertainty_policy,
        pathologies=pathologies,
        logger=logger,
        train=train,
        resize_shape=resize_shape
        )
    
    indices = list(range(dataset.__len__()))
    if shuffle:
        np.random.seed(random_seed)
        np.random.shuffle(indices)

    sampler = SubsetRandomSampler(indices)
    return DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=num_workers,
        pin_memory=pin_memory,
        )


In [None]:
import lightning.pytorch as pl

from torchvision.models.efficientnet import efficientnet_v2_l
import torch

from torchmetrics.classification import MultilabelAUROC, MultilabelF1Score, MultilabelPrecisionRecallCurve

class LitEfficientnet(pl.LightningModule):
    def __init__(self,
                 num_classes:int=5,
                 lr=1e-3) -> None:
        super().__init__()
        self.save_hyperparameters()
        model = efficientnet_v2_l(weights="DEFAULT")
        model.classifier[1] = torch.nn.Linear(1280, num_classes)
        self.model = model
        self.criterion = torch.nn.CrossEntropyLoss()
        self.lr = lr

        self.auc = MultilabelAUROC(num_labels=num_classes, average=None, thresholds=None)
        self.f1 = MultilabelF1Score(num_labels=num_classes, average=None)
        self.pr_curve = MultilabelPrecisionRecallCurve(num_labels=num_classes, thresholds=None)

    def training_step(self, batch, batch_idx):
        data, target = batch
        output = self.model(data)
        train_loss = self.criterion(output, target)

        train_auc = self.f1(output, target)

        self.log_dict({"train_loss": train_loss, "train_auc": train_auc}, prog_bar=True, logger=True)
        return train_loss

    def validation_step(self, batch, batch_idx):
        data, target = batch
        output = self.model(data)
        val_loss = self.criterion(output, target)

        target = target.long()

        val_auc = self.auc(output, target)
        val_precision, val_recall, val_thresholds = self.pr_curve(output, target)
        val_f1 = self.f1(output, target)

        self.log_dict({
             "val_loss": val_loss,
             "val_auc": val_auc,
             "val_precision": val_precision,
             "val_recall": val_recall,
             "val_thresholds": val_thresholds,
             "val_f1": val_f1,
             },
            prog_bar=True, logger=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)



In [None]:
import logging
from tqdm import tqdm
import gc
import time
import numpy as np
import os

#from data import get_dataloader

import torch
import torch.optim as optim
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import WandbLogger

# imports the torch_xla package
import torch_xla
import torch_xla.core.xla_model as xm


import wandb

# Data paths default values
RAW_DATA_PATH = r"/project/data/raw/"
CHECKPOINT_PATH = 'models/ckpt/'
WANDB_PROJECT = 'Chest-X-Ray-Pathology-Classifier'


def train(input_filepath: str=None,
          uncertainty_policy: str='U-Ones',
          config = None) -> None:

    logger = logging.getLogger(__name__)   
    gc.collect() 

    with wandb.init(config=config):
        # Hyperparameters
        BATCH_SIZE = wandb.config.batch_size
        GRAD_ACC = wandb.config.gradient_accumulation_steps
        RESIZE_SHAPE = (384,384)
        LEARNING_RATE = wandb.config.learning_rate
        EPOCHS = wandb.config.epochs
        NUM_WORKERS = 0
        PIN_MEMORY = True
        NUM_CLASSES = 5

        model = LitEfficientnet(lr=LEARNING_RATE)

        wandb_logger = WandbLogger(project=WANDB_PROJECT, log_model="all")

        trainer = pl.Trainer(
            accelerator="tpu",
            devices=1,
            #default_root_dir=CHECKPOINT_PATH,
            callbacks=[EarlyStopping(
                            monitor="val_loss",
                            mode="min",
                            min_delta=0.01,
                            patience=3,
                            divergence_threshold=2.,
                            check_on_train_epoch_end=False)], #FinetuningScheduler() 
            #precision='bf16-mixed',
            logger=WandbLogger,
            max_epochs=EPOCHS,
            log_every_n_steps=100,
            enable_progress_bar=True,
            accumulate_grad_batches=GRAD_ACC,
            #profiler="advanced"
            )

        wandb_logger.watch(model, log="all", log_freq=1, log_graph=False)

        # Data loader
        train_dataloader = get_dataloader(data_path=input_filepath,
                                          uncertainty_policy=uncertainty_policy,
                                          logger=logger,
                                          train=True,
                                          batch_size=BATCH_SIZE,
                                          shuffle=True,
                                          num_workers=NUM_WORKERS,
                                          pin_memory=PIN_MEMORY,
                                          resize_shape=RESIZE_SHAPE)
        valid_dataloader = get_dataloader(data_path=input_filepath,
                                          uncertainty_policy=uncertainty_policy,
                                          logger=logger,
                                          train=False,
                                          batch_size=BATCH_SIZE,
                                          shuffle=True,
                                          num_workers=NUM_WORKERS,
                                          pin_memory=PIN_MEMORY,
                                          resize_shape=RESIZE_SHAPE)


        wandb.log({
            "Uncertainty policy": uncertainty_policy
        })

        trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader)#, ckpt_path=CHECKPOINT_PATH)

        wandb.run.summary["state"] = "completed"
        wandb.finish(quiet=True)


if __name__ == '__main__':
    log_fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    logging.basicConfig(level=logging.INFO, format=log_fmt)

    wandb.finish(quiet=True)
    wandb.login()
    
    # method
    sweep_config = {
        'method': 'bayes'
    }

    # hyperparameters
    parameters_dict = {
        'epochs': {
            'values': [20]
            },
        'gradient_accumulation_steps': {
            'values': [16, 32, 64]
            },
        'batch_size': {
            'values': [4, 8]
        },
        'learning_rate': {
            'distribution': 'log_uniform_values',
            'min': 1e-5,
            'max': 1e-3
        }
    }

    # metric
    sweep_metric = {
        'name': 'val_auc',
        'goal': 'maximize'
    }


    os.environ["WANDB_PROJECT"] = WANDB_PROJECT
    os.environ["WANDB_LOG_MODEL"] = "true"

    sweep_config['parameters'] = parameters_dict
    sweep_config['metric'] = sweep_metric

    sweep_id = wandb.sweep(sweep_config, project=WANDB_PROJECT)
    #sweep_id = 'sla620fp'

    wandb.agent(sweep_id, train, count=20)