In [None]:
!nvidia-smi

In [None]:
DATASET = 'mimic_cxr'
REPORT_TO = 'wandb'

In [None]:
from __future__ import annotations

import json
import random
from collections import defaultdict
from itertools import chain
from pathlib import Path

import torch
from PIL import Image
from torch.utils.data import Dataset

from transformers import AutoModelForImageClassification, AutoFeatureExtractor, TrainingArguments, Trainer

random.seed(42)

In [None]:
label2id = {
    'No Finding': 0,
    'Atelectasis': 1,
    'Cardiomegaly': 2,
    'Consolidation': 3,
    'Edema': 4,
    'Enlarged Cardiomediastinum': 5,
    'Fracture': 6,
    'Lung Lesion': 7,
    'Lung Opacity': 8,
    'Pleural Effusion': 9,
    'Pneumonia': 10,
    'Pneumothorax': 11,
    'Pleural Other': 12,
    'Support Devices': 13,
    'Other': 14,
}

id2label = {v: k for k, v in label2id.items()}

labels = list(label2id.keys())

In [None]:
class XrayReportDataset(Dataset):
    """Dataset class that contains the X-ray images and reports."""

    def __init__(
        self,
        splits: list[str],
        image_dir: Path,
        ann_path: Path,
        id2label: dict,
        label2id: dict,
        transforms: AutoFeatureExtractor,
        sample: float = 1.0
    ):
        """Create train, validation and test datasets.

        Args:
            image_dir: path to directory of images
            ann_path: path to annotations json
            transforms: image transformations
        """
        super().__init__()
        if isinstance(splits, str):
            splits = [splits]
        
        self.image_dir = image_dir
        self.ann_path = ann_path
        self.id2label = id2label
        self.label2id = label2id
        self.transforms = transforms
        self.sample = sample
        self.splits = splits

        with open(self.ann_path, 'r') as f:
            self.annotations = json.load(f)
        
        self.data = list(chain(*[
            self.annotations[split]
            for split in self.splits
        ]))
        if 0.0 < self.sample < 1.0:
            total = max(int(len(self.data) * self.sample), 1)
            self.data = random.sample(self.data, total)

    def __len__(self) -> int:
        """Return the length of the dataset."""
        return len(self.data)

    def __getitem__(
        self, index: int
    ):
        """
        Retrieve an item from the dataset.
        Parameters
        ----------
        index
            dataset index
        Returns
        -------
        ret
            id of the image, transformed image, tokenized report, report attention mask
        """
        item = self.data[index]
        image = Image.open(self.image_dir / item['image_path'][0]).convert('RGB')
        image_transformed = self.transforms(image, return_tensors='pt')
        
        label_names = item['labels']
        if label_names == []:
            label_names = ['Other']
    
        labels = torch.zeros(len(self.id2label), dtype=torch.float)
        labels_ids = [self.label2id[l] for l in label_names]
        
        labels[labels_ids] = 1

        return {
            'labels_names': item['labels'],
            'labels': labels,
            'pixel_values': image_transformed['pixel_values'][0],
        }


class XrayReportData:
    """DataModule class that contains the X-ray images and reports."""

    def __init__(
        self,
        image_dir: Path,
        ann_path: Path,
        id2label: dict,
        label2id: dict,
        transforms: AutoFeatureExtractor,
        batch_size: int = 32,
        sample: float = 1.0
    ):
        super().__init__()
        self.image_dir = image_dir
        self.ann_path = ann_path
        self.transform = transforms
        self.id2label = id2label
        self.label2id = label2id
        self.batch_size = batch_size
        self.sample = sample
        data = self._setup()
        self.train = data['train']
        self.validation = data[('val', 'test')]

    def _setup(self):
        """Initialize the train, val and test datasets."""
        return {
            split: XrayReportDataset(
                split,
                self.image_dir,
                self.ann_path,
                self.id2label,
                self.label2id,
                self.transform,
                sample=self.sample
            ) for split in  ['train', ('val', 'test')]
        }

In [None]:
def collate_fn(batch: dict) -> dict:
    """Collate function from data to model.

    Args:
        batch:
            dataset batch

    Returns:
        pixel_values and labels
    """
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.stack([x['labels'] for x in batch])
    }

In [None]:
pre_trained_model = 'google/vit-base-patch16-224-in21k'

image_dir = Path(f'/kaggle/input/chestxraycaption/{DATASET}/{DATASET}/images')
annotations = Path(f'/kaggle/input/chestxraycaption/{DATASET}/{DATASET}/annotation.json')

In [None]:
transforms = AutoFeatureExtractor.from_pretrained(pre_trained_model)

In [None]:
data = XrayReportData(
    image_dir=image_dir,
    ann_path=annotations,
    id2label=id2label,
    label2id=label2id,
    transforms=transforms
)

In [None]:
 model = AutoModelForImageClassification.from_pretrained(
    pre_trained_model,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
    problem_type='multi_label_classification'
)

In [None]:
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score

def compute_metrics(output):
    predictions, references = output

    predictions = (predictions == predictions.max(axis=1, keepdims=True)).astype('int32')
    references = references.astype('int32')
    references = references & predictions
    
    return {
        'accuracy': accuracy_score(references, predictions),
        'f1': f1_score(references, predictions, average='weighted', zero_division=0),
        'precision': precision_score(references, predictions, average='weighted', zero_division=0),
        'recall': recall_score(references, predictions, average='weighted', zero_division=0)
    }

In [None]:
batch_size = 64

args = TrainingArguments(
    output_dir='vit-model',
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    evaluation_strategy='steps',
    save_strategy='steps',
    save_steps=250,
    eval_steps=250,
    num_train_epochs=1,
    logging_steps=50,
    optim='adamw_torch',
    learning_rate=2e-4,
    save_total_limit=1,
    remove_unused_columns=False,
    push_to_hub=False,
    report_to=REPORT_TO,
    load_best_model_at_end=True,
    metric_for_best_model='f1',
)

trainer = Trainer(
    model,
    args,
    train_dataset=data.train,
    eval_dataset=data.validation,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
    tokenizer=transforms,
)

In [None]:
if REPORT_TO == 'wandb':
    import wandb
    from kaggle_secrets import UserSecretsClient

    user_secrets = UserSecretsClient()
    WANDB_KEY = user_secrets.get_secret("WANDB_KEY")

    wandb.login(key=WANDB_KEY)

    wandb.init(
        project=f"chest-xray-classification-{DATASET}",
        config={
            "model": json.loads(model.config.to_json_string()),
            "args": json.loads(args.to_json_string())
        }
    )
    wandb.run.name = f'multiclass-{wandb.run.name}'

%env WANDB_LOG_MODEL=true

In [None]:
trainer.train()

In [None]:
wandb.finish()