# Ensemble 1 Run de novo model with 5 different seeds with evaluation on balanced test set with all biol. classes

<b> What is our goal? <b>

The main goal is to develop a model that has meaningful uncertainty, meaning the uncertainty is low for known and biologically uninteresting classes, but high for unknown and ideally biologically interesting classes. Our hope would be that this new model learns something the original classifier (autophagy_2_1 from SPACRSpy) did not, thereby identifying something new.
The way to test this for the multi-class classifier is by leaving out different biological conditions during training and then checking uncertainty on them as well as by evaluate the new model with screening data, plotting its 8th layer in UMAP and investigating the classifcation scores between the old and new models. 

<b> What have we done so far? <b>

So far we got for Case 1 the 8th layer activations plotted in a UMAP and identified screening hits. Prior to excising these cells, we need to ensure we are certain about which cells to excise and therefore we want to try different ensemble approaches and get a list of cell ids and their respective slide numbers which seem to overlap across the emsembling process and these cells we want to excise. This way we try to correct for the fact that our screening hits may contain technical artifacts. In other words, if we were to excise these 72 screening hits which are potentially contaminated with technical artifacts, we would see no enrichment of any gene. Thus we want to be computationally as confident as possible that these are actually interesting screening hits, then if there is nothing enriched, we at least learn something about the computational approach we used which means our model identifies some technical artifact very confidently. Within a single slide we expect batch effect. We found 72 screening hits in a subset of 7200 cells and a single slide has approx. 300000 cells and thus we expect approx. 3000 screening hits in total.

<b> What data do we have now? <b>

1. Stimulated 14h (or 16h) -> labelled as 0
2. Unstimulated -> labelled as 1
3. ATG5 KO (stimulated but that doesn’t matter, this KO supersedes the stim status [probably looks like unstimulated data]) -> labelled as 2
4. Stimulated timecourse data -> labelled as 3
5. EI24 KO timecourse data (more similar to unstim) -> labelled as 4
6. Screening data (similar to stim) -> labelled as 5


<b> What are we doing in this section? <b>

Here we run the de novo model multiple times with 5 different random seeds. For each run, we identify the screening hits in UMAP, identify overlapping screening hits using intersection methods and thus hopefully identify consistent patterns and reduce the impact of random fluctuations.
We do this all for Case 1(0,2 for training, and all other classes in test)

<b> Why do we do this? <b>

We run the ensembling to become more certain about if the screening hits we see are worth excising or not. Specifically, it allows us to explore the variability in classification and uncertainty across different splits of the data. This process can help identify patterns that a single de novo model run might have missed, especially in the context of uncertainty. We evaluate on how well the de novo model does on the EI24 knockout.
Since El24 looks more distinctly different from both and thus our best positive control; if we can detect this based on uncertainty then we can detect novel biology!
We want the UMAP on the test set to show all biological classes and also we want the images visualized of the screening hits in case there is something visually interesting.
Also we need to cross reference with the cells we have already excised from autophagy_2_1 model as hits since that binary model can recognize anything which is unstim as a hit in the screen and thus we have to filter our resulting cells for those which hasn’t already been excised. 

<b> What to do from here? <b>

1. We could also try the LOF as a score. By applying LOF we can gain an additional layer of confidence in 1. We could also try the LOF as a score. By applying LOF we can gain an additional layer of confidence in identifying truly novel phenotypes. LOF helps differentiate between points that are genuine anomalies and those that are within expected variability. 
2. After training the five models, we could ensemble them by averaging their predictions or using a majority vote for classification. This will help smooth out individual model variability and may highlight outliers or new patterns.
3. We want to run Case 8 (0,2,4|test everything) as a positive control.
4. We could run uncertainty on 3 or 4 classes (our other cases) and have a statistic on how well we do as we increase number of classes and ideally find a nice list of cell ids to excise which contain a novel phenotype.


In [1]:
!pip install torch-intermediate-layer-getter
!pip install umap-learn
!pip install leidenalg
!pip install scanpy==1.9.6
!pip install anndata umap-learn
!pip install watermark

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting torch-intermediate-layer-getter
  Downloading torch_intermediate_layer_getter-0.1.post1.tar.gz (3.0 kB)
Building wheels for collected packages: torch-intermediate-layer-getter
  Building wheel for torch-intermediate-layer-getter (setup.py) ... [?25ldone
[?25h  Created wheel for torch-intermediate-layer-getter: filename=torch_intermediate_layer_getter-0.1.post1-py3-none-any.whl size=3724 sha256=577f35209ddeaae2f1c62cf61ba0c5e974e0b0613fb6d10ec6f48147d81439da
  Stored in directory: /tmp/pip-ephem-wheel-cache-rdw168y5/wheels/6a/11/c0/30d81aa26172d10d68ffaf352b0762eb9fe0a5f5dcf3de63e0
Successfully built torch-intermediate-layer-getter
Installing collected packages: torch-intermediate-layer-getter
Successfully installed torch-intermediate-layer-getter-0.1.post1
You should consider upgrading via the '/usr/bin/python -m pip install --upgrade pip' command.[0m
Looking in indexes: https://pypi.org/simple, http

In [24]:
!pip install nexusformat

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting nexusformat
  Downloading nexusformat-1.0.6-py3-none-any.whl (79 kB)
[K     |████████████████████████████████| 79 kB 1.4 MB/s eta 0:00:01
Collecting hdf5plugin
  Downloading hdf5plugin-5.0.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (45.6 MB)
[K     |████████████████████████████████| 45.6 MB 8.5 MB/s eta 0:00:01
[?25hInstalling collected packages: hdf5plugin, nexusformat
Successfully installed hdf5plugin-5.0.0 nexusformat-1.0.6
You should consider upgrading via the '/usr/bin/python -m pip install --upgrade pip' command.[0m


In [1]:
%load_ext watermark

# Import 
import os
import wandb
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from tqdm.notebook import tqdm
import pandas as pd
import torch
from torch.utils.data import DataLoader, random_split, SubsetRandomSampler
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, confusion_matrix, balanced_accuracy_score
import sys
import seaborn as sn
from torch.utils.tensorboard import SummaryWriter
from torch_intermediate_layer_getter import IntermediateLayerGetter as MidGetter
import umap
import scanpy as sc
import anndata as ad
import re
from collections import Counter
import random
from torch.utils.data import Subset
import h5py
import random
import pickle

from sparcscore.ml.datasets import HDF5SingleCellDataset
# from sparcscore.pipeline.project import TimecourseProject, Project
# from sparcscore.pipeline.workflows import MultithreadedWGATimecourseSegmentation, WGATimecourseSegmentation, MultithreadedCytosolCellposeTimecourseSegmentation, ShardedWGASegmentation, ShardedDAPISegmentationCellpose, WGASegmentation, DAPISegmentationCellpose
from sparcscore.pipeline.extraction import HDF5CellExtraction, TimecourseHDF5CellExtraction
from sparcscore.pipeline.classification import MLClusterClassifier
from sparcscore.ml.pretrained_models import autophagy_classifier2_1

  from .autonotebook import tqdm as notebook_tqdm


NOTE! Installing ujson may make loading annotations faster.


In [3]:
full_hdf5_data = HDF5SingleCellDataset(
    dir_list=['/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93kux/231018_EI24_timecourse_phenix/231018_0317_EI24_fixed_tc/single_cells.h5',
              '/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93kux/231018_EI24_timecourse_phenix/231018_0316_EI24_fixed_tc/single_cells.h5',
              '/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93kux/231018_EI24_timecourse_phenix/231018_0318_EI24_fixed_tc/single_cells.h5',
              '/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93kux/231004_autophagy_screen_6slides/2.3_A002/single_cells.h5', 
              '/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93kux/231004_autophagy_screen_6slides/2.3_B004/single_cells.h5',
              '/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93kux/231004_autophagy_screen_6slides/2.3_D001/single_cells.h5',
              '/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93kux/231004_autophagy_screen_6slides/2.3_F003/single_cells.h5',
              '/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93kux/231004_autophagy_screen_6slides/2.3_H002/single_cells.h5',
              '/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93kux/231004_autophagy_screen_6slides/2.3_K001/single_cells.h5',
              '/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93kux/230714_autophagy_training_data_sample/T_01_stim_Cr203_C6_filtered.h5',
              '/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93kux/230714_autophagy_training_data_sample/T_02_stim_Cr203_C6_filtered.h5',
              '/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93kux/230714_autophagy_training_data_sample/T_01_stim_wt_filtered.h5',
              '/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93kux/230714_autophagy_training_data_sample/T_2.2_stim_wt_filtered.h5',
              '/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93kux/230714_autophagy_training_data_sample/T_2.3_stim_wt_filtered.h5',
              '/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93kux/230714_autophagy_training_data_sample/T_2.X_stim_wt_filtered.h5',
              '/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93kux/230714_autophagy_training_data_sample/T_02_stim_wt_filtered.h5',
              '/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93kux/230714_autophagy_training_data_sample/T_01_unstim_wt_filtered.h5',
              '/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93kux/230714_autophagy_training_data_sample/T_02_unstim_wt_filtered.h5',
              '/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93kux/230714_autophagy_training_data_sample/T_2.2_stim_Cr203_filtered.h5',
              '/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93kux/230714_autophagy_training_data_sample/T_2.3_stim_Cr203_filtered.h5',
              '/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93kux/230714_autophagy_training_data_sample/T_2.X_stim_Cr203_filtered.h5'],
    dir_labels=[4, 4, 4, 5, 5, 5, 5, 5, 5, 3, 3, 0, 0, 0, 0, 0, 1, 1, 2, 2, 2], 
    root_dir='/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93kux/230714_autophagy_training_data_sample/',
    select_channel=4,  # Select the 5th channel (channel index 4)
    return_id=False
)

Total: 3728222
0: 500000
1: 200000
2: 300000
3: 200000
4: 135131
5: 2393091


In [62]:
# Helper function to extract random cells from each class
def sample_cells(dataset, class_label, n_samples, exclude_indices=[]):
    indices = [i for i, cell in enumerate(dataset.data_locator) if cell[0] == class_label]
    
    # Remove any excluded indices
    indices = list(set(indices) - set(exclude_indices))
    
    # Randomly sample n_samples indices from the available ones
    return random.sample(indices, n_samples)

# Create separate Test Set files for each class 
def create_testset(dataset, test_filename_template):
    testset_indices = {}
    
    # For each class, sample 2000 cells and collect them
    for class_label in range(6):
        sampled_indices = sample_cells(dataset, class_label, 2000)
        testset_indices[class_label] = sampled_indices
        
        # Save each class' test set to a separate HDF5 file
        test_filename = test_filename_template.format(label=class_label)
        
        with h5py.File(test_filename, 'w') as f:
            # Create a dataset for all the single cells in this class
            cell_data_list = []
            cell_index_list = []
            
            for idx in sampled_indices:
                data, label = dataset[idx][0].numpy(), dataset[idx][1].item()
                cell_data_list.append(data)
                cell_index_list.append([idx, label])
            
            # Convert lists to numpy arrays
            single_cell_data = np.array(cell_data_list)
            single_cell_index = np.array(cell_index_list, dtype=np.uint64)
            
            # Create the datasets
            f.create_dataset('single_cell_data', data=single_cell_data, dtype='float32')
            f.create_dataset('single_cell_index', data=single_cell_index, dtype='uint64')
    
    return testset_indices

In [74]:
# Create separate Training Set files for selected classes
def create_trainset(dataset, train_filename_template, testset_indices, class_labels, n_samples_per_class=100000):
    trainset_indices = {}
    testset_all_indices = set([idx for indices in testset_indices.values() for idx in indices])  # Flatten testset indices
    
    # Sample cells from the selected classes, excluding testset indices
    for class_label in class_labels:
        sampled_indices = sample_cells(dataset, class_label, n_samples_per_class, exclude_indices=testset_all_indices)
        trainset_indices[class_label] = sampled_indices
        
        # Save each class' training set to a separate HDF5 file
        train_filename = train_filename_template.format(label=class_label)
        
        with h5py.File(train_filename, 'w') as f:
            # Create a dataset for all the single cells in this class
            cell_data_list = []
            cell_index_list = []
            
            for idx in sampled_indices:
                data, label = dataset[idx][0].numpy(), dataset[idx][1].item()
                cell_data_list.append(data)
                cell_index_list.append([idx, label])
            
            # Convert lists to numpy arrays
            single_cell_data = np.array(cell_data_list)
            single_cell_index = np.array(cell_index_list, dtype=np.uint64)
            
            # Create the datasets
            f.create_dataset('single_cell_data', data=single_cell_data, dtype='float32')
            f.create_dataset('single_cell_index', data=single_cell_index, dtype='uint64')
    
    return trainset_indices

In [75]:
# Check for overlap between testset and trainset indices
def check_overlap(testset_indices, trainset_indices):
    testset_all_indices = set([idx for indices in testset_indices.values() for idx in indices])  # Flatten testset indices
    trainset_all_indices = set([idx for indices in trainset_indices.values() for idx in indices])  # Flatten trainset indices
    overlap = testset_all_indices.intersection(trainset_all_indices)
    
    if overlap:
        print(f"Warning: Overlapping indices found between testset and trainset: {overlap}")
    else:
        print("No overlap between testset and trainset.")

##### Create balanced testset with instances from all biological classes

In [None]:
test_filename_template = '/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93quv/balanced_testset_all_classes/testset_{label}.h5'

# Create test set with separate files for each class
testset_indices = create_testset(full_hdf5_data, test_filename_template)

In [None]:
#Having a look if the test set has been created properly by looking into he h5 file
#import nexusformat.nexus as nx
#f = nx.nxload('./testset_0.h5')
#print(f.tree)

In [64]:
# Now, create balanced dataset from the saved files
balanced_testset_all_classes = HDF5SingleCellDataset(
    dir_list=[f'/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93quv/balanced_testset_all_classes/testset_{label}.h5' for label in range(6)],
    dir_labels=[0, 1, 2, 3, 4, 5],  # class labels
    root_dir='/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93quv/balanced_testset_all_classes/',  
    select_channel=4,  # Select the 5th channel
    return_id=False
)

Total: 12000
0: 2000
1: 2000
2: 2000
3: 2000
4: 2000
5: 2000


##### Case 1 balanced test set containing none of the balaneced testset instances

In [76]:
# Define the class labels to include in the training set
train_class_labels = [0, 2] 

# Create training set with separate files for selected class labels
train_filename_template = '/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93quv/training_sets/case1_balanced_trainingset/trainset_{label}.h5'
trainset_indices = create_trainset(full_hdf5_data, train_filename_template, testset_indices, class_labels=train_class_labels)
check_overlap(testset_indices, trainset_indices)

No overlap between testset and trainset.


In [78]:
balanced_trainset_class_0_and_2 = HDF5SingleCellDataset(
    dir_list=[f'/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93quv/training_sets/case1_balanced_trainingset/trainset_{label}.h5' for label in train_class_labels],
    dir_labels=train_class_labels,  # The class labels
    root_dir='/dss/dssfs02/lwp-dss-0001/pn36po/pn36po-dss-0001/di93quv/balanced_testset_all_classes/',  
    select_channel=4,  # Select the 5th channel
    return_id=False
)

Total: 200000
0: 100000
2: 100000


## II. Ensemble 1: De novo model with multiple random seeds <a class="anchor" id="ensemble1"></a>

Here we run the de novo multi class classifier model multiple times with 5 different random seeds. For each run, we identify the screening hits in UMAP, identify overlapping screening hits using intersection methods. We do this to identify consistent patterns and reduce the impact of random fluctuations and technical artefacts and do it on the Monte Carlo dropout to get uncertainty estimates for this.

In [None]:
##################################################################################

In [None]:
# Seed for reproducibility
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)

# Redirect print statements to a file
sys.stdout = open("duplicate_multi_class_output.txt", "w")

# Initialize TensorBoard writer
tensorboard_writer = SummaryWriter('runs/VGG2_autophagy_multi_class_training')

# Log into W&B
wandb.login()
run = wandb.init(project="VGG2_autophagy_multi_class_training")

# Define the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create an instance of model
num_classes = 3
model = MultiClassClassifier(num_classes)
model.to(device)

# Define the loss function and optimizer
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 5  
batch_size = 256 
log_interval = 50  # Log metrics every 50 batches

epsilon = 1e-8  # Small epsilon value to prevent log(0) in uncertainties

# Set train and test data based on the scenario
train_data = case1_hdf5_train_data
test_data = case1_hdf5_test_data

# Create DataLoaders
train_data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_data_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

stop_training = False

# Training loop
for epoch in range(1, num_epochs + 1):
    if stop_training:
        break  
    
    print("Epoch: ", epoch)
    model.train()  # Set model to training mode

    total_loss = 0.0
    correct = 0
    total_samples = len(train_data)
    batch_counter = 0  # Reset batch counter at the start of each epoch

    for batch_idx, (data, labels) in enumerate(train_data_loader):
        data = data.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = loss_function(output, labels)
        loss.backward() 
        optimizer.step()

        total_loss += loss.sum().item()
        _, predicted = output.max(1)
        correct += predicted.eq(labels).sum().item()

        batch_counter += 1
        
        # Calculate accuracy and average loss for the current batch
        accuracy = 100.0 * correct / (batch_counter * batch_size)
        average_loss = total_loss / (batch_counter * batch_size)

        # Check for the desired accuracy and stop training if reached
        if accuracy >= 99.0:
            stop_training = True
            print("Accuracy over 99% reached and thus stopping training...")
            break

    # Calculate and log training metrics
    all_train_labels = []
    all_train_predicted = []

    for data, labels in train_data_loader:
        data = data.to(device)
        labels = labels.to(device)

        output = model(data)
        _, predicted = output.max(1)
        all_train_labels.extend(labels.cpu().numpy())
        all_train_predicted.extend(predicted.cpu().numpy())

    train_precision = precision_score(all_train_labels, all_train_predicted, average='macro')
    train_recall = recall_score(all_train_labels, all_train_predicted, average='macro')
    train_f1 = f1_score(all_train_labels, all_train_predicted, average='macro')
    train_balanced_accuracy = balanced_accuracy_score(all_train_labels, all_train_predicted)

    train_accuracy = accuracy_score(all_train_labels, all_train_predicted) * 100.0

    print("Train Precision: " + str(train_precision) + " Recall: " + str(train_recall) + " F1 score: " + str(train_f1))
    print("Train Balanced Accuracy: {:.2f}%".format(train_balanced_accuracy))
    
    # Log train metrics for the epoch
    wandb.log({
        "Train Epoch": epoch,
        "Train_Precision": train_precision,
        "Train_Recall": train_recall,
        "Train_F1-score": train_f1,
        "Train_Balanced_Accuracy": train_balanced_accuracy,
        "Train_Loss": average_loss,
    })
    
    # Log on TensorBoard
    tensorboard_writer.add_scalar('Train_Precision', train_precision, global_step=epoch)
    tensorboard_writer.add_scalar('Train_Recall', train_recall, global_step=epoch)
    tensorboard_writer.add_scalar('Train_F1-score', train_f1, global_step=epoch)
    tensorboard_writer.add_scalar('Train_Balanced_Accuracy', train_balanced_accuracy, global_step=epoch)
    tensorboard_writer.add_scalar('Train_Loss', average_loss, global_step=epoch)

    correct = 0
    total_loss = 0.0

    # Test loop with dropout and aggregated confusion matrix
    model.eval()
    
    # Enable dropout during testing
    model.apply(lambda m: setattr(m, 'training', True))

    test_correct = 0
    test_average_loss = 0.0
    all_test_labels = []
    all_test_predicted = []
    test_class_uncertainties = [[] for _ in range(num_classes)]

    with torch.no_grad():
        for data, labels in test_data_loader:
            data = data.to(device)
            labels = labels.to(device)

            output = model(data)
            _, predicted = output.max(1)
            test_correct += predicted.eq(labels).sum().item()
            all_test_labels.extend(labels.cpu().numpy())
            all_test_predicted.extend(predicted.cpu().numpy())

            loss = loss_function(output, labels)
            test_average_loss += loss.sum().item()

            # Calculate class probabilities
            probs = torch.nn.functional.softmax(output, dim=1)

            # Calculate uncertainties (entropy) for each class
            uncertainties = [-torch.sum(p * torch.log(p + epsilon)) for p in probs]

            # Store uncertainties for each class
            for i in range(num_classes):
                class_uncertainty = np.mean(uncertainties[i].detach().cpu().numpy())
                test_class_uncertainties[i].extend([class_uncertainty])

    # Filter out predictions and labels for classes seen during training
    mask_seen_classes = np.isin(all_test_labels, [0, 2])  # Class 0 and 2 seen at training
    filtered_test_labels = np.array(all_test_labels)[mask_seen_classes]
    filtered_test_predicted = np.array(all_test_predicted)[mask_seen_classes]

    # Calculate accuracy and loss only for the seen classes
    test_accuracy = accuracy_score(filtered_test_labels, filtered_test_predicted)
    test_average_loss = test_average_loss / len(test_data)

    print("Test Accuracy: {:.2f}%".format(test_accuracy * 100))
    print("Test Loss: {:.4f}".format(test_average_loss))

    wandb.log({
        "Test_Accuracy": test_accuracy,
        "Test_Loss": test_average_loss,
    })

    tensorboard_writer.add_scalar('Test_Accuracy', test_accuracy, global_step=epoch)
    tensorboard_writer.add_scalar('Test_Loss', test_average_loss, global_step=epoch)
    
    
    # Aggregate and log confusion matrix
    aggregated_confusion = confusion_matrix(all_test_labels, all_test_predicted)

    # Confusion matrix
    epsilon = 1e-8
    df_cm = pd.DataFrame(aggregated_confusion / (np.sum(aggregated_confusion, axis=1)[:, None] + epsilon),
                         index=[i for i in range(num_classes)],
                         columns=[i for i in range(num_classes)])

    # Save confusion matrix to TensorBoard
    figure = sn.heatmap(df_cm, annot=True).get_figure()
    tensorboard_writer.add_figure(f'Aggregated Confusion Matrix - Epoch {epoch}', figure, global_step=epoch)
    
    # Set the model back to training mode
    model.train()
    
    # Plot histogram of uncertainties
    class1_uncertainties = test_class_uncertainties[1]
    plt.hist(class1_uncertainties, bins=50, alpha=0.5, color='blue', label='Class 1 Uncertainties')
    plt.xlabel('Uncertainty')
    plt.ylabel('Frequency')
    plt.title('Uncertainty Distribution for Class 1')
    plt.legend()
    plt.savefig(f'uncertainty_histogram_class1_epoch{epoch}.png')
    plt.close()

# Save model
print("Saving final model now...")
torch.save(model.state_dict(), 'multi_class_VGG2_case1.pth')

# Close the W&B run
wandb.finish()

# Close the TensorBoard writer
tensorboard_writer.close()