Import and constants

In [None]:
from torch import nn, optim, float32, IntTensor, FloatTensor
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, models
from torchvision.io import read_image, ImageReadMode
import torchvision.transforms as transforms
import os
import pandas as pd
import torchvision
import matplotlib.pyplot as plt
from PIL import Image

In [None]:
debug = True

Important paths

In [None]:
# Names and base paths, define here and combine
db_path = "/cluster/home/larsira/tdt4900/databases/chest_xray14/"
img_dir_name = "images"
annotations_file = os.path.join(db_path, "data_list.csv")

img_dir = os.path.join(db_path, img_dir_name)

Sanity check data

In [None]:
# Load the first image and see what is up
image = Image.open(os.path.join(img_dir, "00000001_000.png"))

# Load annotations file and inspect a label
annotations = pd.read_csv(os.path.join(db_path, annotations_file))

if debug:
    print(annotations.iloc[1])
    print("This patient is aflicted with:", annotations["Finding Labels"])
    print(image)
    plt.imshow(image)

Fetch relevant labels from dataset

In [None]:
labels = set([x for y in [x.split("|") for x in annotations["Finding Labels"]] for x in y])
labels.remove("No Finding")
print(labels, "There are", len(labels), "labels available")

Define transform operation for label

In [None]:
def to_numeric_label(in_label):
    in_label = in_label.split("|")
    return FloatTensor([1 if (x in in_label) else 0 for x in labels])

Define custom dataset

In [None]:
class ChestXRayDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(path=img_path, mode=ImageReadMode.RGB)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            print(image.shape)
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

Load into dataset, define a data loader

In [None]:
batch_size = 8
transform = transforms.Compose(
    [transforms.ConvertImageDtype(float32),
    transforms.Normalize((0.5,), (0.5,))])
dataset = ChestXRayDataset(annotations_file, img_dir, transform, to_numeric_label)
loader = DataLoader(dataset, batch_size = batch_size)

for idx, value in enumerate(loader):
    test_img, test_lab = value
    img = test_img[0]
    label = test_lab[0]
    if idx == 100:
        break

So here, we should probably define ourselves a model

In [None]:
model = models.resnet101(weights=None)
model.fc = nn.Linear(2048, len(labels))

In [None]:
# Hyper params
lr = 10e-4

optimizer = optim.SGD(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()

def train_epoch(idx, data_loader):
    prev_loss = 0
    running_loss = 0
    
    for i, data in enumerate(data_loader):
        inputs, labels = data

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000
            print("batch {} loss: {}".format(i+1, last_loss))
            running_loss = 0.
    
    return last_loss

In [None]:
epochs = 10
best_loss = 10_000_000.

for epoch in range(epochs):
    model.train(True)

    avg_loss = train_epoch(epoch, loader)

    running_validation_loss = 0
    model.eval()
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            validation_inputs, validation_labesl = vdata
            validation_output = model(validation_inputs)
            validation_loss = loss_fn(validation_outputs, validation_inputs)
            running_validation_loss += validation_loss
    avg_validation_loss = running_validation_loss/(i+1)

    if avg_validation_loss < best_loss:
        best_loss = avg_loss
        torch.save(model.state_dict(), "model_{}_{}".format(epoch, epoch))