<a href="https://colab.research.google.com/drive/1y10c6h2j--cnvQaCwZqiFLLKS0aD4fbN?usp=sharing" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction

This Colab notebook presents a comprehensive approach to building a powerful neural network model for semantic segmentation, a fundamental task in computer vision. Semantic segmentation involves assigning a specific class label to each pixel in an image, enabling detailed understanding and precise localization of objects or semantic concepts within the visual scene.

The notebook provides step-by-step instructions and code implementation for various stages of the project. We begin by loading and preprocessing the dataset, which is a crucial step in ensuring the dataset's quality and suitability for training our model. We address common challenges such as noisy annotations, class imbalance, and inconsistent labeling, ensuring that our dataset is clean and representative.

Next, we delve into the architecture design of our neural network model. Leveraging the power of convolutional neural networks (CNNs), we construct a deep learning model capable of capturing fine-grained details and spatial dependencies present in the images. The notebook demonstrates the implementation of advanced architectural designs, enabling our model to achieve state-of-the-art performance in semantic segmentation.

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
# from torch.autograd import Variable
import torchvision.transforms as tr
from skimage import io
from torchvision.transforms.functional import InterpolationMode


import os
import numpy as np
import random

# from scipy.ndimage import zoom
import matplotlib.pyplot as plt

from tqdm import tqdm as tqdm
# from pandas import read_csv
# from math import floor, ceil, sqrt, exp
# from IPython import display
# import time
# from itertools import chain
import time
# import warnings
# from pprint import pprint

import random

# Function for setting the seed
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

set_seed(42)

print('IMPORTS OK')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Global Variables' Definitions

# please change the paths to the dataset accordingly    

PATH_TO_DATASET_IMAGES = '/content/drive/MyDrive/Start'
PATH_TO_DATASET_LABELS = '/content/drive/MyDrive/Finish'

PATH_TO_TRAIN_DATASET_IMAGES = '/content/drive/MyDrive/Dataset2/train/images'
PATH_TO_TRAIN_DATASET_LABEL = '/content/drive/MyDrive/Dataset2/train/ground_truth'

PATH_TO_VAL_DATASET_IMAGES = '/content/drive/MyDrive/Dataset2/val/images'
PATH_TO_VAL_DATASET_LABEL = '/content/drive/MyDrive/Dataset2/val/ground_truth'

PATH_TO_TEST_DATASET_IMAGES = '/content/drive/MyDrive/Dataset2/test/images'
PATH_TO_TEST_DATASET_LABEL = '/content/drive/MyDrive/Dataset2/test/ground_truth'


# please change the costant below to the desidered number
BATCH_SIZE = 64
PATCH_SIDE = 96
N_EPOCHS = 500

NORMALISE_IMGS = True

LOAD_TRAINED = False

DATA_AUG = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


print('DEFINITIONS OK')

In [None]:
def get_tif_images(tif_path):
    images = []
    for filename in os.listdir(tif_path):
        if filename.endswith('.tif'):
            images.append(filename)
    return images

# print(get_tif_images(PATH_TO_DATASET))


class MyDataset(Dataset):
    """
    Custom class to load the dataset
    """
    def __init__(self, img_path, label_path, transform=None, allDataset=False):
        """
        Args:
            img_path (string): path to the folder containing the images
            label_path (string): path to the folder containing the labels
            transform (callable, optional): Optional transform to be applied on a sample
            allDataset (boolean): if True, the dataset is the whole dataset, otherwise it is the train dataset
        """
        # read tif files from tif_path directory
        self.transform = transform
        self.names = get_tif_images(img_path)
        self.color_map = {  (0, 0, 1): 0,
                            (0, 1, 1): 1,
                            (0, 1, 0): 2,
                            (1, 1, 1): 3,
                            (1, 1, 0): 4,
                            (1, 0, 0): 5}

        self.reverse_color_map = {0: (0, 0, 1),
                                    1: (0, 1, 1),
                                    2: (0, 1, 0),
                                    3: (1, 1, 1),
                                    4: (1, 1, 0),
                                    5: (1, 0, 0)}

        self.allDataset=allDataset



        self.data = torch.zeros(len(self.names)*len(transform), 3, PATCH_SIDE, PATCH_SIDE)
        self.labels = torch.zeros(len(self.names)*len(transform), 1, PATCH_SIDE, PATCH_SIDE, dtype=torch.long)

        for idx_transform, transform_element in enumerate(self.transform):
            for idx, name in enumerate(self.names):
                idx += len(self.names) * idx_transform
                img_name = os.path.join(img_path, name)
                image = io.imread(img_name)
                if allDataset:
                    label_name = os.path.join(label_path, name[:-9] + '_label.tif')
                else:
                  label_name = os.path.join(label_path, self.modify_string(name))
                label = io.imread(label_name)
                if self.transform:
                    image = transform_element(image)
                    label = transform_element(label)

                # print name of the image and its shape
                print('Image name: {} \t'.format(name))
                label = self.convert_image(label)
                self.data[idx] = image
                self.labels[idx] = label

    def modify_string(self, input_str):
      parts = input_str.split("_")
      parts.remove('IRRG')
      parts.insert(-1, "label")
      return "_".join(parts)

    def __len__(self):
        return len(self.data)

    def convert_image(self, tensor):
        """
        Convert tensor from 3 channel to one channel, and associate it a number from 0 to 5
        """
        converted_tensor = torch.zeros(1, tensor.size(1), tensor.size(2), dtype=torch.long)

        for x in range(tensor.size(1)):
            for y in range(tensor.size(2)):
                pixel_value = tensor[:, x, y]
                # apply treshold to pixel value for each channel
                for i in range(3):
                    if pixel_value[i] < 0.5:
                        pixel_value[i] = 0
                    else:
                        pixel_value[i] = 1

                color = self.color_map.get(tuple(pixel_value.tolist()), 0)
                converted_tensor[:, x, y] = torch.tensor(color)

        # print number of pixels for each class
        for i in range(6):
            print('Number of pixels for class {}: {}'.format(i, torch.sum(converted_tensor == i)))

        return converted_tensor

    def revert_image(self, tensor):

        converted_tensor = torch.zeros(3, tensor.size(0), tensor.size(1))

        for x in range(tensor.size(0)):
            for y in range(tensor.size(1)):
                pixel_value = tensor[x, y]
                # print(pixel_value.tolist(), type(pixel_value.tolist()))

                color = self.reverse_color_map.get(pixel_value.tolist(), (0, 0, 0))
                converted_tensor[:, x, y] = torch.tensor(color)

        return converted_tensor


    def normalize(self, mean, std):
        for idx in range(len(self.data)):
            self.data[idx] = (self.data[idx] - mean) / std


    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

In [None]:
if DATA_AUG:
    transform_list_train=[]
    transform_list_test=[tr.Compose([tr.ToPILImage(),
                                    tr.Resize((PATCH_SIDE, PATCH_SIDE), interpolation=InterpolationMode.NEAREST),
                                    tr.ToTensor()])]

    for i in range(4):
        transform_list_train.append(tr.Compose([tr.ToPILImage(),
                                    tr.RandomRotation((90*i,90*i)),
                                    tr.Resize((PATCH_SIDE, PATCH_SIDE), interpolation=InterpolationMode.NEAREST),
                                    tr.ToTensor()]))
    for i in range(4):
        transform_list_train.append(tr.Compose([tr.ToPILImage(),
                                    tr.RandomRotation((90*i,90*i)),
                                    tr.Resize((PATCH_SIDE, PATCH_SIDE), interpolation=InterpolationMode.NEAREST),
                                    tr.ToTensor(),
                                    tr.RandomHorizontalFlip(p=1)]))
    for i in range(4):
        transform_list_train.append(tr.Compose([tr.ToPILImage(),
                                    tr.RandomRotation((90*i,90*i)),
                                    tr.Resize((PATCH_SIDE, PATCH_SIDE), interpolation=InterpolationMode.NEAREST),
                                    tr.ToTensor(),
                                    tr.RandomVerticalFlip(p=1)]))
    for i in range(4):
        transform_list_train.append(tr.Compose([tr.ToPILImage(),
                                    tr.RandomRotation((90*i,90*i)),
                                    tr.Resize((PATCH_SIDE, PATCH_SIDE), interpolation=InterpolationMode.NEAREST),
                                    tr.ToTensor(),
                                    tr.RandomHorizontalFlip(p=1),
                                    tr.RandomVerticalFlip(p=1)]))
    data_transform_list = transform_list_train


else:
    transform_list_test=[tr.Compose([tr.ToPILImage(),
                                    tr.Resize((PATCH_SIDE, PATCH_SIDE), interpolation=InterpolationMode.NEAREST),
                                    tr.ToTensor()])]
    transform_list_train=[tr.Compose([tr.ToPILImage(),
                                    tr.Resize((PATCH_SIDE, PATCH_SIDE), interpolation=InterpolationMode.NEAREST),
                                    tr.ToTensor()])]


dataset = MyDataset(img_path=PATH_TO_DATASET_IMAGES, label_path=PATH_TO_DATASET_LABELS, transform=transform_list_test, allDataset = True)

In [None]:
# apply data augmentation
if DATA_AUG:
    train_dataset = MyDataset(img_path=PATH_TO_TRAIN_DATASET_IMAGES, label_path=PATH_TO_TRAIN_DATASET_LABEL, transform=transform_list_train)
else:
    train_dataset = MyDataset(img_path=PATH_TO_TRAIN_DATASET_IMAGES, label_path=PATH_TO_TRAIN_DATASET_LABEL, transform=transform_list_test)

val_dataset = MyDataset(img_path=PATH_TO_VAL_DATASET_IMAGES, label_path=PATH_TO_VAL_DATASET_LABEL, transform=transform_list_test)
test_dataset = MyDataset(img_path=PATH_TO_TEST_DATASET_IMAGES, label_path=PATH_TO_TEST_DATASET_LABEL, transform=transform_list_test)

In [None]:
# calculate mean and std of the dataset for normalization
if NORMALISE_IMGS:
    mean_train_set = torch.mean(dataset.data, dim=(0, 2, 3), keepdim=True)
    std_train_set = torch.std(dataset.data, dim=(0, 2, 3), keepdim=True)

    train_dataset.data = (train_dataset.data - mean_train_set) / std_train_set
    val_dataset.data = (val_dataset.data - mean_train_set) / std_train_set
    test_dataset.data = (test_dataset.data - mean_train_set) / std_train_set

In [None]:
# create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)


print('DATALOADERS OK')

## Model Architecture

The **U-Ne**t is a popular convolutional neural network (CNN) architecture specifically designed for semantic segmentation tasks. It was first introduced by Olaf Ronneberger, Philipp Fischer, and Thomas Brox in 2015. The U-Net architecture is widely used and has achieved state-of-the-art results in various image segmentation challenges.

The unique aspect of the U-Net architecture is its U-shaped or symmetric structure, which consists of an encoder path and a corresponding decoder path. The encoder path captures context and extracts high-level features from the input image, while the decoder path performs precise localization by upsampling and concatenating the features from the encoder path.

The encoder path of the U-Net architecture follows a typical CNN design, with successive convolutional and pooling layers to downsample the spatial dimensions of the input image. This allows the network to capture increasingly abstract and high-level features while reducing the spatial resolution.

The decoder path of the U-Net architecture employs upsampling operations, such as transposed convolutions or bilinear interpolation, to gradually recover the spatial resolution lost during the encoding phase. The feature maps from the encoder path are concatenated with the corresponding feature maps in the decoder path using skip connections. These skip connections help to preserve fine-grained spatial information and provide local context for accurate segmentation.

The U-Net architecture is particularly effective for tasks like biomedical image segmentation, where the objects of interest are often small and surrounded by complex backgrounds. By combining both local and global information, the U-Net can produce segmentation maps with precise object boundaries and handle class imbalance issues.

Due to its architecture's effectiveness and versatility, the U-Net has been widely adopted in various domains, including medical image analysis, satellite image segmentation, and general computer vision applications. It has become a popular choice for researchers and practitioners working on semantic segmentation tasks, offering a powerful tool for pixel-wise classification and accurate object delineation.

<div style="text-align:center">
<img src="https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png" width="600"  />
</div>

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF


class DoubleConvolution(nn.Module):

    def __init__(self, in_ch, out_ch):
        super(DoubleConvolution, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class UNet(nn.Module):

    def __init__(self, input_channels=3, output_channels=6,
                 features=[64, 128, 256, 512]):
        super(UNet, self).__init__()

        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Encoder
        for f in features:
            self.downs.append(DoubleConvolution(input_channels, f))
            input_channels = f

        # lower bottleneck layers
        self.bottleneck = DoubleConvolution(features[-1], features[-1] * 2)

        # Decoder
        for f in reversed(features):
            self.ups.append(
                nn.Sequential(
                    nn.Upsample(scale_factor=2),
                    nn.Conv2d(in_channels=2 * f, out_channels=f, kernel_size=3,
                              padding=1),
                ))
            self.ups.append(DoubleConvolution(2 * f, f))

        self.final_convolution = nn.Conv2d(in_channels=features[0],
                                           out_channels=output_channels,
                                           kernel_size=3, padding=1)

    def forward(self, x):
        skip_connections = list()
        for module in self.downs:
            x = module(x)
            skip_connections.append(x)
            x = self.pool(x)

        skip_connections = skip_connections[::-1]  # reverse order

        x = self.bottleneck(x)

        for i in range(0, len(self.ups), 2):
            x = self.ups[i](x)
            skip_connection = skip_connections[i // 2]
            if skip_connection.shape != x.shape:
                x = TF.resize(x, size=skip_connection.shape[2:],
                              interpolation=TF.InterpolationMode.NEAREST)
            x = torch.cat([skip_connection, x], dim=1)
            x = self.ups[i + 1](x)

        x = self.final_convolution(x)

        return x

In [None]:
# train
def train(model, optimizer, loss_fn, train_loader, epochs=20, device='cuda'):

    train_loss_list = []
    epoch_list = []

    for epoch in range(epochs):
        training_loss = 0.0
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            loss = loss_fn(output, targets.squeeze(1))  # modify
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item() * inputs.size(0)
        training_loss /= len(train_loader.dataset)
        model.eval()

        print('Epoch: {}, Training Loss: {:.4f}'.format(epoch, training_loss))
        train_loss_list.append(training_loss)
        epoch_list.append(epoch)

    return model, optimizer, (train_loss_list, epoch_list)

In [None]:
def test(model, test_loader, device='cuda'):
    model.eval()
    test_loss = 0.0
    num_correct = 0
    num_examples = 0
    predicted_labels = []
    true_labels = []
    for batch in test_loader:
        inputs, targets = batch
        inputs = inputs.to(device)
        output = model(inputs)
        targets = targets.to(device)
        loss = loss_fn(output, targets.squeeze(1))  # modify
        test_loss += loss.data.item() * inputs.size(0)

        # Convert output probabilities to predicted labels
        _, predicted = torch.max(output, dim=1)
        predicted_labels.extend(predicted.cpu())
        true_labels.extend(targets.cpu())

    test_loss /= len(test_loader.dataset)
    print('Test Loss: {:.4f}'.format(test_loss))

    return test_loss, predicted_labels, true_labels

## Validation Test with few images

In [None]:
model = UNet(input_channels=3, output_channels=6)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

class_counts = torch.bincount(dataset.labels.flatten())
print(class_counts)

class_weights = 1.0 / class_counts.float()
class_weights /= torch.sum(class_weights)
class_weights = class_weights.to(device)

print(class_weights)

loss_fn = nn.CrossEntropyLoss(weight=class_weights.to(device))

In [None]:
# train
model, optimizer, (train_loss_list, epoch_list) = train(model, optimizer, loss_fn, train_dataloader, epochs=N_EPOCHS, device=device)

In [None]:
# validation
val_loss, predicted_labels_val, true_labels_val = test(model, val_dataloader, device=device)

In [None]:
for i in range (len(true_labels_val)):
    true_labels_val[i] = true_labels_val[i].numpy()
for i in range (len(predicted_labels_val)):
    predicted_labels_val[i] = predicted_labels_val[i].numpy()

true_labels_val = np.array(true_labels_val)
true_labels_val = np.squeeze(true_labels_val)


predicted_labels_val = np.array(predicted_labels_val).flatten()
true_labels_val = np.array(true_labels_val).flatten()

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
# print accuracy
print("Accuracy:", accuracy_score(true_labels_val, predicted_labels_val))
print("Precision:", precision_score(true_labels_val, predicted_labels_val, average='weighted'))
print("Recall:", recall_score(true_labels_val, predicted_labels_val, average='weighted'))
print("F1:", f1_score(true_labels_val, predicted_labels_val, average='weighted'))

# print confusion matrix

from sklearn.metrics import confusion_matrix
import seaborn as sns

cm = confusion_matrix(true_labels_val, predicted_labels_val)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted label')
plt.ylabel('True label')


# print classification report
from sklearn.metrics import classification_report

print(classification_report(true_labels_val, predicted_labels_val))

## Validation Set with a lot of images

In [None]:
transform_list_train=[]
for i in range(4):
    transform_list_train.append(tr.Compose([tr.ToPILImage(),
                                tr.RandomRotation((90*i,90*i)),
                                tr.Resize((PATCH_SIDE, PATCH_SIDE), interpolation=InterpolationMode.NEAREST),
                                tr.ToTensor()]))
for i in range(4):
    transform_list_train.append(tr.Compose([tr.ToPILImage(),
                                tr.RandomRotation((90*i,90*i)),
                                tr.Resize((PATCH_SIDE, PATCH_SIDE), interpolation=InterpolationMode.NEAREST),
                                tr.ToTensor(),
                                tr.RandomHorizontalFlip(p=1)]))
for i in range(4):
    transform_list_train.append(tr.Compose([tr.ToPILImage(),
                                tr.RandomRotation((90*i,90*i)),
                                tr.Resize((PATCH_SIDE, PATCH_SIDE), interpolation=InterpolationMode.NEAREST),
                                tr.ToTensor(),
                                tr.RandomVerticalFlip(p=1)]))
for i in range(4):
    transform_list_train.append(tr.Compose([tr.ToPILImage(),
                                tr.RandomRotation((90*i,90*i)),
                                tr.Resize((PATCH_SIDE, PATCH_SIDE), interpolation=InterpolationMode.NEAREST),
                                tr.ToTensor(),
                                tr.RandomHorizontalFlip(p=1),
                                tr.RandomVerticalFlip(p=1)]))


train_dataset = MyDataset(img_path=PATH_TO_TRAIN_DATASET_IMAGES, label_path=PATH_TO_TRAIN_DATASET_LABEL, transform=transform_list_train)
train_dataset.data = (train_dataset.data - mean_train_set) / std_train_set
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
model = UNet(input_channels=3, output_channels=6)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

class_counts = torch.bincount(dataset.labels.flatten())
print(class_counts)

class_weights = 1.0 / class_counts.float()
class_weights /= torch.sum(class_weights)
class_weights = class_weights.to(device)

print(class_weights)

loss_fn = nn.CrossEntropyLoss(weight=class_weights.to(device))

In [None]:
# train
model, optimizer, (train_loss_list, epoch_list) = train(model, optimizer, loss_fn, train_dataloader, epochs=N_EPOCHS, device=device)

In [None]:
# val
val_loss, predicted_labels_val, true_labels_val = test(model, val_dataloader, device=device)

In [None]:
for i in range (len(true_labels_val)):
    true_labels_val[i] = true_labels_val[i].numpy()
for i in range (len(predicted_labels_val)):
    predicted_labels_val[i] = predicted_labels_val[i].numpy()

true_labels_val = np.array(true_labels_val)
true_labels_val = np.squeeze(true_labels_val)


predicted_labels_val = np.array(predicted_labels_val).flatten()
true_labels_val = np.array(true_labels_val).flatten()

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
# print accuracy
print("Accuracy:", accuracy_score(true_labels_val, predicted_labels_val))
print("Precision:", precision_score(true_labels_val, predicted_labels_val, average='weighted'))
print("Recall:", recall_score(true_labels_val, predicted_labels_val, average='weighted'))
print("F1:", f1_score(true_labels_val, predicted_labels_val, average='weighted'))

# print confusion matrix

from sklearn.metrics import confusion_matrix
import seaborn as sns

cm = confusion_matrix(true_labels_val, predicted_labels_val)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted label')
plt.ylabel('True label')


# print classification report
from sklearn.metrics import classification_report

print(classification_report(true_labels_val, predicted_labels_val))

## Test Set

In [None]:
# test
test_loss, predicted_labels_test, true_labels_test = test(model, val_dataloader, device=device)

In [None]:
import matplotlib.pyplot as plt


# test
predicted_labels_test_rgb = [dataset.revert_image(label) for label in predicted_labels_test]
true_labels_test_rgb = [dataset.revert_image(label[0]) for label in true_labels_test]

if len(predicted_labels_test_rgb) == 1:
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))

    axes[0].imshow(predicted_labels_test_rgb[0].permute(1,2,0))
    axes[0].set_title("Predicted Label test")
    axes[0].axis('off')

    original_img = dataset.revert_image(dataset.labels[0].reshape(PATCH_SIDE, PATCH_SIDE)).permute(1,2,0)
    axes[1].imshow(original_img)
    axes[1].set_title("Original Label")
    axes[1].axis('off')
else:
    fig, axes = plt.subplots(nrows=len(predicted_labels_test_rgb), ncols=2, figsize=(10, len(predicted_labels_test_rgb)*5))

    for i in range(len(predicted_labels_test_rgb)):
        axes[i, 0].imshow(predicted_labels_test_rgb[i].permute(1,2,0))
        axes[i, 0].set_title("Predicted Label Test")
        axes[i, 0].axis('off')

        # original_img = dataset.revert_image(true_labels_train_rgb[i].permute(1,2,0)
        axes[i, 1].imshow(true_labels_test_rgb[i].permute(1,2,0))
        # axes[i, 1].imshow(original_img)
        axes[i, 1].set_title("Original Label Test")
        axes[i, 1].axis('off')

plt.tight_layout()
plt.show()

In [None]:
plt.plot(epoch_list, train_loss_list, label='Train Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Train and Test Loss')
plt.legend()
plt.show()

In [None]:
for i in range (len(true_labels_test)):
    true_labels_test[i] = true_labels_test[i].numpy()
for i in range (len(predicted_labels_test)):
    predicted_labels_test[i] = predicted_labels_test[i].numpy()

true_labels_test = np.array(true_labels_test)
true_labels_test = np.squeeze(true_labels_test)
print(np.shape(true_labels_test))



import numpy as np

# Convert the lists to numpy arrays
predicted_labels_test = np.array(predicted_labels_test).flatten()
true_labels_test = np.array(true_labels_test).flatten()

# Count the occurrences of each class label
predicted_counts = np.bincount(predicted_labels_test)
true_counts = np.bincount(true_labels_test)

# Print the counts
for label, count in enumerate(predicted_counts):
    print(f"Predicted label {label}: {count} occurrences")

for label, count in enumerate(true_counts):
    print(f"True label {label}: {count} occurrences")

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

true_labels_test = np.array(true_labels_test)
true_labels_test = np.squeeze(true_labels_test)

# print accuracy
print("Accuracy:", accuracy_score(true_labels_test, predicted_labels_test, ))
print("Precision:", precision_score(true_labels_test, predicted_labels_test, average='weighted'))
print("Recall:", recall_score(true_labels_test, predicted_labels_test, average='weighted'))
print("F1:", f1_score(true_labels_test, predicted_labels_test, average='weighted'))

# print confusion matrix

from sklearn.metrics import confusion_matrix
import seaborn as sns

cm = confusion_matrix(true_labels_test, predicted_labels_test)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted label')
plt.ylabel('True label')


# print classification report
from sklearn.metrics import classification_report

print(classification_report(true_labels_test, predicted_labels_test))