In [None]:
import PIL
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sklearn.metrics as sklm

import torch
import pytorch_lightning as pl

import dataset
import neural_network
import config as cfg

torch.__version__

In [None]:
import sys
print(sys.version)

## Parameters

In [None]:
print(f"Num GPUs Available: {torch.cuda.device_count()}")

## Classes

In [None]:
CLASSES = dataset.metadata.get_classes(cfg.paths.LABELS_CSV["train"])  
CLASSES, CLASSES.size

## Preprocessing Dataset

In [None]:
train_filenames = cfg.paths.IMG_DIR["train"].glob("*.jpg")

filename = next(train_filenames)
PIL.Image.open(filename)

In [None]:
IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH = dataset.metadata.get_image_dimensions(cfg.paths.IMG_DIR["train"])
IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH

In [None]:
train = dataset.SkinCancerDataset(cfg.paths.LABELS_CSV["train"], cfg.paths.IMG_DIR["train"])
dataset.plot_some_samples(2, 2, train, CLASSES)

In [None]:
test = dataset.SkinCancerDataset(cfg.paths.LABELS_CSV["test"], cfg.paths.IMG_DIR["test"])
dataset.plot_some_samples(2, 2, test, CLASSES)

## Modeling the CNN

In [None]:
SEED = 0

pl.seed_everything(SEED, workers=True)

In [None]:
# I think it should be rescaled here
# model.add(layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)))
data_module = dataset.SkinCancerDataModule(
    cfg.paths.LABELS_CSV,
    cfg.paths.IMG_DIR,
    cfg.hparams.BATCH_SIZE,
    cfg.hparams.DATALOADER_NUM_WORKERS,
    transform=None
)

In [None]:
model = neural_network.ConvNetwork(CLASSES.size, cfg.hparams.DROPOUT_RATE)
print(model)

In [None]:
model_module = neural_network.NetworkModule(
    model,
    IMG_CHANNELS,
    IMG_HEIGHT,
    IMG_WIDTH,
    CLASSES.size,
    cfg.hparams.LEARNING_RATE
)

In [None]:
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.callbacks.progress import TQDMProgressBar

progress_bar = TQDMProgressBar()
validation = EarlyStopping("val_loss")
checkpoint = ModelCheckpoint(save_top_k=3, monitor="val_loss")

trainer_callbacks = [progress_bar, validation, checkpoint]

In [None]:
from pytorch_lightning.loggers import CSVLogger

logger = CSVLogger(cfg.paths.LOG_DIR)

In [None]:
trainer = pl.Trainer(
    min_epochs=5,
    max_epochs=20,
    accelerator="auto",
    devices="auto",
    logger=logger,
    callbacks=trainer_callbacks,
    deterministic=True,
    # mostly parameters below can be removed
    # when we want to fully train our network
    limit_train_batches=0.05,
    limit_val_batches=0.1,
    log_every_n_steps=25,
)

trainer.fit(model=model_module, datamodule=data_module)

## Metrics

In [None]:
checkpoint.best_k_models

In [None]:
# each time you train a neural network
# it logs the information to a new version folder
version = 0
VERSION_DIR = cfg.paths.LOG_DIR / f"lightning_logs/version_{version}"

In [None]:
METRICS_PATH = VERSION_DIR / "metrics.csv"

metrics = pd.read_csv(METRICS_PATH).set_index(["epoch", "step"])

train = metrics[["train_loss_epoch", "train_acc_epoch"]].dropna()
validation = metrics[["val_loss", "val_acc"]].dropna()

validation

In [None]:
epochs = train.index.get_level_values("epoch")

plt.plot(epochs, train["train_acc_epoch"], label="accuracy")
plt.plot(epochs, validation["val_acc"], label = "val_accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend(loc="lower right")

## Confusion Matrix

In [None]:
model_module.clear_test_predictions_variables()
trainer.test(model_module, data_module)

In [None]:
true_labels_pos = torch.concat(model_module.test_expected).cpu()

predicted_probabilities = torch.concat(model_module.test_probabilities)
predicted_labels_pos = predicted_probabilities.argmax(dim=1).cpu()

true_labels = CLASSES[true_labels_pos]
predicted_labels = CLASSES[predicted_labels_pos]

In [None]:
confusion_matrix = pd.crosstab(true_labels, predicted_labels, rownames=["Actual"], colnames=["Predicted"])
confusion_matrix

In [None]:
report = sklm.classification_report(true_labels_pos, predicted_labels_pos, target_names=CLASSES)
print(report)