In [66]:
from datasets import load_dataset
from tqdm import tqdm
from datasets import Dataset
import torch 
import matplotlib.pyplot as plt 
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import tqdm
from tqdm import trange, tqdm
import numpy as np
from sklearn.metrics import accuracy_score
import pandas as pd
from sklearn.metrics import multilabel_confusion_matrix
import os 
import seaborn as sns
from PIL import Image

In [67]:
ds_train = load_dataset("alkzar90/NIH-Chest-X-ray-dataset", 'image-classification', split = "train[:500]") 

# we hold back our test data to be used purely for testing, not in the context of our training loop 
ds_test = load_dataset("alkzar90/NIH-Chest-X-ray-dataset", 'image-classification', split = "test[:500]") 



In [68]:
from transformers import ViTForImageClassification
ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", 
                                          num_labels=15, 
                                          ignore_mismatched_sizes=True
                                          )


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([15]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([15, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [69]:
from torch.nn import functional as F
from torch import optim 

from transformers import ViTForImageClassification
import torchmetrics 

import lightning as L

In [132]:
# continue using pytorch ligthing to fine tune model as per 

# https://towardsdatascience.com/how-to-fine-tune-a-pretrained-vision-transformer-on-satellite-data-d0ddd8359596/#:~:text=Under%20the%20hood%2C%20the%20trainer,is%20completed%20within%20few%20epochs.


class VisionTransformerPretrained(L.LightningModule): 
    '''wrapper for the pretrained vision transformers'''

    def __init__(self, model = "google/vit-base-patch16-224", num_classes = 15, learning_rate = 1e-4):

        super().__init__()
        self.learning_rate = learning_rate 
        self.num_classes = num_classes
        backbone = ViTForImageClassification.from_pretrained(model, 
                                                             num_labels = num_classes, 
                                                             ignore_mismatched_sizes=True)
        
        self.backbone = backbone 

        self.loss_fn = nn.BCEWithLogitsLoss() # adjusted for our task of multilabel 

        #metrics 

        self.acc = torchmetrics.Accuracy("multilabel", num_labels=num_classes, threshold = 0.5)

        self.f1 = torchmetrics.F1Score(task="multilabel", num_labels=num_classes, average=None)  # Per-label F1
        self.precision = torchmetrics.Precision(task="multilabel", num_labels=num_classes, average=None)
        self.recall = torchmetrics.Recall(task="multilabel", num_labels=num_classes, average=None)
        self.accuracy = torchmetrics.Accuracy(task="multilabel", num_labels=num_classes, average=None)



    def forward(self, x): 
        return self.backbone(x).logits
    
    def step(self, batch, stage = "train"):
        '''Any step proccesses to return loss and predictions'''

        x, y = batch 

        logits = self.forward(x)
        y_hat = (torch.sigmoid(logits)>0.5).float() # we need to 

        loss = self.loss_fn(logits, y.float())
        acc = self.acc(y_hat, y)

        return loss, acc, y_hat, y
    
    def training_step(self, batch, batch_idx): 
        loss, acc, y_hat, y = self.step(batch)

        self.log("train_loss", loss)

        return loss 
    
    def validation_step(self, batch, batch_idx): 
        loss, acc, y_hat, y = self.step(batch)

        self.log("valid_acc", acc, on_epoch = True, on_step = False)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr = 1e-4)
        return optimizer 

    def test_step(self, batch, batch_idx): 
        '''computes per label f1, precision, recall, accuracy'''

        x, y = batch 
        logits = self.backbone(x).logits
        y_hat = (torch.sigmoid(logits) > 0.5).float()

        f1 = self.f1(y_hat, y)
        precision = self.precision(y_hat, y)
        recall = self.recall(y_hat, y)
        accuracy = self.accuracy(y_hat, y)

        for i in range(self.num_classes):
            self.log(f"test_f1_label_{i}", f1[i], on_epoch=True, logger=True)
            self.log(f"test_precision_label_{i}", precision[i], on_epoch=True, logger=True)
            self.log(f"test_recall_label_{i}", recall[i], on_epoch=True, logger=True)
            self.log(f"test_accuracy_label_{i}", accuracy[i], on_epoch=True, logger=True)

        return {"f1": f1, "precision": precision, "recall": recall, "accuracy": accuracy}


In [133]:
class HuggingFaceCXR(Dataset):
    """
    Custom PyTorch Dataset wrapper for Hugging Face datasets.
    Converts dataset samples into a PyTorch-compatible format.
    """
    def __init__(self, hf_dataset, image_size, num_classes=15, transform=None):
        self.dataset = hf_dataset.with_format("torch")  # Ensure dataset is in torch format
        self.image_size = image_size
        self.num_classes = num_classes
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]

        # Load image
        image = item["image"]  # Ensure this matches your dataset keys

        # Resize image
        image = F.interpolate(image.unsqueeze(0), size = self.image_size, mode = "bilinear").squeeze(0)

        # Handle 4-channel (RGBA) images: Keep only RGB
        if image.shape[0] == 4:
            image = image[:3, :, :]

        # Handle Grayscale (1-channel) images: Convert to 3-channel
        if image.shape[0] == 1:
            image = image.repeat(3, 1, 1)

        # Normalize pixel values to [0,1]
        image = image / 255.0

        # Convert labels to one-hot encoding
        labels = item["labels"]
        one_hot = torch.zeros(self.num_classes, dtype=torch.float32)
        one_hot[labels] = 1  # Set corresponding indices to 1

        # Apply optional transformations
        if self.transform:
            image = self.transform(image)

        return image, one_hot

In [134]:
import torch
import numpy as np

from sklearn.model_selection import train_test_split

from torchvision.datasets import ImageFolder
from torchvision.transforms import v2

import lightning as L
from torch.utils.data import DataLoader, Subset

class nih_cxr_datamodule(L.LightningDataModule):
    '''Lightning data module for the cxr dataset'''

    def __init__(self, batch_size, data_root="alkzar90/NIH-Chest-X-ray-dataset"): 
        super().__init__()
        self.data_root = data_root
        self.batch_size = batch_size 
        self.num_classes = 15

    def setup(self, stage = None):
        '''set up the dataset, train/valid/test all at once'''

        transforms = v2.Compose([v2.ToImage(),
                                 v2.Resize(size=(224,224), interpolation=2),
                                 v2.Grayscale(num_output_channels=3), # need to ensure 3 channel grayscale for vit 
                                 v2.ToDtype(torch.float32, scale=True),
                                 v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
                                ])
        
        ds_train = load_dataset(self.data_root, 'image-classification', split = "train[:500]")

        train_valid_split = ds_train.train_test_split(test_size = 0.2)

        ds_train = train_valid_split['train']
        ds_valid = train_valid_split['test']

 
        ds_test = load_dataset(self.data_root, 'image-classification', split = "test[:500]") 

        self.train_data = HuggingFaceCXR(ds_train, image_size = (224, 224), transform = transforms)
        self.valid_data = HuggingFaceCXR(ds_valid, image_size = (224, 224), transform = transforms)
        self.test_data = HuggingFaceCXR(ds_test, image_size = (224, 224), transform = transforms)


    def train_dataloader(self): 
        return DataLoader(self.train_data, batch_size = self.batch_size, shuffle = True)
    
    def valid_dataloader(self): 
        return DataLoader(dataset = self.valid_data, batch_size = self.batch_size, shuffle = False)
    
    def test_dataloader(self):
        return DataLoader(self.test_data, batch_size = self.batch_size, shuffle = False)
        

In [None]:

import lightning as L
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import CSVLogger

def main(args): 
    L.seed_everything(42)

    # set up data 
    datamodule = nih_cxr_datamodule(batch_size=8)
    datamodule.prepare_data()
    datamodule.setup()

    train_dataloader = datamodule.train_dataloader()
    valid_dataloader = datamodule.valid_dataloader()
    test_dataloader = datamodule.test_dataloader()

    # setup model 
    model = VisionTransformerPretrained('google/vit-base-patch16-224', datamodule.num_classes, learning_rate= 1e-4)

    early_stopping = EarlyStopping(monitor = 'valid_acc', patience = 6, mode = 'max')

    logger = CSVLogger("tensorboard_logs", name = 'nih_cxr_pretrained_vit')

    #train 
    trainer = L.Trainer(devices = 1, max_epochs = 20, callbacks = [early_stopping], logger =logger)
    trainer.fit(model = model, train_dataloaders=train_dataloader, val_dataloaders = valid_dataloader)

    predictions = trainer.predict(model, dataloaders=test_dataloader)




In [137]:
args = "arg"
run = main(args)

Seed set to 42
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([15]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([15, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name      | Type                      | Params | Mode 
----------------------------------------------------------------
0 | backbone  | ViTForImageClassification | 85.8 M | eval 
1 | loss_fn   | BCEWithLogitsLoss         | 0      | train
2 | acc       | MultilabelAccuracy        | 0      | train
3 | f1        | MultilabelF1Score  

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

/Users/oliverfox/venvs/ml_venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]