The following cells were run in Google Colab to train a Swin Transformer for image classification on the NIH dataset of ~100k chest x-rays.

This is a multiclass, multilabel problem, with images being classified as showing one or more of 15 diseases.

The files used in this notebook can be downloaded here: https://www.kaggle.com/datasets/nih-chest-xrays/data. A description of the data can be
found here: https://www.nih.gov/news-events/news-releases/nih-clinical-center-provides-one-largest-publicly-available-chest-x-ray-datasets-scientific-community.

In [None]:
import cv2
import glob
import shutil
import os

# Classify image files from "images" folder in GDrive as test, train, or validation

path = '/content/drive/MyDrive/cxr_indiv/images'

# put train and test files from txt file to dict
loc_dict = {} # call loc_dict[filename] to see if an image is train, validation, or test


with open("/content/drive/MyDrive/cxr_indiv/train_val_list.txt", "r") as a_file:
  for line in a_file:
    # IF IN THE CURRENT IMAGE FOLDER, THEN
    stripped_line = line.strip()
    if os.path.exists(path + '/' + stripped_line):
      loc_dict[stripped_line] = a_file.name.replace("/content/drive/MyDrive/cxr_indiv/", "").replace("_val_list.txt", "").strip()

In [None]:
# split training data for validation
train_length = 0
for i in loc_dict:
  if loc_dict[i] == 'train':
    train_length += 1

validation_percent = 0.1

desired_length = (1.0 - validation_percent) * train_length
for i in loc_dict:
  if loc_dict[i] == 'train':
    loc_dict[i] = 'valid'
    train_length -=1
  if train_length <= desired_length:
    break


In [None]:
# split csv into train and test
# loop through csv file
# with conditional to separate train and test images

import pandas as pd
full_csv = pd.read_csv("/content/drive/MyDrive/cxr_indiv/Data_Entry_2017_v2020.csv")

train_df = pd.DataFrame({})
valid_df = pd.DataFrame({})
test_df = pd.DataFrame({})
for i in range(len(full_csv['Image Index'])):
  try:
    loc_dict[full_csv['Image Index'][i]] == 'train'

    if loc_dict[full_csv['Image Index'][i]] == 'train':
      train_df = train_df.append(full_csv.iloc[i])
    elif loc_dict[full_csv['Image Index'][i]] == 'valid':
      valid_df = valid_df.append(full_csv.iloc[i])
    # elif loc_dict[full_csv['Image Index'][i]] == 'test':
    #   test_df = test_df.append(full_csv.iloc[i])
    else:
      print("error")
  except KeyError:
    continue

# create sets of labels and remove no finding
count = 0
for i in range(len(train_df['Finding Labels'])):
  if type(train_df['Finding Labels'].iloc[i]) != set:
    temp = train_df['Finding Labels'].iloc[i].split('|')
    train_df['Finding Labels'].iloc[i] = set(temp)

for i in range(len(valid_df['Finding Labels'])):
  if type(valid_df['Finding Labels'].iloc[i]) != set:
    temp = valid_df['Finding Labels'].iloc[i].split('|')
    valid_df['Finding Labels'].iloc[i] = set(temp)

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self._setitem_single_block(indexer, value, name)


In [None]:
# install the necessary dependencies:
# !pip install torch
# !pip install torchvision
# !pip install timm

In [None]:
import os
import pandas as pd
from torchvision.io import read_image
import torch
import io
import numpy as np
from torch.utils.data import Dataset, DataLoader
import sklearn
from sklearn import preprocessing
import torchvision

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        temp = annotations_file['Finding Labels']
        mlb = sklearn.preprocessing.MultiLabelBinarizer()
        self.img_labels = pd.DataFrame(mlb.fit_transform(temp),columns=mlb.classes_)
        print(mlb.classes_)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
        self.annot = annotations_file

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.annot['Image Index'].iloc[idx])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        image.resize_(3, 224, 224)
        return image, label.to_numpy() # torch tensor, numpy array

In [None]:
# create the dataset
import torchvision
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((224, 224))])
train_dataset = CustomImageDataset(train_df, path)
valid_dataset = CustomImageDataset(valid_df, path) # validation data still held in train folder

# create a data loader for train, valid, and test sets
batches = 32

train_dl = DataLoader(train_dataset, batch_size=batches, shuffle=True)
# train_features, train_labels = next(iter(train_dl))

valid_dl = DataLoader(valid_dataset, batch_size=batches, shuffle=True)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [None]:
import torch
import torch.nn as nn
import timm

model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Downloading: "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth" to /root/.cache/torch/hub/checkpoints/swin_base_patch4_window7_224_22kto1k.pth


In [None]:
# Using BCELoss and sigmoid separately
criterion = torch.nn.BCELoss()
criterion = criterion.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.05)


In [None]:
def optimizer_to(optim, device):
    for param in optim.state.values():
        # Not sure there are any global tensors in the state dict
        if isinstance(param, torch.Tensor):
            param.data = param.data.to(device)
            if param._grad is not None:
                param._grad.data = param._grad.data.to(device)
        elif isinstance(param, dict):
            for subparam in param.values():
                if isinstance(subparam, torch.Tensor):
                    subparam.data = subparam.data.to(device)
                    if subparam._grad is not None:
                        subparam._grad.data = subparam._grad.data.to(device)

prevpath = 'state_1.pth'
prevpath = '/content/drive/MyDrive/tesim/'+ prevpath
state = torch.load(prevpath)
model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])
optimizer_to(optimizer, device)
criterion.load_state_dict(state['criterion'])
optimizer.param_groups[0]['lr'] = 0.00001

In [None]:
model = model.to(device)
criterion = criterion.to(device)

In [None]:
from torch.autograd.grad_mode import F
from sklearn.metrics import confusion_matrix
pred = torch.Tensor([[0,1, 1, 0], [1,1,1,1]])
truth = torch.Tensor([[1, 1, 0, 0], [1,1,1,1]])

def find_single_accuracy(pred, truth):
  length = 0
  true_neg_and_pos = 0

  # get tensor size
  for i in pred:
    length+=1

  # count true positives and negatives
  for i in range(length):
    if pred[i].item() == truth[i].item():
      true_neg_and_pos+=1

  return true_neg_and_pos / length



def find_accuracy(pred: torch.Tensor, truth: torch.Tensor):
  """ returns accuracy given a k-D prediction (one hot encoded) and the target for those k samples
  """
  accuracies = []
  length = 0
  for i in pred:
    length+=1

  for i in range(length):
    accuracies.append(find_single_accuracy(pred[i], truth[i]))

  return sum(accuracies) / len(accuracies)

print(find_accuracy(pred, truth))

0.75


In [None]:
optimizer.param_groups[0]['capturable'] = True

In [None]:
from numpy import vstack
from numpy import argmax
from pandas import read_csv
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
from torch import Tensor
from torch.utils.data import Dataset
from torch.utils.data import random_split
from torch.nn import Linear
from torch.nn import ReLU
from torch.nn import Softmax
from torch.nn import Module
import copy

savepath = 'state_1.pth' # for updating model state dict during validation
optimizer.param_groups[0]['lr'] = 0.001

# train model
from tqdm import tqdm
# enumerate epochs
running_loss = 0
current_batch_accs = []
epoch_accs = []
torch.cuda.empty_cache()
for epoch in range(3):
    # enumerate mini batches
    for (inputs, targets) in tqdm(train_dl):
      model.train()
      inputs = inputs.to(device)
      targets = targets.to(device)

      # clear the gradients
      optimizer.zero_grad()

      # compute the model output and establish layers
      lin = nn.Linear(1000, 15)
      lin = lin.to(device)
      sig = nn.Sigmoid()
      yhat = model(inputs.cuda().float())
      # linear layer
      yhat = lin(yhat)
      yhat = yhat.to(device)
      # sigmoid layer
      yhat = sig(yhat)

      # set loss
      try:
        loss = criterion(yhat, targets.float())
      except:
        remaining_images = len(train_dataset) % batches
        loss = criterion(yhat, targets.float())

      yhat = yhat.detach()

      print(loss.item())
      # credit assignment
      loss.backward()
      running_loss += loss.item()
      print(running_loss)

      # update model weights
      optimizer.step()
      optimizer.param_groups[0]['lr'] /= 10

# ------------------------------------------------------------------------

    # validation
    model.eval()
    predictions, actuals = list(), list()
    torch.cuda.empty_cache()
    min_valid_loss = float('inf')
    valid_loss = 0
    for (inputs, targets) in valid_dl:
      inputs = inputs.to(device)
      targets = targets.to(device)
      yhat = model(inputs.float())
      yhat = lin(yhat)
      yhat = sig(yhat)
      yhat = torch.round(yhat, decimals=0)
      yhat = yhat.detach()

      actual = targets.cpu().float().numpy()

      predictions.append(yhat)
      actuals.append(actual)
      yhat = yhat.detach()
      torch.cuda.empty_cache()

      # calculate batch accuracy
      print('batch accuracy:', str(find_accuracy(yhat, actual)))
      current_batch_accs.append(find_accuracy(yhat, actual))
      valid_loss = loss.item() * inputs.size(0)

    if min_valid_loss > valid_loss:
        min_valid_loss = valid_loss
        
        # Saving State Dict
        state = {
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'criterion': criterion.state_dict()
        }
        torch.save(state, '/content/drive/MyDrive/tesim/'+ savepath)

    # calculate epoch accuracy
    predictions, actuals = vstack([i.cpu() for i in predictions]), vstack(actuals)
    acc = (actuals, predictions)
    print(' for epoch', (epoch+1))
    print('epoch accuracy: ' + str(sum(current_batch_accs) / len(current_batch_accs)))
    print('epoch loss: ' + str(running_loss / train_dataset.__len__()))
    running_loss = 0
    epoch_accs.append(sum(current_batch_accs) / len(current_batch_accs))
    current_batch_accs = []

print('====================================')
print('all epoch accuracies:', str(epoch_accs))
print('avg epoch acc:', str(sum(epoch_accs) / len(epoch_accs)))
print('highest epoch acc was epoch', str(epoch_accs.index(max(epoch_accs)) + 1), 'with', str(max(epoch_accs)))
print('lowest epoch acc was epoch', str(epoch_accs.index(min(epoch_accs)) + 1), 'with', str(min(epoch_accs)))

In [None]:
# save state dict
state = {
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'criterion': criterion.state_dict()
}
torch.save(state, '/content/drive/MyDrive/tesim/'+ savepath)