In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision.utils import make_grid 
from torchvision.models import resnet50, ResNet50_Weights
import lightning as L
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger
import numpy as np

import utils
from models import Classificator

## Hyperparameters

In [None]:
#Dataloader params
BATCH_SIZE = 32
NUM_WORKERS = 7 # 7 because that what it suggested in a warning message
PERSISTENT_WORKERS = True # Suggested to do this in a warning message for faster init
USE_AUGMENT = False  # mutual exclusive with CUSTOM_TRAIN_VAL_SPLIT
CUSTOM_TRAIN_VAL_SPLIT = True  # mutual exclusive with USE_AUGMENT
DATA_DIR = "chest_xray" # Change this to chest_xray folder
USE_SAMPLER = False
SHOW_ANALYTICS = False

# Lightning moduls params
LEARNING_RATE = 1e-3
EPOCHS = 100
CLASS_LABELS = ["Normal", "Pneumonia"]
NUM_CLASSES = 2

config = {
    "lr": LEARNING_RATE,
    "loss": "CrossEntropyLoss"
}


## Load data

In [None]:
train_loader, val_loader, test_loader = utils.loadData(
    batchSize=BATCH_SIZE,
    numWorkers=NUM_WORKERS,
    dataDir=DATA_DIR,
    customSplit=CUSTOM_TRAIN_VAL_SPLIT,
    useAugment=USE_AUGMENT,  # Warn: Modify definition, used elsewhere
    useSampler=USE_SAMPLER,
    showAnalytics=SHOW_ANALYTICS,
)


## Plot some example images

In [None]:
examples = next(iter(train_loader))
images, labels = examples
grid = make_grid(images[:9], nrow=3)
plt.imshow(grid.permute(1, 2, 0))
print(labels[:9].reshape(3,3))

## Load ResNet50 model

In [None]:
#Load best resNet50 weights
model = resnet50(weights=ResNet50_Weights.DEFAULT)

#Change output layer to 2 classes
model.fc = torch.nn.Linear(
    in_features=2048,
    out_features=NUM_CLASSES,
    bias=True
)

#Freeze all layers except fc
for name, param in model.named_parameters():
    if "fc.weight" in name or "fc.bias" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

## Run training and validation loops

In [None]:
classifier = Classificator(model, CLASS_LABELS, config, NUM_CLASSES)
early_stop_callback = EarlyStopping(monitor="Validation loss", min_delta=1e-6, patience=10)
checkpoint = L.pytorch.callbacks.ModelCheckpoint(dirpath="pneumonia_model/ResNet/")
callbacks = [early_stop_callback, checkpoint]
logger = TensorBoardLogger("lightning_logs",
                           name=f"resnet/{'augment' if USE_AUGMENT else 'original'}",
                           )

trainer = L.Trainer(
    accelerator="auto",
    devices=1,
    logger=logger,
    max_epochs=EPOCHS,
    reload_dataloaders_every_n_epochs=3,
    callbacks=callbacks
)
trainer.fit(
    model = classifier,
    train_dataloaders = train_loader, 
    val_dataloaders = val_loader
)

## Test model
Only implemented and tested but should not use until final model is decided

In [None]:
#trainer.test(model = classifier, dataloaders=test_loader)

In [None]:
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/