### 1. Import Dependencies and Setup 📥

#### 1.1 Dependencies

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
from torchvision import datasets, transforms as tt
from torch.utils.data import DataLoader, random_split, Dataset
import torch.nn.functional as F
from cnn import Net
from utils import save_model, load_model, display_metrics, plot_graphs
from collections import Counter
from sklearn.metrics import (
    confusion_matrix,
    ConfusionMatrixDisplay,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_curve, 
    roc_auc_score
)
from datasetLoader.MergedDataset import MergedDataset, to_device

#### 1.2 Device

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

#### 1.3 Parameters

Set the parameter *loadExistingModel* to 'True' if you want to load an existing model and specify wich model do you want to use.

In [None]:
# Define transformations (e.g., resizing, normalization)
transform = tt.Compose([tt.Resize(255),
                        tt.CenterCrop(224),
                        tt.ToTensor(),
                        tt.Normalize(mean=0.482, std=0.236, inplace=True)
                        ])

# Define train and validation split value
train_perc = 0.8

# Define batch size
batch_size = 64

# Define learning rate
learning_rate = 0.0001

# Seed
seed = 2024
torch.manual_seed(seed)

# Model settings
loadExistingModel = True
modelName = "./models/modelChestXray.pth"
saveModel = not loadExistingModel

### 2. Load Data 📚

#### 2.1 Dataset loading

In [None]:
dataset = MergedDataset(device, 
                                    transformLoadingChest=transform, 
                                    chest_xray=True, cheX=False, kaggle_rsna=False)
train_dl, test_dl = dataset.getDataLoader()

#### 2.2 Dataset sizes

In [None]:
print(f'Merged dataset lenght: {dataset.getSize()}')

#### 2.3 Dataset classes

In [None]:
train_class_count = dataset.getTrainClasses()
test_class_count = dataset.getTestClasses()

In [None]:
fig, ax = plt.subplots(ncols=2, figsize=(10, 5))

values = [train_class_count[0], train_class_count[1], test_class_count[0], test_class_count[1]]

# Data for the categories
categories = ['Normal', 'Pneumonia']

for i in range(2):
    # Create the bar chart
    ax[i].bar(categories, [values[0 + i*2], values[1 + i*2]], color=['green', 'red'])
    # Add titles and labels
    ax[i].set_title('Train Categories' if i == 0 else 'Test Categories')
    ax[i].set_ylabel('Count')
    ax[i].set_xlabel('Category')

# Show the plot
plt.show()

print(f'Train classes:\n\tNormal:\t\t{train_class_count[0]}\n\tPneumonia:\t{train_class_count[1]}')
print(f'Test classes:\n\tNormal:\t\t{test_class_count[0]}\n\tPneumonia:\t{test_class_count[1]}')

#### 2.4 Data Samples

In [None]:
batch = next(iter(train_dl))
images, labels = batch
print(f'Image shape: {images[0].shape}\nLabel shape: {labels[0]}')

In [None]:
fig, ax = plt.subplots(ncols=4, figsize=(20, 20))
for idx, img in enumerate(batch[0][:4]):
    ax[idx].imshow(img.cpu().permute(1, 2, 0))
    ax[idx].title.set_text(batch[1][idx].cpu().numpy())

### 5. Model Building 🏗️

#### 5.1 Train Model

In [None]:
trained = False
net = None

if loadExistingModel:
    net = load_model(modelName)

if net == None:
    print('The model does not exist!\nCreating and training model...')
    net = to_device(Net(), device)

    # Define optimizer
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

    # Define weights for the cross entropy loss
    weight = torch.FloatTensor([train_class_count[1]/(train_class_count[0]+train_class_count[1]), train_class_count[0]/(train_class_count[0]+train_class_count[1])]).to(device)

    # Train the model
    loss_values = []
    net.train()
    trained = True
    for epoch in range(50):
        running_loss = 0.0
        for i, data in enumerate(train_dl, 0):
            inputs, labels = data
            inputs, labels = to_device(inputs, device), to_device(labels, device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = F.cross_entropy(outputs, labels, weight=weight)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        loss_values.append(running_loss / len(train_dl))
        print(f'Epoch: {epoch}, loss: {(running_loss / len(train_dl))}')
else:
    print('The model exist and exists and has been loaded')
    trained = False
    net.to(device)
    net.eval()
    print('Model info:')
    for param_tensor in net.state_dict():
        print("\t", param_tensor, "\t", net.state_dict()[param_tensor].size())

#### 5.2 Save Model

In [None]:
if saveModel:
    save_model(modelName, net)

#### 5.3 Plot Loss 📈

In [None]:
if trained:
    plt.figure(figsize=(8, 6))
    plt.plot(loss_values, marker='o', label='Loss')
    plt.title('Loss Values Over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss Value')
    plt.legend()
    plt.grid(True)
    plt.show()

### 7. Testing Model 🧪

#### 7.1 Test on chest_xray-3

In [None]:
# Test the model
y_test = []
prob = []
with torch.no_grad():
    for data in test_dl:
        images, labels = data
        images, labels = to_device(images, device), to_device(labels, device)

        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        y_test.extend(labels.cpu().numpy())
        prob.extend(torch.nn.functional.softmax(outputs.cpu(), dim=1)[:, 1])
    fprChest, tprChest, thresholds = roc_curve(y_test, prob)
    roc_aucChest = roc_auc_score(y_test, prob)

    distances = np.sqrt(fprChest**2 + (1 - tprChest)**2)
    best_threshold = thresholds[np.argmin(distances)]
    new_preds = [1 if score > best_threshold else 0 for score in prob]

    cm = confusion_matrix(y_test, new_preds)
    dispChest = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["NEGATIVE", "POSITIVE"])
    accuracyChest = accuracy_score(y_test, new_preds)
    precisionChest = precision_score(y_test, new_preds)
    recallChest = recall_score(y_test, new_preds)
    f1Chest = f1_score(y_test, new_preds)

#### 7.2 Test on cheX

##### 7.2.0 Load Dataset

In [None]:
transform = tt.Compose([tt.Resize(255),
                        tt.CenterCrop(224),
                        tt.ToTensor(),
                        tt.Normalize(mean=0.5017, std=0.2905, inplace=True)
                        ])

dataset = MergedDataset(device, 
                                    transformLoadingCheX=transform,
                                    chest_xray=False, cheX=True, kaggle_rsna=False, 
                                    train_percentage=0, 
                                    split_seed=2024)
_, test_dl = dataset.getDataLoader()

##### 7.2.1 Test

In [None]:
# Test the model
y_test = []
prob = []
with torch.no_grad():
    for data in test_dl:
        images, labels = data
        images, labels = to_device(images, device), to_device(labels, device)

        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        y_test.extend(labels.cpu().numpy())
        prob.extend(torch.nn.functional.softmax(outputs.cpu(), dim=1)[:, 1])
    fprCheX, tprCheX, thresholds = roc_curve(y_test, prob)
    roc_aucCheX = roc_auc_score(y_test, prob)

    distances = np.sqrt(fprCheX**2 + (1 - tprCheX)**2)
    best_threshold = thresholds[np.argmin(distances)]
    new_preds = [1 if score > best_threshold else 0 for score in prob]

    cm = confusion_matrix(y_test, new_preds)
    dispCheX = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["NEGATIVE", "POSITIVE"])
    accuracyCheX = accuracy_score(y_test, new_preds)
    precisionCheX = precision_score(y_test, new_preds)
    recallCheX = recall_score(y_test, new_preds)
    f1CheX = f1_score(y_test, new_preds)

#### 7.3 Test on kaggle-rsna

##### 7.3.0 Load dataset

In [None]:
transform = tt.Compose([tt.Resize(255),
                        tt.CenterCrop(224),
                        tt.ToTensor(),
                        tt.Normalize(mean=0.4841, std=0.2428, inplace=True)
                        ])

dataset = MergedDataset(device, 
                                    transformLoadingRsna=transform, 
                                    chest_xray=False, cheX=False, kaggle_rsna=True, 
                                    train_percentage=0, 
                                    kaggleRsna_drop_normal_percentage=0.50,
                                    split_seed=2024)
_, test_dl = dataset.getDataLoader()

##### 7.3.1 Test

In [None]:
# Test the model
y_test = []
prob = []
with torch.no_grad():
    for data in test_dl:
        images, labels = data
        images, labels = to_device(images, device), to_device(labels, device)

        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        y_test.extend(labels.cpu().numpy())
        prob.extend(torch.nn.functional.softmax(outputs.cpu(), dim=1)[:, 1])
    fprRsna, tprRsna, thresholds = roc_curve(y_test, prob)
    roc_aucRsna = roc_auc_score(y_test, prob)

    distances = np.sqrt(fprRsna**2 + (1 - tprRsna)**2)
    best_threshold = thresholds[np.argmin(distances)]
    new_preds = [1 if score > best_threshold else 0 for score in prob]

    cm = confusion_matrix(y_test, new_preds)
    dispRsna = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["NEGATIVE", "POSITIVE"])
    accuracyRsna = accuracy_score(y_test, new_preds)
    precisionRsna = precision_score(y_test, new_preds)
    recallRsna = recall_score(y_test, new_preds)
    f1Rsna = f1_score(y_test, new_preds)

#### 7.4 Test Results

##### 7.4.1 Metrics

In [None]:
display_metrics('CXr', 'CXr', [roc_aucChest, accuracyChest, precisionChest, recallChest, f1Chest])
display_metrics('CXr', 'CheX', [roc_aucCheX, accuracyCheX, precisionCheX, recallCheX, f1CheX])
display_metrics('CXr', 'RSNA', [roc_aucRsna, accuracyRsna, precisionRsna, recallRsna, f1Rsna])

print([roc_aucChest, accuracyChest, precisionChest, recallChest, f1Chest, roc_aucCheX, accuracyCheX, precisionCheX, recallCheX, f1CheX, roc_aucRsna, accuracyRsna, precisionRsna, recallRsna, f1Rsna])

##### 7.4.2 Grphs

In [None]:
plot_graphs('CXr', roc_aucChest, fprChest, tprChest, dispChest, [accuracyChest, precisionChest, recallChest, f1Chest])
plot_graphs('CheX', roc_aucCheX, fprCheX, tprCheX, dispCheX, [accuracyCheX, precisionCheX, recallCheX, f1CheX])
plot_graphs('RSNA', roc_aucRsna, fprRsna, tprRsna, dispRsna, [accuracyRsna, precisionRsna, recallRsna, f1Rsna])