<div style="background-color: #34495e; border-bottom: 5px solid #95a5a6; padding: 10px;">
    <h2 style="color: white;">Installing the Required Libraries</h2>
</div>

In [None]:
!pip install timm grad-cam lime scikit-image scipy

In [None]:
!pip install -U kaleido

<div style="background-color: #34495e; border-bottom: 5px solid #95a5a6; padding: 10px;">
    <h2 style="color: white;">Imports</h2>
</div>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, WeightedRandomSampler, Dataset
import torchvision.models as models
import torch.optim as optim

import os
import time
import copy
import random
import pickle

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import cv2
import plotly.express as px
import plotly.io as pio

from PIL import Image

from collections import OrderedDict
from typing import Tuple, Union

from timm.scheduler import CosineLRScheduler
from sklearn.model_selection import StratifiedKFold
from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, \
    MulticlassRecall, MulticlassF1Score, MulticlassConfusionMatrix, \
    MulticlassMatthewsCorrCoef, MulticlassSpecificity

<div style="background-color: #34495e; border-bottom: 5px solid #95a5a6; padding: 10px;">
    <h2 style="color: white;">Config Parameters</h2>
</div>

In [None]:
class Config:
    SEED = 2023
    
    # Model Parameters
    RESOLUTION = 224
    MLP_INPUT_DIM = 2048
    MLP_HIDDEN_DIM = 512
    MLP_DROPOUT_RATE = 0.2
    
    BATCH_SIZE = 64
    NUM_WORKERS = 2
    NUM_CLASSES = 8
    
    # Training Hyperparameters
    COOLDOWN_EPOCHS = 5
    TRAIN_EPOCHS = 25
    NUM_EPOCHS = TRAIN_EPOCHS + COOLDOWN_EPOCHS
    BASE_LR = 0.0001
    WEIGHT_DECAY = 0.00001
    PRINT_FREQ = 1024
    SAVE_FREQ = 5
    
    # Cosine LR Decay Parameters
    LR_MIN = 1e-6
    WARMUP_LR_INIT = 5e-5
    CYCLE_DECAY = 0.5
    CYCLE_LIMIT = 1
    WARMUP_EPOCHS = 3

<div style="background-color: #34495e; border-bottom: 5px solid #95a5a6; padding: 10px;">
    <h2 style="color: white;">Paths</h2>
</div>

In [None]:
INPUT_DIR = '/kaggle/input'
OUTPUT_DIR = '/kaggle/working/'

BASE_DIR = INPUT_DIR + '/kvasir-v2/kvasir-dataset-v2'
DATA_FILES_DIR = INPUT_DIR + '/kvasir-v2-folds'

PRETRAINED_WEIGHTS_DIR = INPUT_DIR + '/kvasir-v2-resnet50-epochs-100'
PRETRAINED_WEIGHTS_FILENAME = '/model_fold_0_epoch_30.pth'

In [None]:
PATHS = [
    '/dyed-lifted-polyps',
    '/dyed-resection-margins',
    '/esophagitis',
    '/normal-cecum',
    '/normal-pylorus',
    '/polyps',
    '/ulcerative-colitis',
    '/normal-z-line',
]
IMAGES = [
    '/008aa3ed-1812-4854-954c-120ae85bb6bd.jpg',
    '/0062bbf3-58d7-435d-b0ca-381703c39911.jpg',
    '/001fb927-4814-4ba5-851d-189db99291d8.jpg',
    '/0ab26ff5-3161-4a17-bcf4-95663033af0a.jpg',
    '/005959d0-b75b-41ed-8da1-2a5d0666d612.jpg',
    '/00072d5f-7cd8-434c-8a5a-1a0bb2c9711d.jpg',
    '/005b9962-41ed-4ae4-8aae-395bbab93fd7.jpg',
    '/00bee375-36d2-4ba9-89e5-bd6132d79c0c.jpg'
]

<div style="background-color: #34495e; border-bottom: 5px solid #95a5a6; padding: 10px;">
    <h2 style="color: white;">EDA: Exploratory Data Analysis</h2>
</div>

In [None]:
# Function to load and display one image from each class
def display_sample_images(root_folder, num_classes=Config.NUM_CLASSES):
    fig, axes = plt.subplots(2, 4, figsize=(12, 6), dpi=600)
    axes = axes.flatten()

    for i in range(num_classes):
        img_path = root_folder + PATHS[i] + IMAGES[i]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        axes[i].imshow(img)
        axes[i].set_title(PATHS[i][1:])
        axes[i].axis('off')

    plt.tight_layout()
    plt.savefig(OUTPUT_DIR + 'sample_images.pdf', bbox_inches='tight')
    plt.show()

In [None]:
display_sample_images(BASE_DIR)

In [None]:
# Function to display data distribution across classes using plotly
def display_data_distribution(root_folder):
    classes = os.listdir(root_folder)
    num_images = [len(os.listdir(os.path.join(root_folder, class_folder))) for class_folder in classes]

    df = pd.DataFrame({'Class': classes, 'Number of Images': num_images})
    fig = px.bar(df, x='Class', y='Number of Images', title='Data Distribution Across Classes',
                 labels={'Number of Images': 'Number of Images', 'Class': 'Class'},
                 text='Number of Images',
                 height=500)
    fig.update_traces(texttemplate='%{text}', textposition='outside')
    fig.update_layout(height=600, width=800, xaxis=dict(tickangle=45))
    pio.write_image(fig, OUTPUT_DIR + 'data_distribution.pdf')
    fig.show()

In [None]:
display_data_distribution(BASE_DIR)

In [None]:
# Function to display image size distribution for each class using plotly
def display_image_size_distribution(root_folder):
    classes = os.listdir(root_folder)
    data = []

    for class_folder in classes:
        class_path = os.path.join(root_folder, class_folder)
        image_sizes = []

        for img_file in os.listdir(class_path):
            img_path = os.path.join(class_path, img_file)
            img = cv2.imread(img_path)
            image_sizes.append(img.shape[:2])

        image_sizes = np.array(image_sizes)
        average_sizes = np.mean(image_sizes, axis=0)
        data.append({'Class': class_folder, 'Width': average_sizes[0], 'Height': average_sizes[1]})

    df = pd.DataFrame(data)

    fig = px.bar(df, x='Class', y=['Width', 'Height'], title='Image Size Distribution Across Classes',
                 labels={'value': 'Average Image Size (pixels)', 'variable': 'Dimension'},
                 color_discrete_sequence=['skyblue', 'salmon'])

    # Add text annotations inside each bar
    for i, class_label in enumerate(df['Class']):
        fig.add_annotation(
            x=class_label,
            y=df.loc[i, 'Width'] / 2,
            text=f'{round(df.loc[i, "Width"], 1)}',
            showarrow=False,
            font=dict(color='white', size=14)
        )
        fig.add_annotation(
            x=class_label,
            y=df.loc[i, "Width"] + (df.loc[i, 'Height'] / 2),
            text=f'{round(df.loc[i, "Height"], 1)}',
            showarrow=False,
            font=dict(color='white', size=14)
        )

    # Update layout
    fig.update_layout(xaxis=dict(tickangle=45))
    pio.write_image(fig, OUTPUT_DIR + 'image_size_distribution.pdf')
    fig.show()

In [None]:
display_image_size_distribution(BASE_DIR)

In [None]:
# Function to display color distribution using plotly
def display_color_distribution(root_folder):
    classes = os.listdir(root_folder)
    data = []

    for class_folder in classes:
        class_path = os.path.join(root_folder, class_folder)
        class_color_distribution = np.zeros(3)

        for img_file in os.listdir(class_path):
            img_path = os.path.join(class_path, img_file)
            img = cv2.imread(img_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            class_color_distribution += np.mean(img, axis=(0, 1))

        class_color_distribution /= len(os.listdir(class_path))
        data.append({'Class': class_folder, 'Red': class_color_distribution[0],
                     'Green': class_color_distribution[1], 'Blue': class_color_distribution[2]})

    df = pd.DataFrame(data)

    fig = px.bar(df, x='Class', y=['Red', 'Green', 'Blue'], title='Color Distribution Across Classes',
                 labels={'value': 'Average Color Value', 'variable': 'Color Channel'},
                 color_discrete_sequence=['red', 'green', 'blue'])

    # Add text annotations inside each color bar
    for i, class_label in enumerate(df['Class']):
        fig.add_annotation(
            x=class_label,
            y=df.loc[i, 'Red'] / 2,
            text=f'{round(df.loc[i, "Red"], 1)}',
            showarrow=False,
            font=dict(color='white', size=14)
        )
        fig.add_annotation(
            x=class_label,
            y=df.loc[i, 'Red'] + (df.loc[i, 'Green'] / 2),
            text=f'{round(df.loc[i, "Green"], 1)}',
            showarrow=False,
            font=dict(color='white', size=14)
        )
        fig.add_annotation(
            x=class_label,
            y=df.loc[i, 'Red'] + df.loc[i, 'Green'] + (df.loc[i, 'Blue'] / 2),
            text=f'{round(df.loc[i, "Blue"], 1)}',
            showarrow=False,
            font=dict(color='white', size=14)
        )

    # Update layout
    fig.update_layout(xaxis=dict(tickangle=45))
    pio.write_image(fig, OUTPUT_DIR + 'color_distribution.pdf')
    fig.show()

In [None]:
display_color_distribution(BASE_DIR)

In [None]:
kvasir_v2_dataset = ImageFolder(BASE_DIR)

# Extract class names
class_names = kvasir_v2_dataset.classes

# Count the number of images in each class
class_counts = [len(os.listdir(os.path.join(BASE_DIR, class_name))) for class_name in class_names]

In [None]:
kvasir_v2_dataset = datasets.ImageFolder(BASE_DIR, transform=transforms.ToTensor())

# Compute mean and std
mean = 0.0
for dataset in [kvasir_v2_dataset]:
    for images, _ in dataset:
        mean += images.mean([1,2])
mean /= len(kvasir_v2_dataset)

std = 0.0
for dataset in [kvasir_v2_dataset]:
    for images, _ in dataset:
        std += ((images - mean.unsqueeze(1).unsqueeze(2))**2).mean([1,2])
std = torch.sqrt(std / len(kvasir_v2_dataset))

print(mean, std)

<div style="background-color: #34495e; border-bottom: 5px solid #95a5a6; padding: 10px;">
    <h2 style="color: white;">Defining Data Transforms</h2>
</div>

In [None]:
normalize = transforms.Normalize(mean=[0.4857, 0.3460, 0.2983],
                                 std=[0.3348, 0.2456, 0.2369])

# Define data augmentation transformations
train_transforms = transforms.Compose([
    transforms.Resize((Config.RESOLUTION, Config.RESOLUTION)),
    transforms.RandomApply([transforms.RandomHorizontalFlip()], p=0.5),
    transforms.RandomApply([transforms.RandomVerticalFlip()], p=0.5),
    transforms.RandomApply([transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)], p=0.5),
    transforms.RandomApply([transforms.RandomResizedCrop(Config.RESOLUTION, scale=(0.8, 1.0))], p=0.5),
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=7)], p=0.5),
    transforms.RandomApply([transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1))], p=0.5),
    transforms.RandomApply([transforms.RandomAffine(degrees=45)], p=0.5),
    transforms.RandomApply([transforms.RandomPerspective(distortion_scale=0.5)], p=0.5),
    transforms.RandomApply([transforms.RandomAffine(degrees=10, shear=10)], p=0.5),
    transforms.ToTensor(),
    normalize,
])

# Validation transformations without augmentation
valid_transforms = transforms.Compose([
    transforms.Resize((Config.RESOLUTION, Config.RESOLUTION)),
    transforms.ToTensor(),
    normalize,
])

test_transforms = transforms.Compose([
    transforms.Resize((Config.RESOLUTION, Config.RESOLUTION)),
    transforms.ToTensor(),
    normalize,
])

<div style="background-color: #34495e; border-bottom: 5px solid #95a5a6; padding: 10px;">
    <h2 style="color: white;">Creating K Folds</h2>
</div>

In [None]:
file_paths = []
labels = []

# Iterate through each category folder
for category in os.listdir(BASE_DIR):
    category_folder = os.path.join(BASE_DIR, category)
    
    if os.path.isdir(category_folder):
        # List all image files in the category folder
        image_files = [os.path.join(category_folder, file) for file in os.listdir(category_folder) if file.endswith('.jpg')]
        
        # Add file paths and labels to the respective lists
        file_paths.extend(image_files)
        labels.extend([category] * len(image_files))

In [None]:
print(len(file_paths))
print(len(labels))

In [None]:
# # Set the number of folds for k-fold cross-validation
# k_folds = 5

# # Create a k-fold cross-validation splitter
# kf = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=Config.SEED)

# # Iterate over the k-folds
# for fold_idx, (train_indices, val_test_indices) in enumerate(kf.split(file_paths, labels)):
#     num_val_test = len(val_test_indices)
#     num_val = int(0.5 * num_val_test)
#     num_test = num_val_test - num_val
    
#     # Initialize empty lists for validation and test indices
#     val_indices = []
#     test_indices = []
    
#     # Count the number of samples for each class
#     class_counts = {label: 0 for label in set(labels)}
#     for idx in val_test_indices:
#         label = labels[idx]
#         class_counts[label] += 1
    
#     test_class_counts = {label: 0 for label in set(labels)}
    
#     for idx in val_test_indices:
#         label = labels[idx]
#         if test_class_counts[label] < class_counts[label] // 2:
#             test_indices.append(idx)
#             test_class_counts[label] += 1
#         else:
#             val_indices.append(idx)
    
#     train_files = [file_paths[i] for i in train_indices]
#     valid_files = [file_paths[i] for i in val_indices]
#     test_files = [file_paths[i] for i in test_indices]
    
#     with open(OUTPUT_DIR + f'/data_files_fold_{fold_idx}.pkl', 'wb') as f:
#         pickle.dump((train_files, valid_files, test_files), f)
    
#     print(len(train_indices))
#     print(len(val_indices))
#     print(len(test_indices))
    
#     print(len(train_files))
#     print(len(valid_files))
#     print(len(test_files))

<div style="background-color: #34495e; border-bottom: 5px solid #95a5a6; padding: 10px;">
    <h2 style="color: white;">Populate Data Loaders</h2>
</div>

In [None]:
# Weighted Random Sampling: To balance the dataset
def build_sampler(train_dataset):
    y_train = [label for (path, label) in train_dataset.samples]

    class_sample_count = np.array(
        [len(np.where(y_train == t)[0]) for t in np.unique(y_train)])

    weight = 1. / class_sample_count
    samples_weight = np.array([weight[t] for t in y_train])
    samples_weight = torch.from_numpy(samples_weight)

    sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight))
    return sampler


def create_data_loader(dataset, batch_size, shuffle=True, num_workers=Config.NUM_WORKERS, \
                        pin_memory=True, sampler=None):
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, \
                      pin_memory=pin_memory, sampler=sampler) 

In [None]:
k_folds = 5
TRAIN_DATALOADERS = []
VALID_DATALOADERS = []
TEST_DATALOADERS = []

In [None]:
for fold_idx in range(k_folds):
    with open(DATA_FILES_DIR + f'/data_files_fold_{fold_idx}.pkl', 'rb') as f:
        train_files, valid_files, test_files = pickle.load(f)
    
    # Initialize and load the train and validation datasets using ImageFolder
    train_dataset = ImageFolder(root=BASE_DIR, transform=train_transforms)
    valid_dataset = ImageFolder(root=BASE_DIR, transform=valid_transforms)
    test_dataset = ImageFolder(root=BASE_DIR, transform=test_transforms)

    # Filter train and validation datasets to include only the selected files
    train_dataset.samples = [(path, label) for path, label in train_dataset.samples if path in train_files]
    valid_dataset.samples = [(path, label) for path, label in valid_dataset.samples if path in valid_files]
    test_dataset.samples = [(path, label) for path, label in test_dataset.samples if path in test_files]

    # print(len(train_dataset))
    # print(len(valid_dataset))
    # print(len(test_dataset))

    sampler = build_sampler(train_dataset)
    train_dataloader = create_data_loader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=(sampler is None), num_workers=Config.NUM_WORKERS, pin_memory=True, sampler=sampler)
    valid_dataloader = create_data_loader(valid_dataset, batch_size=Config.BATCH_SIZE, shuffle=False,num_workers=Config.NUM_WORKERS, pin_memory=True)
    test_dataloader = create_data_loader(test_dataset, batch_size=Config.BATCH_SIZE, shuffle=False,num_workers=Config.NUM_WORKERS, pin_memory=True)

    TRAIN_DATALOADERS.append(train_dataloader)
    VALID_DATALOADERS.append(valid_dataloader)
    TEST_DATALOADERS.append(test_dataloader)

<div style="background-color: #34495e; border-bottom: 5px solid #95a5a6; padding: 10px;">
    <h2 style="color: white;">Defining ResNet-50 Classifier</h2>
</div>

In [None]:
class ResNet50Classifier(nn.Module):
    
    def __init__(self, mlp_input_dim=Config.MLP_INPUT_DIM, mlp_hidden_dim=Config.MLP_HIDDEN_DIM, \
                 mlp_dropout_rate=Config.MLP_DROPOUT_RATE, num_classes=Config.NUM_CLASSES):
        
        super(ResNet50Classifier, self).__init__()
        
        # Load the pre-trained ResNet-50 model
        self.resnet50 = models.resnet50(pretrained=True)
        
        # Remove the final classification layer
        self.resnet50 = nn.Sequential(*list(self.resnet50.children())[:-1])
        
        # Define an MLP for classification
        self.fc = nn.Sequential(
            nn.Linear(mlp_input_dim, mlp_hidden_dim),
            nn.ReLU(),
            nn.Linear(mlp_hidden_dim, num_classes)
        )
    
    def forward(self, x):
        # Forward pass through ResNet-50
        resnet_features = self.resnet50(x)
        
        # Flatten the features
        resnet_features = resnet_features.view(resnet_features.size(0), -1)
        
        # Forward pass through MLP
        output = self.fc(resnet_features)
        return output

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [None]:
model = ResNet50Classifier().to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")

<div style="background-color: #34495e; border-bottom: 5px solid #95a5a6; padding: 10px;">
    <h2 style="color: white;">Plot Cosine Annealing Learning Rate Schedule</h2>
</div>

In [None]:
NUM_STEPS_PER_EPOCH = len(train_dataset) // Config.BATCH_SIZE

def plot_lrs_for_timm_scheduler(scheduler):
    lrs = []

    for epoch in range(Config.NUM_EPOCHS):
        num_updates = epoch * NUM_STEPS_PER_EPOCH

        for i in range(NUM_STEPS_PER_EPOCH):
            num_updates += 1
            scheduler.step_update(num_updates=num_updates)

        scheduler.step(epoch + 1)

        lrs.append(optimizer.param_groups[0]["lr"])
    return lrs

In [None]:
# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=Config.BASE_LR, weight_decay=Config.WEIGHT_DECAY)

# Cosine Annealing Learning Rate Scheduler
scheduler = CosineLRScheduler(optimizer, t_initial=Config.TRAIN_EPOCHS, cycle_decay=Config.CYCLE_DECAY, lr_min=Config.LR_MIN,
                              t_in_epochs=True, warmup_t=Config.WARMUP_EPOCHS, warmup_lr_init=Config.WARMUP_LR_INIT, 
                              cycle_limit=Config.CYCLE_LIMIT, warmup_prefix=True)


lrs = plot_lrs_for_timm_scheduler(scheduler)

plt.plot(range(1, Config.NUM_EPOCHS+1), lrs)

plt.xlabel('Epochs')
plt.ylabel('Learning Rate')
plt.xticks(np.arange(0, Config.NUM_EPOCHS, 5))

plt.grid(True, which='major', color='#666666', linestyle='-', alpha=0.4)
plt.minorticks_on()
plt.grid(True, which='minor', color='#999999', linestyle='-', alpha=0.2)

plt.savefig(OUTPUT_DIR + 'learning_rate.pdf', dpi=600, bbox_inches='tight')
plt.show()

<div style="background-color: #34495e; border-bottom: 5px solid #95a5a6; padding: 10px;">
    <h2 style="color: white;">Defining Train, Validation, and Test Functions</h2>
</div>

In [None]:
# Training loop
def train_model(epoch, model, train_dataloader, device, metrics_data, metrics_keys):
    print()
    print('*****Train*****')
    
    model.train()
    total_loss = 0.0
    total_samples = 0
    
    num_steps_per_epoch = len(train_dataloader)
    num_updates = epoch * num_steps_per_epoch

    for i, (inputs, labels) in enumerate(train_dataloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        scheduler.step_update(num_updates=num_updates)
        
        total_loss += loss.item() * inputs.size(0)
        total_samples += inputs.size(0)
        
        # Update metrics
        for metric in metrics_keys:
            metrics_data['train'][metric].update(outputs, labels)
        
        if total_samples % Config.PRINT_FREQ == 0:
            print(f'images count: {total_samples}')
        
    scheduler.step(epoch + 1)
    
    # Calculate average loss
    avg_loss = total_loss / total_samples
    
    cur_metrics = dict()
    for metric in metrics_keys:
        cur_metrics[metric] = metrics_data['train'][metric].compute()
    
    return avg_loss, cur_metrics


# Validation loop
def validate_model(epoch, model, val_loader, device, best_val_acc, best_model_wts, metrics_data, metrics_keys):
    print()
    print('*****Validation*****')

    model.eval()
    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(val_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            preds = model(inputs)
            
            # Compute loss
            loss = criterion(preds, labels)

            total_loss += loss.item() * inputs.size(0)
            total_samples += inputs.size(0)
            
            # Update metrics
            for metric in metrics_keys:
                metrics_data['valid'][metric].update(preds, labels)

            if total_samples % Config.PRINT_FREQ == 0:
                print(f'images count: {total_samples}')
    
    # Calculate average loss
    avg_loss = total_loss / total_samples
    
    cur_metrics = dict()
    for metric in metrics_keys:
        cur_metrics[metric] = metrics_data['valid'][metric].compute()
    
    accuracy = cur_metrics['acc']
    if accuracy > best_val_acc:
        best_val_acc = accuracy
        best_model_wts = copy.deepcopy(model.state_dict())

    return avg_loss, cur_metrics


# Test loop
def test_model(model, test_dataloader, device, fold, metrics_data, metrics_keys):
    print()
    print('*****Test*****')
    
    model.eval()
    since = time.time()
    total_samples = 0

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(test_dataloader):
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            preds = model(inputs)
            
            total_samples += inputs.size(0)
            
            for metric in metrics_keys:
                metrics_data['test'][metric].update(preds, labels)

            if total_samples % Config.PRINT_FREQ == 0:
                print(f'images count: {total_samples}')
    
    test_metrics = dict()
    for metric in metrics_keys:
        test_metrics[metric] = metrics_data['test'][metric].compute()
        
    torch.save(test_metrics, OUTPUT_DIR + f'/kvasir-v2-fold_{fold}.pt')
    
    print('Test Metrics: ')
    print(test_metrics)
    
    time_elapsed = time.time() - since
    print('Testing complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

<div style="background-color: #34495e; border-bottom: 5px solid #95a5a6; padding: 10px;">
    <h2 style="color: white;">Perform Training</h2>
</div>

In [None]:
since = time.time()

for fold in range(1):
    print()
    print(f'---------------------------Fold {fold}---------------------------')
    
    model = ResNet50Classifier().to(device)
    
    total_params = sum(p.numel() for p in model.parameters())
    print(f"{total_params:,} total parameters.")
    total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"{total_trainable_params:,} training parameters.")
    
#     MODEL_PATH = PRETRAINED_WEIGHTS_DIR + PRETRAINED_WEIGHTS_FILENAME
#     checkpoint = torch.load(MODEL_PATH, map_location=device)
#     model.load_state_dict(checkpoint)
    
    best_model_wts = model.state_dict()
    best_val_acc = 0.0

    metrics_data = dict()
    metrics_keys = ['acc', 'precision', 'micro-precision', 'recall', 'f1score', 'specificity', 'mcc', 'confusion_mat']
    for phase in ['train', 'valid', 'test']:
        metrics_data[phase] = {
            'acc': MulticlassAccuracy(num_classes=Config.NUM_CLASSES).to(device),
            'precision': MulticlassPrecision(num_classes=Config.NUM_CLASSES).to(device),
            'micro-precision':  MulticlassPrecision(num_classes=Config.NUM_CLASSES, average='micro').to(device),
            'recall': MulticlassRecall(num_classes=Config.NUM_CLASSES).to(device),
            'f1score': MulticlassF1Score(num_classes=Config.NUM_CLASSES).to(device),
            'specificity': MulticlassSpecificity(num_classes=Config.NUM_CLASSES).to(device),
            'mcc': MulticlassMatthewsCorrCoef(num_classes=Config.NUM_CLASSES).to(device),
            'confusion_mat': MulticlassConfusionMatrix(num_classes=Config.NUM_CLASSES).to(device)
        }


    # Define loss function
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

    # Define optimizer
    optimizer = optim.AdamW(model.parameters(), lr=Config.BASE_LR, weight_decay=Config.WEIGHT_DECAY)

    # Cosine Annealing Learning Rate Scheduler
    scheduler = CosineLRScheduler(optimizer, t_initial=Config.TRAIN_EPOCHS, cycle_decay=Config.CYCLE_DECAY, lr_min=Config.LR_MIN,
                                  t_in_epochs=True, warmup_t=Config.WARMUP_EPOCHS, warmup_lr_init=Config.WARMUP_LR_INIT, 
                                  cycle_limit=Config.CYCLE_LIMIT, warmup_prefix=True)

    
    for epoch in range(Config.NUM_EPOCHS):
        train_loader = TRAIN_DATALOADERS[fold]
        validation_loader = VALID_DATALOADERS[fold]
        
        # Train the model
        train_loss, train_metrics = train_model(
            epoch, model, train_loader, device, metrics_data, metrics_keys
        )

        print(
            f"Epoch {epoch + 1}/{Config.NUM_EPOCHS} - "
            f"LR: {optimizer.state_dict()['param_groups'][0]['lr']}, "
            f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_metrics['acc']:.4f}, Train Mirco-P: {train_metrics['micro-precision']:.4f}, Train MCC: {train_metrics['mcc']:.4f}"
        )

        # Validate the model
        val_loss, val_metrics = validate_model(epoch, model, validation_loader, device, best_val_acc, best_model_wts, metrics_data, metrics_keys)

        # Print epoch information
        print(
            f"Epoch {epoch + 1}/{Config.NUM_EPOCHS} - "
            f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_metrics['acc']:.4f}, Val Mirco-P: {val_metrics['micro-precision']:.4f}, Val MCC: {val_metrics['mcc']:.4f}"
        )

        time_elapsed = time.time() - since
        print('Epoch {} complete in {:.0f}m {:.0f}s'.format(epoch+1, time_elapsed // 60, time_elapsed % 60))

        epoch_data = {
            'epoch': epoch+1,
            'lr': optimizer.state_dict()['param_groups'][0]['lr'],
            'train_loss': train_loss,
            'train_metrics': train_metrics,
            'valid_loss': val_loss, 
            'valid_metrics': val_metrics
        }

        if (epoch+1) % Config.SAVE_FREQ == 0:
            epoch_data['model_state_dict'] = model.state_dict()
            epoch_data['optimizer_state_dict'] = optimizer.state_dict()
            epoch_data['scheduler_state_dict'] = scheduler.state_dict()

        torch.save(epoch_data, OUTPUT_DIR + '/model_fold_{}_epoch_{}.pth'.format(fold, epoch+1))
        print()
    
    time_elapsed = time.time() - since
    print('Training Complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best Validation Accuracy: {:.4f}'.format(best_val_acc))

    torch.save(best_model_wts, OUTPUT_DIR + f'/best_model_fold_{fold}.pth')
    model.load_state_dict(best_model_wts)
    
    test_dataloader = TEST_DATALOADERS[fold]
    test_model(model, test_dataloader, device, fold, metrics_data, metrics_keys)
    
    print(f'-----------------------------------------------------------------')

<div style="background-color: #34495e; border-bottom: 5px solid #95a5a6; padding: 10px;">
    <h2 style="color: white;">Perform Testing</h2>
</div>

In [None]:
model = ResNet50Classifier().to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")

MODEL_PATH = PRETRAINED_WEIGHTS_DIR + PRETRAINED_WEIGHTS_FILENAME
checkpoint = torch.load(MODEL_PATH, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

metrics_data = dict()
metrics_keys = ['acc', 'precision', 'micro-precision', 'recall', 'f1score', 'specificity', 'mcc', 'confusion_mat']
for phase in ['test']:
    metrics_data[phase] = {
        'acc': MulticlassAccuracy(num_classes=Config.NUM_CLASSES).to(device),
        'precision': MulticlassPrecision(num_classes=Config.NUM_CLASSES).to(device),
        'micro-precision':  MulticlassPrecision(num_classes=Config.NUM_CLASSES, average='micro').to(device),
        'recall': MulticlassRecall(num_classes=Config.NUM_CLASSES).to(device),
        'f1score': MulticlassF1Score(num_classes=Config.NUM_CLASSES).to(device),
        'specificity': MulticlassSpecificity(num_classes=Config.NUM_CLASSES).to(device),
        'mcc': MulticlassMatthewsCorrCoef(num_classes=Config.NUM_CLASSES).to(device),
        'confusion_mat': MulticlassConfusionMatrix(num_classes=Config.NUM_CLASSES).to(device)
    }


fold = 0
test_dataloader = TEST_DATALOADERS[fold]
test_model(model, test_dataloader, device, fold, metrics_data, metrics_keys)