The following cells were run in Google Colab to train a Swin Transformer for image classification on the NIH dataset of ~100k chest x-rays.

This is a multiclass, multilabel problem, with images being classified as showing one or more of 15 diseases.

The files used in this notebook can be downloaded here: https://www.kaggle.com/datasets/nih-chest-xrays/data. A description of the data can be
found here: https://www.nih.gov/news-events/news-releases/nih-clinical-center-provides-one-largest-publicly-available-chest-x-ray-datasets-scientific-community.

Achieved accuracy of ~80% trained from scratch. SoTA is ~82%.

In [None]:
import cv2
import glob
import shutil
import os

path = '/content/drive/MyDrive/buffer/testing_swin'

# put train and test files from txt to dict
loc_dict = {} # call loc_dict[filename] to see if image is test

with open("/content/drive/MyDrive/cxr_indiv/test_list.txt", "r") as a_file:
  for line in a_file:
    # IF IN THE CURRENT IMAGE FOLDER, THEN
    stripped_line = line.strip()
    loc_dict[stripped_line] = a_file.name.replace("/content/drive/MyDrive/cxr_indiv/", "").replace("_list.txt", "").strip()

In [None]:
# split csv into train and test
# loop through csv file
# with conditional to separate train and test images

import pandas as pd
full_csv = pd.read_csv("/content/drive/MyDrive/cxr_indiv/Data_Entry_2017_v2020.csv")

test_df = pd.DataFrame({})
for i in range(len(full_csv['Image Index'])):
  try:
    loc_dict[full_csv['Image Index'][i]] == 'test'
    if loc_dict[full_csv['Image Index'][i]] == 'test':
      test_df = test_df.append(full_csv.iloc[i])
    else:
      print("error")
  except KeyError:
    continue

# create sets of labels and remove no finding
# =========================== For Testing ===========================

for i in range(len(test_df['Finding Labels'])):
  if type(test_df['Finding Labels'].iloc[i]) != set:
    temp = test_df['Finding Labels'].iloc[i].split('|')
    if 'No Finding' in temp:
      temp.remove('No Finding')
    test_df['Finding Labels'].iloc[i] = set(temp)


In [None]:
# install dependencies
# !pip install torch
# !pip install torchvision
# !pip install timm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import os
import pandas as pd
from torchvision.io import read_image
import torch
import io
import numpy as np
from torch.utils.data import Dataset, DataLoader
import sklearn
from sklearn import preprocessing
import torchvision

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        temp = annotations_file['Finding Labels']
        mlb = sklearn.preprocessing.MultiLabelBinarizer()
        self.img_labels = pd.DataFrame(mlb.fit_transform(temp),columns=mlb.classes_)
        print(mlb.classes_)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
        self.annot = annotations_file

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.annot['Image Index'].iloc[idx])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        image.resize_(3, 224, 224)
        return image, label.to_numpy() # torch tensor, numpy array

In [None]:
# create the dataset
import torchvision
test_dataset = CustomImageDataset(test_df, path)

# create a data loader for train, valid, and test sets
batches = 100

test_dl = DataLoader(test_dataset, batch_size=batches, shuffle=True)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
import torch
import torch.nn as nn
import timm

import collections
try:
    from collections import OrderedDict
except ImportError:
    OrderedDict = dict

HUB_URL = "SharanSMenon/swin-transformer-hub:main"
MODEL_NAME = "swin_tiny_patch4_window7_224"
model = torch.hub.load(HUB_URL, MODEL_NAME, pretrained=False)
classifier = nn.Sequential(nn.Linear(model.head.in_features, 14), nn.Sigmoid())
model.classifier = classifier

In [None]:
modelpath = 'state_12.pth'
modelpath = '/content/drive/MyDrive/tesim/'+ modelpath
state = torch.load(modelpath)
model.load_state_dict(state['state_dict'])

In [None]:
model = model.to(device)

In [None]:
from torch.autograd.grad_mode import F
from sklearn.metrics import confusion_matrix
pred = torch.Tensor([[0,1, 1, 0], [1,1,1,1]])
truth = torch.Tensor([[1, 1, 0, 0], [1,1,1,1]])

def find_single_accuracy(pred, truth):
  length = 0
  true_neg_and_pos = 0

  # get tensor size
  for i in pred:
    length+=1

  # count true positives and negatives
  for i in range(length):
    if pred[i].item() == truth[i].item():
      true_neg_and_pos+=1

  return true_neg_and_pos / length



def find_accuracy(pred: torch.Tensor, truth: torch.Tensor):
  """ returns accuracy given a k-D prediction (one hot encoded) and the target for those k samples
  """
  accuracies = []
  length = 0
  for i in pred:
    length+=1

  for i in range(length):
    accuracies.append(find_single_accuracy(pred[i], truth[i]))

  return sum(accuracies) / len(accuracies)

print(find_accuracy(pred, truth))

ROC_AUC Score:

In [None]:
from numpy import vstack
from numpy import argmax
from pandas import read_csv
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import roc_auc_score
from torch import Tensor
from torch.utils.data import Dataset
from torch.utils.data import random_split
from torch.nn import Linear
from torch.nn import ReLU
from torch.nn import Softmax
from torch.nn import Module
import copy
from tqdm import tqdm

running_loss = 0
current_batch_accs = []
epoch_accs = []
torch.cuda.empty_cache()

# testing loop
model.eval()
predictions, actuals = list(), list()
torch.cuda.empty_cache()
min_valid_loss = float('inf')
valid_loss = 0
for (inputs, targets) in tqdm(test_dl):
  inputs = inputs.to(device)
  targets = targets.to(device)
  # compute the model output and establish layers
  lin = nn.Linear(1000, 14)
  lin = lin.to(device)
  sig = nn.Sigmoid()
  yhat = model(inputs.float())
  yhat = lin(yhat)
  yhat = sig(yhat)
  yhat = torch.round(yhat, decimals=0) # Perhaps remove for roc_auc
  yhat = yhat.detach()

  actual = targets.cpu().float().numpy()

  predictions.append(yhat)
  actuals.append(actual)
  yhat = yhat.detach()
  torch.cuda.empty_cache()
  # calculate batch accuracy
  # print(yhat)
  print('\n')
  print(actual)
  try:
    print('batch roc_auc_score:', str(roc_auc_score(y_true=actual, y_score=yhat.cpu())))
    current_batch_accs.append(roc_auc_score(y_true=actual, y_score=yhat))
  except:
    print('error')
    actual = np.vstack((actual, np.array([1,0,0,0,0,0,0,0,0,0,0,0,0,0])))
    yhat = np.vstack((yhat.cpu(), np.array([1,0,0,0,0,0,0,0,0,0,0,0,0,0])))
    print('batch roc_auc_score:', str(roc_auc_score(y_true=actual.flatten(), y_score=yhat.flatten())))
    current_batch_accs.append(roc_auc_score(y_true=actual.flatten(), y_score=yhat.flatten()))
  print('to-date average roc_auc_score', str(sum(current_batch_accs) / len(current_batch_accs)))

# calculate epoch accuracy
print('total test accuracy: ' + str(sum(current_batch_accs) / len(current_batch_accs)))

Confusion Matrix:

In [None]:
from numpy import vstack
from numpy import argmax
from pandas import read_csv
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
from torch import Tensor
from torch.utils.data import Dataset
from torch.utils.data import random_split
from torch.nn import Linear
from torch.nn import ReLU
from torch.nn import Softmax
from torch.nn import Module
import copy
from tqdm import tqdm

running_loss = 0
current_batch_accs = []
epoch_accs = []
torch.cuda.empty_cache()

# testing loop
model.eval()
predictions, actuals = list(), list()
torch.cuda.empty_cache()
min_valid_loss = float('inf')
valid_loss = 0
for (inputs, targets) in tqdm(test_dl):
  inputs = inputs.to(device)
  targets = targets.to(device)
  # compute the model output and establish layers
  lin = nn.Linear(1000, 14)
  lin = lin.to(device)
  sig = nn.Sigmoid()
  yhat = model(inputs.float())
  yhat = lin(yhat)
  yhat = sig(yhat)
  yhat = torch.round(yhat, decimals=0)
  yhat = yhat.detach()

  actual = targets.cpu().float().numpy()

  predictions.append(yhat)
  actuals.append(actual)
  yhat = yhat.detach()
  torch.cuda.empty_cache()
  # calculate batch accuracy
  # print(yhat)
  print('\n')
  print('batch accuracy:', str(find_accuracy(yhat, actual)))
  current_batch_accs.append(find_accuracy(yhat, actual))
  print('to-date average accuracy', str(sum(current_batch_accs) / len(current_batch_accs)))

# calculate epoch accuracy
predictions, actuals = vstack([i.cpu() for i in predictions]), vstack(actuals)
acc = (actuals, predictions)
print('total test accuracy: ' + str(sum(current_batch_accs) / len(current_batch_accs)))