# Classifier Training

## Initial Setup

In [None]:
import sys
import os
from dotenv import load_dotenv

load_dotenv()
parent_folder_path = os.getenv('PARENT_FOLDER', '.')
sys.path.append(parent_folder_path)

In [None]:
# Imports

import gc
from joblib import Parallel, delayed
from collections import OrderedDict
import platform

import numpy as np
import cv2
from datasets import load_dataset
import torch
from torch import nn, optim
import torch.nn.functional as functional
from torch.utils.data import Dataset, DataLoader
from torchvision.models.resnet import resnet18
from torchsummary import summary
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, precision_score, recall_score, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import clear_output
from tqdm.auto import tqdm
# import torch_directml


from libraries.segmentation.k_means import kmeans, KmeansFlags, KmeansTermCrit, KmeansTermOpt
from libraries.improving.filtering import conv2d
from dataset import TEST_DATASET, TRAIN_DATASET, get_elements_from_indexes, LABELS

In [None]:
# MatPlotLib Configuration
%matplotlib widget
plt.ioff()

In [None]:
# Set device
def get_device():
    if torch.cuda.is_available():
        # Use CUDA if available
        device = torch.device("cuda")
    elif platform.system() == "Windows":
        # For Windows, use torch_directml if available
        try:
            import torch_directml
            device = torch_directml.device()
        except ImportError:
            # Fallback to CPU if torch_directml is not installed
            device = torch.device("cpu")
    elif platform.system() == "Darwin":
        # For macOS, use MPS (Metal Performance Shaders) if available, otherwise CPU
        if torch.backends.mps.is_available():
            device = torch.device("mps")
        else:
            device = torch.device("cpu")
    else:
        # Fallback for other platforms
        device = torch.device("cpu")
    return device
device = get_device()

In [None]:
device

In [None]:
# Pre-processed dataset files
if not os.path.exists('pre_processed_dataset'):
    os.mkdir('pre_processed_dataset')
TRAIN_DATA_FILE = 'pre_processed_dataset/train.npy'
TRAIN_LABELS_FILE = 'pre_processed_dataset/train_labels.npy'
TEST_DATA_FILE = 'pre_processed_dataset/test.npy'
TEST_LABELS_FILE = 'pre_processed_dataset/test_labels.npy'

In [None]:
# Functions
# Plot single-channel images with a fixed configuration
def draw_grayscale_image(image, ax):
    ax.imshow(image, cmap='gray', vmin=0, vmax=255)
    ax.axis('off')
    
# Clear output and then execute decorator
def clear_and_execute(func):
    def wrapper(*args, **kwargs):
        clear_output()
        return func(*args, **kwargs)
    return wrapper

# Image pre-processing
# Filter definition
epsilon = 0.1
sigma = 3
gaussianDim = int(np.ceil(np.sqrt(-2 * sigma ** 2 * np.log(epsilon * sigma * np.sqrt(2 * np.pi)))))
gaussianKernel1D = cv2.getGaussianKernel(gaussianDim, sigma)
gaussianKernel = np.outer(gaussianKernel1D, gaussianKernel1D)
@clear_and_execute
def process_element(element):
    image = element[0]
    label = element[1]
    filtered_image = conv2d(image, gaussianKernel)
    compactness, labels, centers = kmeans(
        filtered_image.flatten(), 3, 
        criteria=KmeansTermCrit(KmeansTermOpt.BOTH, 20, 0.5),
        flags=KmeansFlags.KMEANS_PP_CENTERS, attempts=5
    )
    centers = centers.astype(np.uint8)
    segmented_kmeans = centers[labels].reshape(image.shape)
    sorted_centers = sorted(centers)
    white_matter_idx = np.argmax(centers == sorted_centers[2])
    grey_matter_idx = np.argmax(centers == sorted_centers[1])
    segmented_white_matter = np.where(segmented_kmeans == centers[white_matter_idx], 1, 0).astype(np.uint8)
    segmented_grey_matter = np.where(segmented_kmeans == centers[grey_matter_idx], 1, 0).astype(np.uint8)
    return np.array((image, segmented_white_matter, segmented_grey_matter)), label

# Parallel processing of images
def parallel_processing(dataset_elements, img_shape):
    dataset_length = len(dataset_elements)
    data = np.zeros((dataset_length, 3, *img_shape), dtype=np.uint8)
    labels = np.zeros((dataset_length,), dtype=np.uint8)
    results = Parallel(n_jobs=8)(delayed(process_element)(element) for element in tqdm(dataset_elements))
    for index in range(len(results)):
        data[index] = results[index][0]
        labels[index] = results[index][1]
    return data, labels.astype(np.uint8)

## Dataset Preprocessing

In [None]:
# Train Data Processing
if not (os.path.exists(TRAIN_DATA_FILE) and os.path.exists(TRAIN_LABELS_FILE)):
    train_dataset = load_dataset(**TRAIN_DATASET)
    train_dataset_elements = get_elements_from_indexes(train_dataset, np.arange(len(train_dataset)))
    img_shape = train_dataset_elements[0][0].shape
    train_data, train_labels = parallel_processing(train_dataset_elements, img_shape)
    print("Saving training data")
    np.save(TRAIN_DATA_FILE, train_data)
    np.save(TRAIN_LABELS_FILE, train_labels)

In [None]:
# Test Data Processing
if not (os.path.exists(TEST_DATA_FILE) and os.path.exists(TEST_LABELS_FILE)):
    test_dataset = load_dataset(**TEST_DATASET)
    test_dataset_elements = get_elements_from_indexes(test_dataset, np.arange(len(test_dataset)))
    img_shape = test_dataset_elements[0][0].shape
    test_data, test_labels = parallel_processing(test_dataset_elements, img_shape)
    print("Saving testing data")
    np.save(TEST_DATA_FILE, test_data)
    np.save(TEST_LABELS_FILE, test_labels)

In [None]:
# Validate the results
train_data = np.load(TRAIN_DATA_FILE)
train_labels = np.load(TRAIN_LABELS_FILE)
random_index = np.random.randint(train_data.shape[0])
sample = train_data[random_index]
sample_image, sample_white_matter, sample_grey_matter = sample[0], sample[1], sample[2]
sample_label = train_labels[random_index]
if 'sample_fig' in globals():
    plt.close('Sample image')
sample_fig = plt.figure(figsize=(6, 2.3), num="Sample Image")
sample_axs = sample_fig.subplots(1, 3)
draw_grayscale_image(sample_image, sample_axs[0])
draw_grayscale_image(sample_white_matter * 255, sample_axs[1])
draw_grayscale_image(sample_grey_matter * 255, sample_axs[2])
sample_fig.suptitle(f"Example Figure, label: {LABELS[sample_label]}")
sample_fig.tight_layout()

In [None]:
sample_fig

## Dataset Handlers & Dataloaders

In [None]:
# Dataset definition
class DatasetHandler(Dataset):
    def __init__(self, path_to_data: str, path_to_labels: str):
        self.data = torch.tensor(np.load(path_to_data), dtype=torch.float)
        self.data[:,0] /= 255 # Must be divided to be between 0 and 1
        labels = np.load(path_to_labels)
        self.labels = torch.zeros((len(labels), 4))
        for index in range(len(labels)):
            self.labels[index][labels[index]] = 1
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]
    def __len__(self):
        return len(self.data)

In [None]:
# Creation of dataloaders
train_dataset_handler = DatasetHandler(TRAIN_DATA_FILE, TRAIN_LABELS_FILE)
test_dataset_handler = DatasetHandler(TEST_DATA_FILE, TEST_LABELS_FILE)
batch_size_train = 32
batch_size_test = 128
n_workers = 4
train_dataloader = DataLoader(train_dataset_handler, batch_size=batch_size_train, shuffle=True,
                              num_workers=0, pin_memory=False)
test_dataloader = DataLoader(test_dataset_handler, batch_size=batch_size_test, shuffle=True,
                             num_workers=0, pin_memory=False)

## Module Initialization Functions

In [None]:
def init_weights(net, init='norm', gain=0.02, verbose: bool = True):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and 'Conv' in classname:
            if init == 'norm':
                nn.init.normal_(m.weight.validation_data, mean=0.0, std=gain)
            elif init == 'xavier':
                nn.init.xavier_normal_(m.weight.validation_data, gain=gain)
            elif init == 'kaiming':
                nn.init.kaiming_normal_(m.weight.validation_data, a=0, mode='fan_in')
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.validation_data, 0.0)
        elif 'BatchNorm2d' in classname:
            nn.init.normal_(m.weight.validation_data, 1., gain)
            nn.init.constant_(m.bias.validation_data, 0.)
    net.apply(init_func)
    if verbose:
        print(f"Model initialized with {init} initialization.")
    return net


def init_model(model, use_current_weights: bool = True, verbose: bool = True):
    if use_current_weights:
        if verbose:
            print("Using current weights for the model.")
    else:
        model = init_weights(model, verbose=verbose)
        if verbose:
            print("Initializing model weights.")
    return model

## Architecture

### CNN

#### Initialization with Pre-Trained Weights

In [None]:
from torchvision.models import efficientnet_b0

efficientnet_model = efficientnet_b0(weights='IMAGENET1K_V1')

efficientnet_skeleton = nn.Sequential(*list(efficientnet_model.children())[:-1])

In [None]:
summary(efficientnet_skeleton, input_size=(3,224,224))

#### Custom CNN

In [None]:
def get_classification_layer_efficientnet(in_channels=1280, out_channels=1000):
    return nn.Sequential(
        nn.AdaptiveAvgPool2d(1),
        nn.Flatten(),
        nn.Dropout(0.2),
        nn.Linear(in_features=in_channels, out_features=out_channels),
    )

class CustomEfficientNetB0(nn.Module):
    def __init__(self, skeleton: nn.Module, use_current_weights: bool = True):
        super().__init__()
        self.skeleton = skeleton
        self.classifier = init_model(get_classification_layer_efficientnet(1280, 4), use_current_weights=use_current_weights)
    
    def forward(self, input):
        return self.classifier(self.skeleton(input))
    
    def set_requires_grad_skeleton(self, requires_grad: bool = True):
        for param in self.skeleton.parameters():
            param.requires_grad = requires_grad

#### Main Model Definition

In [None]:
class MainModel(nn.Module):
    def __init__(self, skeleton: nn.Module, lr=1e-3, beta1=0.9, beta2=0.999, use_current_weights: bool = True, device=device):
        super().__init__()
        self.device = device
        self.net = CustomEfficientNetB0(skeleton, use_current_weights=use_current_weights).to(self.device)  # Changed this line
        self.criterion = nn.CrossEntropyLoss()
        self.opt = optim.Adam(self.net.parameters(), lr=lr, betas=(beta1, beta2))
        self.input = None
        self.target = None
        self.prediction = None
        self.ce_loss = None
    def set_requires_grad(self, requires_grad=True):
        for p in self.net.parameters():
            p.requires_grad = requires_grad
    def setup_input(self, data):
        self.input = data[0].to(self.device)
        self.target = data[1].to(self.device)
    def forward(self):
        self.prediction = self.net(self.input)
    def backward(self):
        self.ce_loss = self.criterion(self.prediction, self.target)
        self.ce_loss.backward()
    def optimize(self):
        self.forward()
        self.net.train()
        self.opt.zero_grad()
        self.backward()
        self.opt.step()

## Training

### Model Initialization

In [None]:
from_checkpoint = False
checkpoint_num = 1
requires_grad_skeleton = True

In [None]:
classifier = MainModel(efficientnet_skeleton, lr=5e-4)
test_dataloader_iter = iter(test_dataloader)
if not from_checkpoint:
    last = 0
else:
    checkpoint = checkpoint_num
    if checkpoint is None:
        last = 0
        classifier.load_state_dict(torch.load(os.path.join(parent_folder_path, "model", "model.pth")))
    else:
        last = checkpoint
        classifier.load_state_dict(torch.load(os.path.join(parent_folder_path, "model", "checkpoints", f"model_chckpt{checkpoint}.pth")))
if requires_grad_skeleton:
    classifier.net.set_requires_grad_skeleton(True)
else:
    classifier.net.set_requires_grad_skeleton(False)

### Training Functions

In [None]:
class ModelCheckpoint:
    def __init__(self, filepath='best_model.pth', monitor='val_f1', mode='max', save_best_only=True, verbose=True, current_value=None):
        self.filepath = filepath
        self.monitor = monitor
        self.mode = mode
        self.save_best_only = save_best_only
        self.verbose = verbose
        
        if mode == 'max':
            self.best = float('-inf') if not current_value else current_value
            self.monitor_op = lambda current, best: current > best
        elif mode == 'min':
            self.best = float('inf') if not current_value else current_value
            self.monitor_op = lambda current, best: current < best
        else:
            raise ValueError(f"Mode {mode} is unknown, please use 'max' or 'min'")
    
    def __call__(self, current_value, model):
        if self.monitor_op(current_value, self.best):
            if self.verbose:
                print(f'\nEpoch: {self.monitor} improved from {self.best:.5f} to {current_value:.5f}, saving model to {self.filepath}')
            self.best = current_value
            torch.save(model.state_dict(), self.filepath)
            return True
        else:
            if self.verbose and not self.save_best_only:
                print(f'\nEpoch: {self.monitor} did not improve from {self.best:.5f}')
            return False

def evaluate_model(model, test_dataloader, device):
    """Evaluate model on test set and return metrics"""
    model.net.eval()
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for data in test_dataloader:
            model.setup_input(data)
            model.forward()
            
            predicted = np.argmax(functional.softmax(model.prediction.detach().cpu(), dim=1), axis=1)
            labels = np.argmax(data[1].numpy(), axis=1)
            
            all_predictions.extend(predicted)
            all_labels.extend(labels)
    
    # Calculate metrics
    f1 = f1_score(all_labels, all_predictions, average='weighted')
    accuracy = accuracy_score(all_labels, all_predictions)
    
    model.net.train()  # Set back to training mode
    return {
        'val_f1': f1,
        'val_accuracy': accuracy,
        'predictions': all_predictions,
        'labels': all_labels
    }

class AverageMeter:
    def __init__(self):
        self.count, self.avg, self.sum = [0.] * 3
    def update(self, val, count=1):
        self.count += count
        self.sum += count * val
        self.avg = self.sum / self.count


def create_loss_meters():
    loss = AverageMeter()
    return {'ce_loss': loss}


def update_losses(model, loss_meter_dict, count):
    for loss_name, loss_meter in loss_meter_dict.items():
        loss = getattr(model, loss_name)
        loss_meter.update(loss.item(), count=count)


def visualize(model, data, fig = None, ax = None):
    labels = np.argmax(data[1], axis=1)
    with torch.no_grad():
        model.setup_input(data)
        model.forward()
    predicted = np.argmax(functional.softmax(model.prediction.detach().cpu(), dim=1), axis=1)
    c_mat = confusion_matrix(labels, predicted, labels=np.arange(4))
    if ax is not None:
        ax.clear()
    categories = list(map(lambda x: x[1], sorted(list(LABELS.items()), key=lambda x: x[0])))
    sns.heatmap(c_mat/c_mat.sum(), 
                xticklabels=categories, yticklabels=categories,
                cmap='Blues',
                fmt='.2%',
                ax=ax,
                cbar=False)
    if fig is not None:
        fig.show()


def log_results(loss_meter_dict):
    log = ""
    for loss_name, loss_meter in loss_meter_dict.items():
        log += f"{loss_name}: {loss_meter.avg:.5f} "
    return log

def visualize_from_predictions(predictions, labels, fig=None, ax=None):
    """Visualize confusion matrix from predictions and labels"""
    c_mat = confusion_matrix(labels, predictions, labels=np.arange(4))
    if ax is not None:
        ax.clear()
    categories = list(map(lambda x: x[1], sorted(list(LABELS.items()), key=lambda x: x[0])))
    sns.heatmap(c_mat/c_mat.sum(), 
                xticklabels=categories, yticklabels=categories,
                cmap='Blues',
                fmt='.2%',
                ax=ax,
                cbar=False)
    if fig is not None:
        fig.canvas.draw_idle()

def load_best_model(model: MainModel, filepath='best_model.pth'):
    """Load the best model weights"""
    if os.path.exists(filepath):
        model.load_state_dict(torch.load(filepath))
        print(f"Loaded best model weights from {filepath}")
    else:
        print(f"No saved model found at {filepath}")
    return model

def train_model_with_callbacks(model: MainModel, train_dl: DataLoader, test_dl: DataLoader,
                               epochs: int = 100, display_every: int = 1, 
                               check_point_start: int = 0, checkpoint_callback=None,
                               fig=None, ax=None):
    """Enhanced training function with callback support"""
    check_point = check_point_start
    best_metrics = {'val_f1': 0, 'val_accuracy': 0}

    # Initial performance evaluation
    val_metrics = evaluate_model(model, test_dl, model.device)
    
    # Initialize checkpoint callback if not provided
    if checkpoint_callback is None:
        checkpoint_callback = ModelCheckpoint(
            filepath='best_model.pth', 
            monitor='val_f1', 
            mode='max', 
            verbose=True,
            current_value=val_metrics['val_f1']
        )
    
    epoch_tqdm = tqdm(range(epochs), "Epochs", position=0, leave=True)
    
    for epoch in epoch_tqdm:
        # Training phase
        model.net.train()
        loss_meter_dict = create_loss_meters()
        batch_tqdm = tqdm(train_dl, "Batches", position=1, leave=False)
        
        for data in batch_tqdm:
            model.setup_input(data)
            model.optimize()
            update_losses(model, loss_meter_dict, count=len(data[0]))
        
        batch_tqdm.close()
        
        # Validation phase
        val_metrics = evaluate_model(model, test_dl, model.device)
        
        # Update best metrics
        if val_metrics['val_f1'] > best_metrics['val_f1']:
            best_metrics = val_metrics.copy()
        
        # Checkpoint callback
        is_best = checkpoint_callback(val_metrics['val_f1'], model)
        
        # Logging
        train_loss = log_results(loss_meter_dict)
        epoch_desc = f"{train_loss}val_f1: {val_metrics['val_f1']:.5f} val_acc: {val_metrics['val_accuracy']:.5f}"
        if is_best:
            epoch_desc += " [BEST]"
        epoch_tqdm.set_description(epoch_desc)
        
        # Visualization
        if not (epoch + 1) % display_every and fig is not None and ax is not None:
            visualize_from_predictions(val_metrics['predictions'], val_metrics['labels'], fig, ax)
    
    return check_point, best_metrics, checkpoint_callback.best


### Validation Set Performance Supervision During Training

In [None]:
## Training Supervision
if 'train_figure' in globals():
    plt.close('Test Validation')
train_figure = plt.figure(num='Test Validation', figsize=(8,8))
train_ax = train_figure.subplots(1,1)
train_figure.show()

### Training

In [None]:
# Training with callbacks
N_epochs = 100
Show_Every_N_epochs = 1

# Create checkpoint callback
best_model_callback = ModelCheckpoint(
    filepath='best_model.pth',
    monitor='val_f1',
    mode='max',
    save_best_only=True,
    verbose=True
)

# Train with callbacks
last, best_metrics, best_f1_score = train_model_with_callbacks(
    classifier, 
    train_dataloader, 
    test_dataloader,  # Use full dataloader instead of iterator
    N_epochs, 
    Show_Every_N_epochs,
    last, 
    best_model_callback,
    train_figure, 
    train_ax
)

print(f"\nTraining completed!")
print(f"Best validation F1 score: {best_f1_score:.5f}")
print(f"Best validation accuracy: {best_metrics['val_accuracy']:.5f}")

## Evaluation

### Checkpoint Evaluation

In [None]:
# Load the best model automatically
best_classifier = MainModel(efficientnet_skeleton)
best_classifier = load_best_model(best_classifier, 'best_model.pth')

# Get final predictions
validation_data = test_dataset_handler[:]
with torch.no_grad():
    best_classifier.setup_input(validation_data)
    best_classifier.forward()
predicted = np.argmax(functional.softmax(best_classifier.prediction.detach().cpu(), dim=1), axis=1)
labels = np.argmax(validation_data[1], axis=1)

print(f"Final evaluation on best model:")
print(f"F1 Score: {f1_score(labels, predicted, average='weighted'):.5f}")
print(f"Accuracy: {accuracy_score(labels, predicted):.5f}")

### Best Model Results

In [None]:
# Function to plot the confusion matrix
def plot_confusion_matrix(ground_truth, predictions, fig = None, ax = None, mode: str = "recall"):
    c_mat = confusion_matrix(ground_truth, predictions, labels=np.arange(4)).astype(np.float64)
    if mode == "recall":
        row_sums = c_mat.sum(axis=1)
        for index in range(c_mat.shape[0]):
            c_mat[index] /= row_sums[index]
    elif mode == "precision":
        column_sums = c_mat.sum(axis=0)
        for index in range(c_mat.shape[1]):
            c_mat[:, index] /= column_sums[index]
    else:
        raise ValueError("Mode should be either 'recall' or 'precision'")
    if ax is not None:
        ax.clear()
    categories = list(map(lambda x: x[1], sorted(list(LABELS.items()), key=lambda x: x[0])))
    sns.heatmap(c_mat, 
                xticklabels=categories, yticklabels=categories,
                cmap='Blues',
                fmt='.2%',
                ax=ax,
                cbar=False,
                annot=True)
    if fig is not None:
        fig.tight_layout()
        fig.show()

In [None]:
# Show Confusion Matrix
if 'fig_results' in globals():
    plt.close('Results')
fig_results = plt.figure(num='Results', figsize=(7,7))
ax_results = fig_results.subplots(1,1)
fig_results.suptitle('Recall', fontsize=14, fontweight='bold')
fig_results.supxlabel('Model Predictions')
fig_results.supylabel('Ground Truth')
plot_confusion_matrix(labels, predicted, fig_results, ax_results, "recall")
fig_results.savefig("confusion_matrix_recall.png")

In [None]:
# Show Report
target_names = list(map(lambda x: x[1], sorted(list(LABELS.items()), key=lambda x: x[0])))
report = classification_report(labels, predicted, target_names=target_names)
score = f1_score(labels, predicted, average='weighted')
accuracy = accuracy_score(labels, predicted)
recall = recall_score(labels, predicted, average='weighted')
precision = precision_score(labels, predicted, average='weighted')

In [None]:
print(report)

In [None]:
print(f'Model results:\nAccuracy: {accuracy:.2f}\nRecall: {recall:.2f}\nPrecision: {precision:.2f}')

In [None]:
# Save checkpoint as the deployment model
torch.save(best_classifier.state_dict(), os.path.join(rf'model.pth'))

In [None]:
# Run this cell when the system's memory is running low
gc.collect()
clear_output()