In [None]:
from pneumonia_detector.preprocess import XrayDataset
from pneumonia_detector.model import PneumoniaClassifier
from pneumonia_detector.train import train_one_epoch
import os
import torch
import torch.nn as nn 
import torch.optim as optim
import pandas as pd
# from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import transforms, utils
from torchvision.io import read_image
from typing import List

# from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

In [None]:
torch.manual_seed(55)

In [None]:
train_dir = "/workspaces/chest_xray_challenge/data/chest_xray/train/"
val_dir = "/workspaces/chest_xray_challenge/data/chest_xray/val/"

In [None]:
train_transforms = transforms.Compose([
    # transforms.Grayscale(1),
    transforms.Resize((256, 256)),
    # transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4823, 0.4823, 0.4823], std=[0.2363, 0.2363, 0.2363]),
])

In [None]:
train_transforms = transforms.Compose([transforms.RandomRotation(20),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomResizedCrop(size=(256, 256), scale=(0.8, 1.0)),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
        transforms.RandomApply([transforms.RandomAffine(0, translate=(0.1, 0.1))], p=0.5),
        transforms.RandomApply([transforms.RandomPerspective(distortion_scale=0.2)], p=0.5),
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4823, 0.4823, 0.4823], std=[0.2363, 0.2363, 0.2363]),
])

In [None]:
xray_train_data = XrayDataset(root_dir=train_dir, transform=train_transforms)
len(xray_train_data)

In [None]:
def create_weighted_sampler(dataset):
    targets = [XrayDataset.label_map[file.split(os.sep)[-2].lower()] for file in dataset.files]
    class_counts = np.bincount(targets)
    class_weights = 1.0 / class_counts
    weights = [class_weights[label] for label in targets]
    sampler = WeightedRandomSampler(weights, len(weights))
    return sampler

In [None]:
sampler = create_weighted_sampler(xray_train_data)

In [None]:
xray_train_data[0][0].shape

In [None]:
print(torch.mean(xray_train_data[0][0], dim=[1,2], keepdim=True))
print(torch.std(xray_train_data[0][0], dim=[1,2], keepdim=True))
print(torch.min(xray_train_data[0][0]))
print(torch.max(xray_train_data[0][0]))

In [None]:
xray_val_data = XrayDataset(root_dir=val_dir, transform=train_transforms)
len(xray_val_data)

In [None]:
train_dataloader_xray = DataLoader(
                                dataset=xray_train_data,
                                batch_size=16,
                                num_workers=0,
                                # shuffle=True,
                                sampler=sampler,
                                )

In [None]:
val_dataloader_xray = DataLoader(
                                dataset=xray_val_data,
                                batch_size=16,
                                num_workers=0,
                                shuffle=True
                                )

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
model = PneumoniaClassifier().to(device)
criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(model.parameters(), lr=0.0005, momentum=0.9)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(train_dataloader_xray, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs.to(torch.float32).to(device))
        loss = criterion(outputs, labels.to(device))
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 100 == 99:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0

print('Finished Training')

In [None]:
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
epoch_number = 0

EPOCHS = 10

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(train_dataloader_xray, optimizer, model, criterion, epoch_number)

    # We don't need gradients on to do reporting
    model.train(False)

    running_vloss = 0.0
    for i, vdata in enumerate(val_dataloader_xray):
        vinputs, vlabels = vdata
        voutputs = model(vinputs)
        vloss = criterion(voutputs, vlabels)
        running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    # writer.add_scalars('Training vs. Validation Loss',
    #                 { 'Training' : avg_loss, 'Validation' : avg_vloss },
    #                 epoch_number + 1)
    # writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = '../models/model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1