# Import Libraries

In [1]:
import os
import random
import sys
from datetime import datetime
from itertools import cycle
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as T
from PIL import Image
from torch.utils.data import DataLoader, Dataset, Subset
from torchinfo import summary
from tqdm import tqdm
from torchvision.datasets import ImageFolder
from torchvision.models import vit_b_16
from torchvision import transforms

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

# Datasets

In [3]:
# mean = [0.485, 0.456, 0.406]
# std = [0.229, 0.224, 0.225]

# train_transforms = T.Compose([
#     #T.ToPILImage(),
#     T.Resize((48, 48)),
#     #T.RandomHorizontalFlip(p=0.5),
#     T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
#     T.ToTensor(),
#     T.Normalize(mean, std)
# ])

# val_transforms = T.Compose([
#     T.ToPILImage(),
#     T.Resize((48, 48)),
#     T.ToTensor(),
#     T.Normalize(mean, std)
# ])


mean = [0.485]  # Single channel
std = [0.229]

train_transforms = T.Compose([
    T.Grayscale(num_output_channels=3),  # Keep 3 channels but use grayscale
    T.RandomApply([T.RandomRotation(15)], p=0.5),
    T.RandomPerspective(distortion_scale=0.3, p=0.3),
    T.RandomResizedCrop(48, scale=(0.8, 1.2)),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean, std),
    T.RandomErasing(p=0.2)  # Helps with occlusion
])

val_transforms = T.Compose([
    T.Resize((48, 48)),
    T.ToTensor(),
    T.Normalize(mean, std)
])

In [4]:
class RandomPatchMask:
    """This class creates a random patch mask for an image.
    """
    def __init__(self, mask_ratio=0.15, patch_size=16):
        self.mask_ratio = mask_ratio
        self.patch_size = patch_size
    
    def __call__(self, img):
        # img is a Tensor [C, H, W]
        C, H, W = img.shape
        # number of patches horizontally and vertically
        num_patches_h = H // self.patch_size
        num_patches_w = W // self.patch_size
        total_patches = num_patches_h * num_patches_w
        
        # how many to mask
        num_masked = int(total_patches * self.mask_ratio)
        
        # choose random patches
        patch_indices = list(range(total_patches))
        random.shuffle(patch_indices)
        mask_indices = patch_indices[:num_masked]
        
        # create a copy to mask
        masked_img = img.clone()
        
        for idx in mask_indices:
            row = idx // num_patches_w
            col = idx % num_patches_w
            y_start = row * self.patch_size
            x_start = col * self.patch_size
            # set patch to 0
            masked_img[:, y_start:y_start+self.patch_size, x_start:x_start+self.patch_size] = 0
        
        return masked_img

### Unlabeled faces: MultiTaskDataset; FER2013

In [5]:
def generate_random_permutations(num_patches=9, num_permutations=30):
    permutations = set()
    while len(permutations) < num_permutations:
        perm = tuple(random.sample(range(num_patches), num_patches))
        permutations.add(perm)
    return [list(p) for p in permutations]

In [6]:
class MultiTaskDataset(Dataset):
    """
    A single dataset that returns a dictionary with keys:
       {
         'denoise':  (denoise_input, denoise_target),
         'rotation': (rotation_input, rotation_label),
         'jigsaw':   (jigsaw_input,   jigsaw_label),
         'mask':     (mask_input,     mask_target)
       }

    Each item in the batch corresponds to one original image, from which
    we generate the different task inputs/targets on-the-fly.
    """
    def __init__(self, image_paths, base_transforms, image_size=224, 
                 jigsaw_permutations=None,  # e.g. a precomputed list of permutations
                 mask_ratio=0.75):
        super().__init__()
        self.image_paths = image_paths
        self.image_size = image_size
        
        # Basic transforms (resize + to tensor + normalization)
        # self.base_transform = T.Compose([
        #     T.Resize((image_size, image_size)),
        #     T.ToTensor(),
        #     T.Normalize(mean=[0.485, 0.456, 0.406],
        #                 std=[0.229, 0.224, 0.225])
        # ])
        self.base_transform = base_transforms
        
        # If you have a fixed set of jigsaw permutations, store them here
        # For example: jigsaw_permutations = [[0,1,2,3,4,5,6,7,8], [3,0,1,4,2,5,7,8,6], ...]
        self.jigsaw_permutations = jigsaw_permutations
        self.mask_ratio = mask_ratio

    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # 1. Load the image
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert('RGB')  # or 'L' for grayscale
        img = img.resize((self.image_size, self.image_size), Image.BILINEAR)
        
        #print('Current image: ', img_path, end='\n')
        
        # Convert to tensor but *don't* normalize yet, because for denoising
        # we might want the original pixel scale in [0,1].
        # We'll do basic conversion:
        img_tensor = T.ToTensor()(img)  # shape: [3, H, W], range ~[0,1]

        # ---------------------------------------
        # (a) Denoise: Add noise -> (noisy_input, clean_target)
        # ---------------------------------------
        noisy_input = self.add_gaussian_noise(img_tensor)
        # The target is the *clean* version (we can still do normal transforms after)
        denoise_input = self.base_transform(T.functional.to_pil_image(noisy_input))
        denoise_target = self.base_transform(img)  
        # (B, 3, H, W) after transform

        # ---------------------------------------
        # (b) Rotation: rotate the image randomly among {0, 90, 180, 270}
        # ---------------------------------------
        angles = [0, 90, 180, 270]
        angle = random.choice(angles)
        rotation_label = angles.index(angle)  # e.g. 0->0°, 1->90°, ...
        rotated_img = img.rotate(angle)
        rotation_input = self.base_transform(rotated_img)
        # rotation_label is just an integer 0..3

        # ---------------------------------------
        # (c) Jigsaw puzzle: shuffle patches, produce a class label
        # ---------------------------------------
        # We'll do a simplistic approach: pick a random permutation from jigsaw_permutations
        # and rearrange the patches accordingly. The "class label" is the index of that permutation.
        # In a real scenario, you might just feed the shuffled patches as input
        # and the label is the permutation index.
        if self.jigsaw_permutations is not None:
            perm_index = random.randint(0, len(self.jigsaw_permutations) - 1)
            perm = self.jigsaw_permutations[perm_index]  # e.g. [3,0,1,2,4,5,6,7,8]
            jigsaw_img = self.create_jigsaw(img, perm)
            jigsaw_input = self.base_transform(jigsaw_img)
            jigsaw_label = torch.tensor(perm_index, dtype=torch.long)
        else:
            # If you haven't defined permutations, fallback
            jigsaw_input = self.base_transform(img)
            jigsaw_label = torch.tensor(0, dtype=torch.long)

        # ---------------------------------------
        # (d) Masked Patch: randomly mask a fraction of patches
        # ---------------------------------------
        # We'll do a naive approach: zero out a random fraction of pixels. 
        # The "target" is the original unmasked image. 
        mask_input, mask_target = self.create_masked_input(img_tensor, self.mask_ratio)
        mask_input = self.base_transform(T.functional.to_pil_image(mask_input))
        mask_target = self.base_transform(T.functional.to_pil_image(mask_target))

        # Return a dictionary of tasks
        return {
            'denoise':  (denoise_input,  denoise_target),
            'rotation': (rotation_input, torch.tensor(rotation_label, dtype=torch.long)),
            'jigsaw':   (jigsaw_input,   jigsaw_label),
            'mask':     (mask_input,     mask_target),
        }

    # -----------------------------
    # Helper functions
    # -----------------------------
    def add_gaussian_noise(self, img_tensor, std=0.1):
        """
        Add random Gaussian noise to a [C, H, W] tensor, range ~[0,1].
        std controls the noise level.
        """
        noise = torch.randn_like(img_tensor) * std
        noisy_img = img_tensor + noise
        return torch.clamp(noisy_img, 0.0, 1.0)

    def create_jigsaw(self, img_pil, permutation, grid_size=3):
        """
        Slices the image into grid_size x grid_size patches, rearranges them
        according to 'permutation' (list of length 9 if grid_size=3),
        then reassembles them into a single PIL image.
        """
        w, h = img_pil.size
        patch_w = w // grid_size
        patch_h = h // grid_size
        patches = []
        
        # Cut into patches
        for row in range(grid_size):
            for col in range(grid_size):
                left = col * patch_w
                upper = row * patch_h
                box = (left, upper, left+patch_w, upper+patch_h)
                patch = img_pil.crop(box)
                patches.append(patch)
        
        # Rearrange using permutation
        shuffled_patches = [patches[i] for i in permutation]
        
        # Reassemble
        new_img = Image.new('RGB', (w, h))
        idx = 0
        for row in range(grid_size):
            for col in range(grid_size):
                new_img.paste(shuffled_patches[idx], (col*patch_w, row*patch_h))
                idx += 1
        
        return new_img

    def create_masked_input(self, img_tensor, mask_ratio=0.75):
        """
        Masks a random fraction (mask_ratio) of the image pixels by setting them to 0.
        Returns (masked_img, original_img) as Tensors in [0,1].
        """
        c, h, w = img_tensor.shape
        num_pixels = h * w
        num_mask = int(num_pixels * mask_ratio)

        # Flatten the image, choose which pixels to mask
        flat_img = img_tensor.view(c, -1).clone()  # shape [3, H*W]
        mask_indices = random.sample(range(num_pixels), num_mask)
        # zero out those pixels in all channels
        for mi in mask_indices:
            flat_img[:, mi] = 0.0
        
        masked_img = flat_img.view(c, h, w)
        # The target is the original unmasked image
        target_img = img_tensor.clone()
        return masked_img, target_img

### Unlabeled faces DataLoader

In [7]:
fer_2013_dir_train = Path(os.getcwd(), 'datasets', 'fer2013', 'train')

In [None]:
# jigsaw_permutations = generate_random_permutations(num_patches=9, num_permutations=30)
# fer2013_train_files = [p for p in Path(fer_2013_dir_train).rglob('*') if p.is_file()]

# dataset = MultiTaskDataset(fer2013_train_files, train_transforms, image_size=48,
#                            jigsaw_permutations=jigsaw_permutations,
#                            mask_ratio=0.75)

# ssl_dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

In [None]:
# class UnlabeledFacesDataset(torch.utils.data.Dataset):
#     def __init__(self, img_paths, transform=None):
#         self.img_paths = img_paths
#         self.transform = transform
    
#     def __len__(self):
#         return len(self.img_paths)
    
#     def __getitem__(self, idx):
#         img_path = self.img_paths[idx]
#         # load, detect face, align, etc.
#         face_img = cv2.imread(img_path).astype('float32')
        
#         if face_img is None:
#             # handle missing face or skip
#             # for simplicity, just return a dummy or raise an error
#             raise RuntimeError(f"No face found in {img_path}")
        
#         if self.transform:
#             face_img = self.transform(face_img)
        
#         return face_img

### Labeled faces: CK+

In [None]:
# class LabeledFERDataset(torch.utils.data.Dataset):
#     def __init__(self, img_paths, labels, transform=None):
#         self.img_paths = img_paths
#         self.labels = labels
#         self.transform = transform
    
#     def __len__(self):
#         return len(self.img_paths)
    
#     def __getitem__(self, idx):
#         img_path = self.img_paths[idx]
#         label = self.labels[idx]
        
#         face_img = cv2.imread(img_path).astype('float32')
#         if face_img is None:
#             raise RuntimeError(f"No face found in {img_path}")
        
#         if self.transform:
#             face_img = self.transform(face_img)
        
#         return face_img, label

### FewShot Image Dataset: Manga faces

In [None]:
# class FewShotFERDataset(torch.utils.data.Dataset):
#     def __init__(self, class_folders, transform=None, n_way=5, k_shot=1, k_query=5):
#         # class_folders: e.g. {class_name: [list_of_image_paths]}
#         self.class_folders = class_folders
#         self.transform = transform
#         self.n_way = n_way
#         self.k_shot = k_shot
#         self.k_query = k_query
    
#     def __len__(self):
#         # The "length" might be the number of episodes you want
#         return 1000  # or some large number for meta-training
    
#     def __getitem__(self, idx):
#         # Sample n_way classes
#         sampled_classes = random.sample(list(self.class_folders.keys()), self.n_way)
        
#         support_imgs = []
#         support_labels = []
#         query_imgs = []
#         query_labels = []
        
#         label_map = {cls_name: i for i, cls_name in enumerate(sampled_classes)}
        
#         for cls_name in sampled_classes:
#             paths = self.class_folders[cls_name]
#             selected_paths = random.sample(paths, self.k_shot + self.k_query)
#             s_paths = selected_paths[:self.k_shot]
#             q_paths = selected_paths[self.k_shot:]
            
#             for sp in s_paths:
#                 img = cv2.imread(sp).astype('float32')
#                 if img is not None and self.transform:
#                     img = self.transform(img)
#                 support_imgs.append(img)
#                 support_labels.append(label_map[cls_name])
            
#             for qp in q_paths:
#                 img = cv2.imread(qp).astype('float32')
#                 if img is not None and self.transform:
#                     img = self.transform(img)
#                 query_imgs.append(img)
#                 query_labels.append(label_map[cls_name])
        
#         # Convert lists to tensors
#         support_imgs = torch.stack(support_imgs)  # shape: [n_way*k_shot, C, H, W]
#         support_labels = torch.tensor(support_labels)  # [n_way*k_shot]
#         query_imgs = torch.stack(query_imgs)      # shape: [n_way*k_query, C, H, W]
#         query_labels = torch.tensor(query_labels) # [n_way*k_query]
        
#         return (support_imgs, support_labels), (query_imgs, query_labels)

# Self-supervised Learning

In [8]:
class Identity(nn.Module):
    def forward(self, x):
        return x

### ViT_SelfSupervised class

In [9]:
class ViT_SelfSupervised(nn.Module):
    def __init__(self,
                 image_size=224,
                 patch_size=16,
                 in_chans=3,
                 embedding_dim=768,
                 num_rotation_classes=4,
                 num_jigsaw_classes=30):
        """
        image_size: input image resolution (assumed square)
        patch_size: size of each non-overlapping patch
        in_chans: number of input channels (e.g., 3 for RGB)
        embedding_dim: dimension of patch embeddings from ViT
        num_rotation_classes: number of rotation categories (0°, 90°, 180°, 270°)
        num_jigsaw_classes: number of predefined jigsaw permutation orders
        """
        super(ViT_SelfSupervised, self).__init__()
        
        # Calculate the number of patches (ViT divides the image into non-overlapping patches)
        self.num_patches = (image_size // patch_size) ** 2
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.embedding_dim = embedding_dim
        self.image_size = image_size

        # --- ViT Encoder ---
        # Instantiate the ViT model from TorchVision.
        # We set weights=False to train from scratch.
        self.encoder = vit_b_16(weights=None, image_size=image_size, num_classes=7)
        # Remove its default classification head by replacing with Identity.
        # IMPORTANT: In this example we assume that the modified encoder returns the CLS token and the patch tokens.
        self.encoder.heads = nn.Identity()

        # --- Pretext Task Heads ---

        # (a) Denoising & Reconstruction head.
        # For each patch token, we want to predict the original patch pixels.
        # Here we use a simple linear projection: output dim = patch_size x patch_size x in_chans.
        self.denoise_decoder = nn.Linear(embedding_dim, patch_size * patch_size * in_chans)

        # (b) Rotation Prediction head.
        # Uses the CLS token from the encoder.
        self.rotation_head = nn.Linear(embedding_dim, num_rotation_classes)

        # (c) Jigsaw Puzzle head.
        # For the jigsaw task, we assume that the image (with patches shuffled) is processed by the encoder.
        # We then flatten the patch tokens (excluding the CLS token) and feed them to an MLP that predicts one of
        # num_jigsaw_classes possible permutation orders.
        self.jigsaw_head = nn.Sequential(
            nn.Linear(embedding_dim * self.num_patches, 1024),
            nn.ReLU(),
            nn.Linear(1024, num_jigsaw_classes)
        )

        # (d) Masked Patch Prediction head.
        # Similar to the denoising decoder, but applied only to the masked patches.
        self.mask_decoder = nn.Linear(embedding_dim, patch_size * patch_size * in_chans)

    def forward_encoder(self, x):
        """
        Forward pass through the ViT encoder.
        We assume that the encoder returns a tensor of shape:
           (B, num_tokens, embedding_dim)
        where token 0 is the CLS token and tokens 1: are the patch tokens.
        """
        
        # x: [B, C, H, W]
        # Use conv_proj to get patch embeddings.
        # This produces an output of shape [B, embed_dim, H_p, W_p],
        # where H_p * W_p equals the total number of patches.
        x = self.encoder.conv_proj(x)  # shape: [B, embedding_dim, H_p, W_p]

        B, C, H_p, W_p = x.shape
        # Flatten the spatial dimensions and transpose to get shape [B, num_patches, embedding_dim]
        x = x.flatten(2).transpose(1, 2)  # shape: [B, H_p*W_p, embedding_dim]

        # Expand the CLS token to batch size.
        #cls_tokens = self.encoder.cls_token.expand(B, -1, -1)  # shape: [B, 1, embedding_dim]
        cls_tokens = self.encoder.class_token.expand(B, -1, -1)


        # Concatenate the CLS token with the patch tokens.
        x = torch.cat((cls_tokens, x), dim=1)  # shape: [B, 1 + num_patches, embedding_dim]
        
        ### option 1
        # Retrieve positional embeddings from state_dict
        # pos_embed = self.encoder.state_dict()['pos_embed']  # shape: [1, 1+num_patches, embedding_dim]
        
        # #pos_embed = self.encoder.pos_embed
        # x = x + pos_embed.to(x.device)
        
        # x = self.encoder.dropout(x)

        # # Pass through transformer encoder blocks.
        # for blk in self.encoder.encoder.blocks:
        #     x = blk(x)
        # # Apply final normalization.
        # x = self.encoder.encoder.norm(x)
        
        
        ### option 2
        # Check where the correct dropout layer is located
        if hasattr(self.encoder, 'dropout') and callable(self.encoder.dropout):
            x = self.encoder.dropout(x)  # Apply dropout if it exists
        elif hasattr(self.encoder.encoder, 'dropout') and callable(self.encoder.encoder.dropout):
            x = self.encoder.encoder.dropout(x)  # Apply encoder-level dropout
            
        for blk in self.encoder.encoder.layers:
            x = blk(x)
        x = self.encoder.encoder.ln(x)
        
        return x

    # def forward_denoise(self, x):
    #     """
    #     Image denoising and reconstruction task.
    #     x: input noisy images.
    #     Returns: reconstructed patches for all patch tokens.
    #     """
    #     tokens = self.forward_encoder(x)  # shape: (B, 1 + num_patches, D)
    #     # Exclude the CLS token (first token)
    #     patch_tokens = tokens[:, 1:, :]  # shape: (B, num_patches, D)
    #     # Reconstruct each patch.
    #     # The output will be of shape (B, num_patches, patch_size*patch_size*in_chans)
    #     recon_patches = self.denoise_decoder(patch_tokens)
    #     return recon_patches
    def forward_denoise(self, x):
        """Denoising task with proper reshaping"""
        B = x.shape[0]
        tokens = self.forward_encoder(x)
        patch_tokens = tokens[:, 1:, :]  # [B, num_patches, D]
        
        # Reconstruct patches
        recon_patches = self.denoise_decoder(patch_tokens)  # [B, 9, 768]
        
        # Reshape to original image dimensions [B, 3, 48, 48]
        recon_image = self.patches_to_image(recon_patches, B)
        return recon_image

    def forward_masked(self, x):
        """Masked patch prediction with proper reshaping"""
        B = x.shape[0]
        tokens = self.forward_encoder(x)
        patch_tokens = tokens[:, 1:, :]  # [B, 9, D]
        
        # Reconstruct patches
        mask_predictions = self.mask_decoder(patch_tokens)  # [B, 9, 768]
        
        # Reshape to original image dimensions [B, 3, 48, 48]
        recon_image = self.patches_to_image(mask_predictions, B)
        return recon_image

    def patches_to_image(self, patches, batch_size):
        """Convert patch sequence to image tensor"""
        # patches shape: [B, num_patches, patch_size^2 * 3]
        patch_size = self.patch_size
        num_patches = self.num_patches
        channels = self.in_chans
        
        # Reshape to [B, num_patches, C, patch_size, patch_size]
        patches = patches.view(
            batch_size, 
            num_patches, 
            channels, 
            patch_size, 
            patch_size
        )
        
        # Reshape to image grid
        grid_size = int(num_patches ** 0.5)  # 3 for 9 patches
        image = patches.permute(0, 2, 1, 3, 4)  # [B, C, num_patches, p, p]
        image = image.contiguous().view(
            batch_size, 
            channels, 
            grid_size * patch_size, 
            grid_size * patch_size
        )
        return image

    def forward_rotation(self, x):
        """
        Image rotation prediction task.
        x: input rotated images.
        Returns: rotation logits predicted from the CLS token.
        """
        tokens = self.forward_encoder(x)  # shape: (B, 1 + num_patches, D)
        cls_token = tokens[:, 0, :]         # shape: (B, D)
        rotation_logits = self.rotation_head(cls_token)
        return rotation_logits

    def forward_jigsaw(self, x):
        """
        Jigsaw puzzle task.
        x: input images with shuffled patches.
        Returns: logits for predicting the permutation order (classification over num_jigsaw_classes).
        """
        tokens = self.forward_encoder(x)      # shape: (B, 1 + num_patches, D)
        patch_tokens = tokens[:, 1:, :]         # remove CLS token, shape: (B, num_patches, D)
        # Flatten patch tokens for each image: (B, num_patches * D)
        flat_tokens = patch_tokens.reshape(x.size(0), -1)
        jigsaw_logits = self.jigsaw_head(flat_tokens)
        return jigsaw_logits

    # def forward_masked(self, x):
    #     """
    #     Masked patch prediction task.
    #     x: input images with some patches masked out.
    #        (The masking operation should be performed in the data pre-processing or transform.)
    #     Returns: reconstructed predictions for the (unobserved) masked patches.
    #     For simplicity, here we process the entire set of patch tokens, and later you would compare
    #     the output for masked locations with the ground truth.
    #     """
    #     tokens = self.forward_encoder(x)      # shape: (B, 1 + num_patches, D)
    #     patch_tokens = tokens[:, 1:, :]         # shape: (B, num_patches, D)
    #     # Predict patch pixel values for each token.
    #     mask_predictions = self.mask_decoder(patch_tokens)
    #     return mask_predictions

    def forward(self, x, task):
        """
        A unified forward method that selects the appropriate pretext task.
        task: a string specifying the task type: 'denoise', 'rotation', 'jigsaw', or 'mask'.
        """
        if task == 'denoise':
            return self.forward_denoise(x)
        elif task == 'rotation':
            return self.forward_rotation(x)
        elif task == 'jigsaw':
            return self.forward_jigsaw(x)
        elif task == 'mask':
            return self.forward_masked(x)
        else:
            raise ValueError("Invalid task specified. Choose from 'denoise', 'rotation', 'jigsaw', or 'mask'.")

### MultiTaskLossWrapper

In [10]:
class MultiTaskLossWrapper(nn.Module):
    def __init__(self):
        super().__init__()
        # Initialize the log-variances (or deltas) as learnable parameters.
        # We'll store them in log-space for numerical stability, e.g. log(delta^2).
        self.log_var_den = nn.Parameter(torch.zeros(1))
        self.log_var_rot = nn.Parameter(torch.zeros(1))
        self.log_var_puz = nn.Parameter(torch.zeros(1))
        self.log_var_msk = nn.Parameter(torch.zeros(1))
    
    def forward(self, L_den, L_rot, L_puz, L_msk):
        # Convert log_var_* to delta = exp(log_var / 2)
        # But the formula in the paper effectively uses 1/(2*delta^2), so we can do:
        # Weighted losses:
        w_den = 1.0 / (2.0 * torch.exp(self.log_var_den))
        w_rot = 1.0 / (2.0 * torch.exp(self.log_var_rot))
        w_puz = 1.0 / (2.0 * torch.exp(self.log_var_puz))
        w_msk = 1.0 / (2.0 * torch.exp(self.log_var_msk))

        # Combined loss
        loss = (w_den * L_den 
                + w_rot * L_rot
                + w_puz * L_puz
                + w_msk * L_msk
                + (self.log_var_den + self.log_var_rot 
                   + self.log_var_puz + self.log_var_msk) * 0.5)
        
        return loss

In [11]:
multi_task_loss_fn = MultiTaskLossWrapper().to(device)

### Train SSL Model function

In [12]:
def train_ssl_model(model, dataloader, num_epochs=10, device='cuda', learning_rate=0.01):
    """
    This is a simplified training loop that assumes:
      - The dataloader returns a dictionary containing inputs for each task.
        For example, each batch might be a dict with keys: 'denoise', 'rotation', 'jigsaw', 'mask'
      - Each key maps to a tuple: (input_tensor, target) where target is the ground truth for that task.
    In practice, you may need separate dataloaders or combine losses with appropriate weighting.
    """
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.1)
    
    # For simplicity, we use fixed weights for each task’s loss.
    weight_denoise = 1.0
    weight_rotation = 1.0
    weight_jigsaw = 1.0
    weight_mask = 1.0
    
    best_loss = np.inf
    patience_counter = 0   # Tracks the number of epochs without improvement
    early_stop = False # Flag to indicate whether to stop training
    save_weights_patience = 3  # Stop training if no improvement after this many epochs
    
    metrics_loss = []

    for epoch in range(num_epochs):
        if early_stop:
            print("Early stopping triggered. Stopping training.")
            break
        print(f"Epoch {epoch + 1}, LR: {scheduler.optimizer.param_groups[0]['lr']}")
        
        model.train()
        total_loss = 0.0
        tk = tqdm(dataloader, desc="EPOCH" + "[TRAIN]" + str(epoch + 1) + "/" + str(num_epochs))
        #print('DATALOADER LEN \n', len(dataloader))
        
        for t, batch in enumerate(tk):
            optimizer.zero_grad()
            #loss = 0.0
            L_den = 0.0
            L_rot = 0.0
            L_jig = 0.0
            L_mask = 0.0
            #print('BATCH IS \n', batch)
            #break

            # --- Denoising & Reconstruction Task ---
            if 'denoise' in batch:
                inputs_denoise, targets_denoise = batch['denoise']  # targets: original patches or full image patches
                inputs_denoise = inputs_denoise.to(device)
                targets_denoise = targets_denoise.to(device)
                recon_patches = model(inputs_denoise, task='denoise')
                # Reshape targets to match the output if necessary.
                loss_denoise = F.mse_loss(recon_patches, targets_denoise)
                L_den += weight_denoise * loss_denoise

            # --- Rotation Prediction Task ---
            if 'rotation' in batch:
                inputs_rotation, targets_rotation = batch['rotation']
                inputs_rotation = inputs_rotation.to(device)
                targets_rotation = targets_rotation.to(device)
                rotation_logits = model(inputs_rotation, task='rotation')
                loss_rotation = F.cross_entropy(rotation_logits, targets_rotation)
                L_rot += weight_rotation * loss_rotation

            # --- Jigsaw Puzzle Task ---
            if 'jigsaw' in batch:
                inputs_jigsaw, targets_jigsaw = batch['jigsaw']  # targets: permutation labels (integer class index)
                inputs_jigsaw = inputs_jigsaw.to(device)
                targets_jigsaw = targets_jigsaw.to(device)
                jigsaw_logits = model(inputs_jigsaw, task='jigsaw')
                loss_jigsaw = F.cross_entropy(jigsaw_logits, targets_jigsaw)
                L_jig += weight_jigsaw * loss_jigsaw

            # --- Masked Patch Prediction Task ---
            if 'mask' in batch:
                inputs_mask, targets_mask = batch['mask']
                inputs_mask = inputs_mask.to(device)
                targets_mask = targets_mask.to(device)
                mask_predictions = model(inputs_mask, task='mask')
                loss_mask = F.mse_loss(mask_predictions, targets_mask)
                L_mask += weight_mask * loss_mask
                
            # Combine them with learned weights:
            L_ssl = multi_task_loss_fn(L_den, L_rot, L_jig, L_mask)

            L_ssl.backward()
            optimizer.step()
            total_loss += L_ssl.item()
            tk.set_postfix({'loss': '%6f' % float(L_ssl / (t + 1))})
        
        avg_loss = total_loss / len(dataloader)
        #print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
        
        metrics_loss.append(avg_loss)
        scheduler.step(avg_loss)
        
        if avg_loss < best_loss:
            # Save the pretrained encoder weights (and optionally heads) for later fine-tuning.
            torch.save(model.state_dict(), "weights/vit_ssl_pretrained.pth")
            print("SAVED-BEST-WEIGHTS!")
            best_loss = avg_loss
            patience_counter = 0 # Reset early stopping
        else:
            patience_counter += 1
            print(f"No improvement in validation loss for {patience_counter} epoch(s).")

        if patience_counter >= save_weights_patience:
            print("Patience exceeded. Early stopping at epoch " +str(epoch + 1))
            early_stop = True

    print("")
    return metrics_loss

In [None]:
# -------------------------------
# Example usage:
# Assume we have a combined dataloader that yields a dict with keys: 'denoise', 'rotation', 'jigsaw', 'mask'.
# Each entry is a tuple: (input_tensor, target_tensor)
# In practice, you need to implement or combine datasets that perform the corresponding data augmentation.
# -------------------------------
if __name__ == "__main__":
    # Create model instance 
    model_ssl = ViT_SelfSupervised(image_size=48, patch_size=16, in_chans=3,
                                   embedding_dim=768, num_rotation_classes=4,
                                   num_jigsaw_classes=30)
    #print(dir(model_ssl))
    # Here, "ssl_dataloader" should be defined by you.
    # For demonstration purposes, assume it's provided.
    # ssl_dataloader = ...
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # train_ssl_model(model_ssl, ssl_dataloader, num_epochs=10, device=device)
    
    jigsaw_permutations = generate_random_permutations(num_patches=9, num_permutations=30)
    fer2013_train_files = [p for p in Path(fer_2013_dir_train).rglob('*') if p.is_file()]

    dataset = MultiTaskDataset(fer2013_train_files, train_transforms, image_size=48,
                            jigsaw_permutations=jigsaw_permutations,
                            mask_ratio=0.75)

    ssl_dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)
    
    pretrained_metrics_loss = train_ssl_model(model_ssl, ssl_dataloader, num_epochs=10, device=device, learning_rate=0.01)

# ViT encoder fine-tuning

### Data Preprocessing

In [13]:
mean = [0.485]  # Single channel
std = [0.229]

train_transforms = T.Compose([
    T.Grayscale(num_output_channels=3),  # Keep 3 channels but use grayscale
    T.RandomApply([T.RandomRotation(15)], p=0.5),
    T.RandomPerspective(distortion_scale=0.3, p=0.3),
    T.RandomResizedCrop(48, scale=(0.8, 1.2)),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean, std),
    T.RandomErasing(p=0.2)  # Helps with occlusion
])

val_transforms = T.Compose([
    T.Resize((48, 48)),
    T.ToTensor(),
    T.Normalize(mean, std)
])

fer2013_dir_train = Path(os.getcwd(), 'datasets', 'fer2013')

train_dataset = ImageFolder(root=fer2013_dir_train.joinpath('train'), transform=train_transforms)
val_dataset   = ImageFolder(root=fer2013_dir_train.joinpath('test'), transform=val_transforms)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

### FineTuningModel class

In [14]:
class FineTuningModel(nn.Module):
    def __init__(self, num_classes=7, image_size=224, use_pretrained_ssl=True, pretrained_path="vit_ssl_pretrained.pth"):
        super(FineTuningModel, self).__init__()
        
        # Load the ViT model from torchvision.
        self.encoder = vit_b_16(weights=None, image_size=image_size)
        # Remove the default classification head.
        self.encoder.heads = nn.Identity()
        
        # Optionally load the pretrained weights from SSL stage.
        if use_pretrained_ssl:
            state_dict = torch.load(pretrained_path, map_location='cuda')
            self.encoder.load_state_dict(state_dict, strict=False)
        
        # Define a new classification head.
        embedding_dim = 768  # This is the standard for vit_b_16.
        self.classifier = nn.Linear(embedding_dim, num_classes)
    
    def forward(self, x, return_embeddings=False):
        # Assume the encoder outputs a tensor of shape [B, 1 + num_patches, D].
        tokens = self.encoder(x)
        # We use the CLS token (first token) for classification.
        # Handle different dimensional outputs
        if tokens.dim() == 3:
            # Standard ViT output: [B, num_tokens, D]
            cls_token = tokens[:, 0, :]  # Extract CLS token
        else:
            # Direct feature output: [B, D]
            cls_token = tokens
        
        # Get final logits
        logits = self.classifier(cls_token)
        
        return (logits, cls_token) if return_embeddings else logits

### Create Train and Test functions

In [15]:
def calculate_accuracy(y_pred,y_true):
    top_p,top_class = y_pred.topk(1, dim = 1)
    equals = top_class == y_true.view(*top_class.shape)
    return torch.mean(equals.type(torch.cuda.FloatTensor))

In [16]:
def train_one_epoch(model, dataloader, criterion, optimizer, device, current_epoch, epochs):
    """
    Train one epoch of the model.

    Args:
        model (nn.Module): The  model.
        dataloader (DataLoader): DataLoader for training data.
        device (torch.device): Device to train the model on (CPU/GPU).

    Returns:
        training_loss (float): Returns epoch_loss / len(dataloader)
    """
    model.train()  # Set model to training mode
    epoch_loss = 0.0
    epoch_accuracy = 0.0
    tk = tqdm(dataloader, desc="EPOCH" + "[TRAIN]" + str(current_epoch + 1) + "/" + str(epochs))

    for t, data in enumerate(tk):
        images, labels = data

        images = images.to(device)
        labels = labels.to(device)

        # Zero gradients
        optimizer.zero_grad()

        # Compute log probabilities from model
        logits = model(images)
        loss = criterion(logits, labels)

        # Backpropagation
        loss.backward()

        # Update model parameters
        optimizer.step()

        # Accumulate loss for logging; Total loss
        epoch_loss += loss.item()
        
        epoch_accuracy += calculate_accuracy(logits, labels)

        # Print/log training loss and accuracy for this epoch
        tk.set_postfix({
            'loss': '%6f' % float(epoch_loss / (t + 1)), 
            'acc': '%6f' % float(epoch_accuracy / (t + 1))
        })

    return epoch_loss / len(dataloader), epoch_accuracy / len(dataloader)

In [17]:
def test_one_epoch(model, dataloader, criterion, device, current_epoch, epochs):
    """
    Test one epoch of the model

    Args:
        model (nn.Module): The model.
        dataloader (DataLoader): DataLoader for training data.
        learning_rate (float): Learning rate for optimizer.
        device (torch.device): Device to train the model on (CPU/GPU).

    Returns:
        training_loss (float): Returns epoch_loss / len(dataloader)
        
        running_acc (float): Returns running accuracy
    """
    model.eval()  # Set model to evaluation mode
    epoch_loss = 0.0
    epoch_accuracy = 0.0

    tk = tqdm(dataloader, desc="EPOCH" + "[VALID]" + str(current_epoch + 1) + "/" + str(epochs))

    correct = 0
    total = 0
    with torch.no_grad():  # Disable gradient calculation for testing
        for t, data in enumerate(tk):          
            images, labels = data

            images = images.to(device)
            labels = labels.to(device)

            # Compute log probabilities from model
            logits = model(images)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += images.size(0)            

            # Compute loss
            loss = criterion(logits, labels)

            # Accumulate loss for logging; Total loss
            epoch_loss += loss.item()
            
            epoch_accuracy += calculate_accuracy(logits, labels)
            
            #running_acc = correct / total * 100

            #tk.set_postfix({'loss': '%6f' % float(epoch_loss / (t + 1)), 'acc': '%2f%%' %float(running_acc),})
            tk.set_postfix({
                'loss': '%6f' % float(epoch_loss / (t + 1)), 
                'acc': '%6f' % float(epoch_accuracy / (t + 1))
            })


    return epoch_loss / len(dataloader), epoch_accuracy / len(dataloader)


In [18]:
def train_and_validate_model(model, training_dataloader, testing_dataloader, epochs, learning_rate, device):
    """
    Train and Test the model using loss.

    Args:
        model (nn.Module): The model.
        training_dataloader (DataLoader): DataLoader for training data.
        testing_dataloader (DataLoader): DataLoader for testing data.
        epochs (int): Number of training epochs.
        learning_rate (float): Learning rate for optimizer.
        device (torch.device): Device to train the model on (CPU/GPU).
    """
    # Define Loss function
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.1)

    # Move model to device
    model.to(device)

    best_valid_loss = np.inf
    patience_counter = 0   # Tracks the number of epochs without improvement
    early_stop = False # Flag to indicate whether to stop training
    save_weights_patience = 5

    # Dictionary to store loss and accuracy values over epochs
    history_metrics = {
        'training_loss': [],
        'training_accuracy': [],
        'validation_loss': [],
        'validation_accuracy': []
    }

    for epoch in range(epochs):
        if early_stop:
            print("Early stopping triggered. Stopping training.")
            break

        print(f"Epoch {epoch + 1}, LR: {scheduler.optimizer.param_groups[0]['lr']}")

        # Training step
        train_loss, train_accuracy = train_one_epoch(model, training_dataloader, criterion, optimizer, device, epoch, epochs)
        
        # Testing step
        valid_loss, valid_accuracy = test_one_epoch(model, testing_dataloader, criterion, device, epoch, epochs) 

        history_metrics['training_loss'].append(train_loss)
        history_metrics['validation_loss'].append(valid_loss)
        history_metrics['training_accuracy'].append(train_accuracy)
        history_metrics['validation_accuracy'].append(valid_accuracy)

        # Update the learning rate based on validation loss and print
        scheduler.step(valid_loss)

        if valid_loss < best_valid_loss:
            torch.save(model.state_dict(), 'weights/vit_fine_tuned_with_fer2013_NO_PRETRAIN.pt')
            print("SAVED-BEST-WEIGHTS!")
            best_valid_loss = valid_loss
            patience_counter = 0 # Reset early stopping
        else:
            patience_counter += 1
            print(f"No improvement in validation loss for {patience_counter} epoch(s).")

        if patience_counter >= save_weights_patience:
            print("Patience exceeded. Early stopping at epoch " +str(epoch + 1))
            early_stop = True
            
        
    print("")
    #return model
    return history_metrics


### Create model instance and call train functions

In [25]:
time1 = datetime.now()

In [22]:
# Create the model instance:
model_fer = FineTuningModel(num_classes=7, image_size=48, use_pretrained_ssl=False)

In [None]:
# Train the model:
finetuning_losses = train_and_validate_model(model_fer, train_loader, val_loader, epochs=20, learning_rate=0.001, device=device)

Epoch 1, LR: 0.001


EPOCH[TRAIN]1/20: 100%|██████████| 898/898 [01:28<00:00, 10.17it/s, loss=1.797473, acc=0.251246]
EPOCH[VALID]1/20: 100%|██████████| 225/225 [00:29<00:00,  7.66it/s, loss=1.780289, acc=0.262944]


SAVED-BEST-WEIGHTS!
Epoch 2, LR: 0.001


EPOCH[TRAIN]2/20: 100%|██████████| 898/898 [01:05<00:00, 13.65it/s, loss=1.769565, acc=0.266815]
EPOCH[VALID]2/20: 100%|██████████| 225/225 [00:17<00:00, 13.19it/s, loss=1.779265, acc=0.257472]


SAVED-BEST-WEIGHTS!
Epoch 3, LR: 0.001


EPOCH[TRAIN]3/20: 100%|██████████| 898/898 [01:05<00:00, 13.73it/s, loss=1.758142, acc=0.278410]
EPOCH[VALID]3/20: 100%|██████████| 225/225 [00:17<00:00, 13.13it/s, loss=1.733075, acc=0.293611]


SAVED-BEST-WEIGHTS!
Epoch 4, LR: 0.001


EPOCH[TRAIN]4/20: 100%|██████████| 898/898 [01:06<00:00, 13.56it/s, loss=1.748842, acc=0.284222]
EPOCH[VALID]4/20: 100%|██████████| 225/225 [00:16<00:00, 13.26it/s, loss=1.720345, acc=0.298028]


SAVED-BEST-WEIGHTS!
Epoch 5, LR: 0.001


EPOCH[TRAIN]5/20: 100%|██████████| 898/898 [01:07<00:00, 13.27it/s, loss=1.742977, acc=0.284883]
EPOCH[VALID]5/20: 100%|██████████| 225/225 [00:16<00:00, 13.32it/s, loss=1.727594, acc=0.308472]


No improvement in validation loss for 1 epoch(s).
Epoch 6, LR: 0.001


EPOCH[TRAIN]6/20: 100%|██████████| 898/898 [01:06<00:00, 13.41it/s, loss=1.732255, acc=0.295476]
EPOCH[VALID]6/20: 100%|██████████| 225/225 [00:30<00:00,  7.38it/s, loss=1.738679, acc=0.274861]


No improvement in validation loss for 2 epoch(s).
Epoch 7, LR: 0.001


EPOCH[TRAIN]7/20: 100%|██████████| 898/898 [01:08<00:00, 13.08it/s, loss=1.731515, acc=0.294119]
EPOCH[VALID]7/20: 100%|██████████| 225/225 [00:17<00:00, 13.04it/s, loss=1.699793, acc=0.319444]


SAVED-BEST-WEIGHTS!
Epoch 8, LR: 0.001


EPOCH[TRAIN]8/20: 100%|██████████| 898/898 [01:05<00:00, 13.70it/s, loss=1.721614, acc=0.302332]
EPOCH[VALID]8/20: 100%|██████████| 225/225 [00:17<00:00, 12.86it/s, loss=1.691885, acc=0.311361]


SAVED-BEST-WEIGHTS!
Epoch 9, LR: 0.001


EPOCH[TRAIN]9/20: 100%|██████████| 898/898 [01:05<00:00, 13.73it/s, loss=1.721803, acc=0.301051]
EPOCH[VALID]9/20: 100%|██████████| 225/225 [00:16<00:00, 13.41it/s, loss=1.699319, acc=0.317611]


No improvement in validation loss for 1 epoch(s).
Epoch 10, LR: 0.001


EPOCH[TRAIN]10/20: 100%|██████████| 898/898 [01:06<00:00, 13.54it/s, loss=1.711502, acc=0.309069]
EPOCH[VALID]10/20: 100%|██████████| 225/225 [00:16<00:00, 13.40it/s, loss=1.736935, acc=0.304028]


No improvement in validation loss for 2 epoch(s).
Epoch 11, LR: 0.001


EPOCH[TRAIN]11/20: 100%|██████████| 898/898 [01:04<00:00, 14.02it/s, loss=1.707864, acc=0.309312]
EPOCH[VALID]11/20: 100%|██████████| 225/225 [00:16<00:00, 13.27it/s, loss=1.659956, acc=0.336250]


SAVED-BEST-WEIGHTS!
Epoch 12, LR: 0.001


EPOCH[TRAIN]12/20: 100%|██████████| 898/898 [01:03<00:00, 14.05it/s, loss=1.707028, acc=0.310962]
EPOCH[VALID]12/20: 100%|██████████| 225/225 [00:17<00:00, 12.72it/s, loss=1.669878, acc=0.327361]


No improvement in validation loss for 1 epoch(s).
Epoch 13, LR: 0.001


EPOCH[TRAIN]13/20: 100%|██████████| 898/898 [01:04<00:00, 14.02it/s, loss=1.703111, acc=0.314129]
EPOCH[VALID]13/20: 100%|██████████| 225/225 [00:16<00:00, 13.49it/s, loss=1.646916, acc=0.346806]


SAVED-BEST-WEIGHTS!
Epoch 14, LR: 0.001


EPOCH[TRAIN]14/20: 100%|██████████| 898/898 [01:03<00:00, 14.13it/s, loss=1.699556, acc=0.312368]
EPOCH[VALID]14/20: 100%|██████████| 225/225 [00:16<00:00, 13.57it/s, loss=1.648841, acc=0.346722]


No improvement in validation loss for 1 epoch(s).
Epoch 15, LR: 0.001


EPOCH[TRAIN]15/20: 100%|██████████| 898/898 [01:03<00:00, 14.18it/s, loss=1.695200, acc=0.319265]
EPOCH[VALID]15/20: 100%|██████████| 225/225 [00:16<00:00, 13.46it/s, loss=1.674029, acc=0.324056]


No improvement in validation loss for 2 epoch(s).
Epoch 16, LR: 0.001


EPOCH[TRAIN]16/20: 100%|██████████| 898/898 [01:04<00:00, 13.82it/s, loss=1.689902, acc=0.321123]
EPOCH[VALID]16/20: 100%|██████████| 225/225 [00:17<00:00, 12.61it/s, loss=1.634338, acc=0.349028]


SAVED-BEST-WEIGHTS!
Epoch 17, LR: 0.001


EPOCH[TRAIN]17/20: 100%|██████████| 898/898 [01:04<00:00, 13.89it/s, loss=1.684782, acc=0.327304]
EPOCH[VALID]17/20: 100%|██████████| 225/225 [00:16<00:00, 13.39it/s, loss=1.631200, acc=0.354583]


SAVED-BEST-WEIGHTS!
Epoch 18, LR: 0.001


EPOCH[TRAIN]18/20: 100%|██████████| 898/898 [01:04<00:00, 13.82it/s, loss=1.685150, acc=0.325425]
EPOCH[VALID]18/20: 100%|██████████| 225/225 [00:16<00:00, 13.52it/s, loss=1.680742, acc=0.322694]


No improvement in validation loss for 1 epoch(s).
Epoch 19, LR: 0.001


EPOCH[TRAIN]19/20: 100%|██████████| 898/898 [01:04<00:00, 13.89it/s, loss=1.678458, acc=0.327408]
EPOCH[VALID]19/20: 100%|██████████| 225/225 [00:16<00:00, 13.29it/s, loss=1.645004, acc=0.337528]


No improvement in validation loss for 2 epoch(s).
Epoch 20, LR: 0.001


EPOCH[TRAIN]20/20: 100%|██████████| 898/898 [01:04<00:00, 13.88it/s, loss=1.676877, acc=0.327666]
EPOCH[VALID]20/20: 100%|██████████| 225/225 [00:17<00:00, 13.20it/s, loss=1.627575, acc=0.352111]


SAVED-BEST-WEIGHTS!



In [27]:
time2 = datetime.now()
print("Total training time: ", time2 - time1)

Total training time:  0:28:24.946426


In [28]:
# Prepare the data
data = {
    "Epoch": list(range(1, len(finetuning_losses['training_loss']) + 1)),
    "Training Loss": finetuning_losses['training_loss'],
    "Validation Loss": finetuning_losses['validation_loss'],
    "Training Accuracy": [acc.cpu().item() for acc in finetuning_losses['training_accuracy']],
    "Validation Accuracy": [acc.cpu().item() for acc in finetuning_losses['validation_accuracy']]
}

# Create a DataFrame
df = pd.DataFrame(data)

# Save to CSV
df.to_csv("stats/vit_model_stats_001_50_epochs.csv", index=False)
print("Losses and accuracy saved")

Losses and accuracy saved


# Test Model Accuracy on Out of Distribution Data set (Manga Faces)

In [23]:
def test_out_of_distribution(model, testing_dataloader, epochs, device):
    """
    Train and Test the speech recognition model using CTC loss.

    Args:
        model (nn.Module): The model.
        training_dataloader (DataLoader): DataLoader for training data.
        testing_dataloader (DataLoader): DataLoader for testing data.
        epochs (int): Number of training epochs.
        learning_rate (float): Learning rate for optimizer.
        device (torch.device): Device to train the model on (CPU/GPU).
        
    Returns:
        history_metrics (dict): Dictionary containing validation loss and accuracy over epochs.
    """
    # Define Loss function
    criterion = nn.CrossEntropyLoss()
    
    # Move model to device
    model.to(device)

    # Dictionary to store loss and accuracy values over epochs
    history_metrics = {
        'validation_loss': [],
        'validation_accuracy': []
    }

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}")
        
        # Testing step
        valid_loss, valid_accuracy = test_one_epoch(model, testing_dataloader, criterion, device, epoch, epochs) 
        
        history_metrics['validation_loss'].append(valid_loss)
        history_metrics['validation_accuracy'].append(valid_accuracy)
                
    print("")
    return history_metrics

In [24]:
params={
        "initial_filters": 8,    
        "dropout_rate": 0.2,
        "num_classes": 7}

In [25]:
vit_model_2 = model_fer.to(device)
vit_model_2.load_state_dict(torch.load('weights/vit_fine_tuned_with_fer2013_NO_PRETRAIN.pt', weights_only=True))

<All keys matched successfully>

## Import MangaFaces Dataset

In [26]:
BATCH_SIZE = 32

In [27]:
# Train set
manga_faces_train_dir = Path(os.getcwd(), 'datasets', 'manga', 'train')
manga_faces_train_images = ImageFolder(root=manga_faces_train_dir, transform=train_transforms)
manga_faces_train_images_loader = DataLoader(manga_faces_train_images, batch_size=BATCH_SIZE, shuffle=True)

# Test Set
manga_faces_test_dir = Path(os.getcwd(), 'datasets', 'manga', 'test')
manga_faces_test_images = ImageFolder(root=manga_faces_test_dir, transform=val_transforms)
manga_faces_test_images_loader = DataLoader(manga_faces_test_images, batch_size=BATCH_SIZE, shuffle=True)

## Run 'test_out_of_distribution' function

In [28]:
# Train the model:
test_out_of_distribution_metrics = test_out_of_distribution(vit_model_2, manga_faces_train_images_loader, epochs=5, device=device)

Epoch 1


EPOCH[VALID]1/5: 100%|██████████| 7/7 [00:02<00:00,  3.32it/s, loss=2.658994, acc=0.095536]


Epoch 2


EPOCH[VALID]2/5: 100%|██████████| 7/7 [00:00<00:00,  8.82it/s, loss=2.675275, acc=0.113393]


Epoch 3


EPOCH[VALID]3/5: 100%|██████████| 7/7 [00:00<00:00,  8.70it/s, loss=2.626171, acc=0.093750]


Epoch 4


EPOCH[VALID]4/5: 100%|██████████| 7/7 [00:00<00:00,  9.59it/s, loss=2.660418, acc=0.116071]


Epoch 5


EPOCH[VALID]5/5: 100%|██████████| 7/7 [00:00<00:00, 11.18it/s, loss=2.655772, acc=0.108036]







In [37]:
# Store the metrics from when the model was tested on the out-of-distribution dataset
data = {
    "Epoch": list(range(1, len(test_out_of_distribution_metrics['validation_loss']) + 1)),
    "Validation Loss": test_out_of_distribution_metrics['validation_loss'],
    "Validation Accuracy": [acc.cpu().item() for acc in test_out_of_distribution_metrics['validation_accuracy']]
}

# Create a DataFrame
df = pd.DataFrame(data)

# Save to CSV
df.to_csv("stats/vit_model_out_of_distribution_stats_001_5_epochs.csv", index=False)
print("Losses and accuracy saved")

Losses and accuracy saved


In [38]:
# Train the model:
test_out_of_distribution_metrics = test_out_of_distribution(vit_model_2, manga_faces_test_images_loader, epochs=5, device=device)

Epoch 1


EPOCH[VALID]1/5: 100%|██████████| 5/5 [00:01<00:00,  2.85it/s, loss=2.480916, acc=0.218750]


Epoch 2


EPOCH[VALID]2/5: 100%|██████████| 5/5 [00:00<00:00, 13.60it/s, loss=2.793143, acc=0.087500]


Epoch 3


EPOCH[VALID]3/5: 100%|██████████| 5/5 [00:00<00:00, 19.59it/s, loss=2.590372, acc=0.175000]


Epoch 4


EPOCH[VALID]4/5: 100%|██████████| 5/5 [00:00<00:00, 20.10it/s, loss=2.689698, acc=0.087500]


Epoch 5


EPOCH[VALID]5/5: 100%|██████████| 5/5 [00:00<00:00, 19.24it/s, loss=2.762226, acc=0.087500]







# FSL Domain Adaptation Prototypical Network

### Dataset class

In [29]:
class FewShotFERDataset(Dataset):
    """
    Dataset for few-shot FER, where images are organized by class in folders.
    This dataset generates episodes (tasks) on-the-fly.
    """
    def __init__(self, root_dir, n_way=5, k_shot=1, k_query=5, transform=None):
        """
        root_dir: Root folder containing one folder per class.
        n_way: number of classes per episode.
        k_shot: number of support examples per class.
        k_query: number of query examples per class.
        transform: transformation to apply to images.
        """
        self.root_dir = root_dir
        self.n_way = n_way
        self.k_shot = k_shot
        self.k_query = k_query
        self.transform = transform
        
        # Build a mapping: class -> list of image paths.
        self.class_to_imgs = {}
        for cls_name in os.listdir(root_dir):
            cls_folder = Path.joinpath(root_dir, cls_name)
            if Path.is_dir(cls_folder):
                self.class_to_imgs[cls_name] = [Path.joinpath(cls_folder, img)                                                 
                                                 for img in Path(cls_folder).rglob('*')
                                                 if str(img).endswith('.jpg') or str(img).endswith('.png')]        
        self.classes = list(self.class_to_imgs.keys())
    
    def __len__(self):
        # Define the number of episodes arbitrarily.
        return 1000  # or any number representing episodes
    
    def __getitem__(self, idx):
        # Randomly sample n_way classes for this episode.
        sampled_classes = random.sample(self.classes, self.n_way)
        support_imgs, support_labels = [], []
        query_imgs, query_labels = [], []
        
        label_map = {cls_name: i for i, cls_name in enumerate(sampled_classes)}
        
        for cls_name in sampled_classes:
            imgs = self.class_to_imgs[cls_name]
            # Ensure there are enough examples in this class.
            selected_imgs = random.sample(imgs, self.k_shot + self.k_query)
            support_paths = selected_imgs[:self.k_shot]
            query_paths = selected_imgs[self.k_shot:]
            
            for sp in support_paths:
                img = Image.open(sp).convert('RGB')
                if self.transform:
                    img = self.transform(img)
                support_imgs.append(img)
                support_labels.append(label_map[cls_name])
            
            for qp in query_paths:
                img = Image.open(qp).convert('RGB')
                if self.transform:
                    img = self.transform(img)
                query_imgs.append(img)
                query_labels.append(label_map[cls_name])
        
        # Convert lists to tensors.
        support_imgs = torch.stack(support_imgs)  # shape: [n_way*k_shot, C, H, W]
        support_labels = torch.tensor(support_labels, dtype=torch.long)
        query_imgs = torch.stack(query_imgs)      # shape: [n_way*k_query, C, H, W]
        query_labels = torch.tensor(query_labels, dtype=torch.long)
        
        return (support_imgs, support_labels), (query_imgs, query_labels)

### Constructing DataLoader

In [30]:
# Define transforms (should match what the encoder expects)
transform = T.Compose([
    T.Resize((48, 48)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
])

# Root folder with classes as subfolders.
few_shot_dataset = FewShotFERDataset(root_dir=manga_faces_train_dir, n_way=4, k_shot=10, k_query=22, transform=transform)
few_shot_loader = DataLoader(few_shot_dataset, batch_size=1, shuffle=True, num_workers=0)

### Prototypical Network Inference function

In [31]:
def evaluate_episode(model, support_imgs, support_labels, query_imgs, query_labels, device):
    """
    model: the fine-tuned FER model, used as feature extractor.
    support_imgs: [n_way*k_shot, C, H, W]
    query_imgs: [n_way*k_query, C, H, W]
    support_labels: [n_way*k_shot]
    query_labels: [n_way*k_query]
    """
    model.eval()
    with torch.no_grad():
        support_imgs = support_imgs.to(device)
        query_imgs = query_imgs.to(device)
        
        # Extract features using the pretrained encoder.
        # Adjusted to handle both 2D (CNN) and 3D (ViT) feature outputs
        def get_cls_features(x):
            features = model.encoder(x)
            # If features are 3D (ViT), extract CLS token; else, use features directly (CNN)
            if features.dim() == 3:
                return features[:, 0, :]
            else:
                return features
        
        support_feats = get_cls_features(support_imgs)  # shape: [n_way*k_shot, D]
        query_feats = get_cls_features(query_imgs)      # shape: [n_way*k_query, D]
        
        # Compute prototypes: mean feature for each class.
        n_way = len(torch.unique(support_labels))
        prototypes = []
        for cls in range(n_way):
            cls_indices = (support_labels == cls).nonzero(as_tuple=True)[0]
            cls_feats = support_feats[cls_indices]
            prototype = cls_feats.mean(dim=0)
            prototypes.append(prototype)
        prototypes = torch.stack(prototypes)  # shape: [n_way, D]
        
        # Compute distances between each query feature and prototypes.
        dists = torch.cdist(query_feats, prototypes, p=2)  # shape: [n_way*k_query, n_way]
        probs = F.softmax(-dists, dim=1)
        preds = torch.argmax(probs, dim=1)
        correct = (preds.cpu() == query_labels).sum().item()
        total = query_labels.size(0)
    
    return correct, total


### Run inference

In [52]:
# ------------------------------------------
# Evaluate on a few episodes from the DataLoader
# ------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Assume model_fer is the fine-tuned FER model we defined previously.
model_fer = model_fer.to(device)
total_correct = 0
total_samples = 0
num_episodes = 50  # Evaluate on 50 episodes for instance.

for i, ((support_imgs, support_labels), (query_imgs, query_labels)) in enumerate(few_shot_loader):
    if i >= num_episodes:
        break
    correct, total = evaluate_episode(model_fer, support_imgs.squeeze(0), support_labels.squeeze(0),
                                      query_imgs.squeeze(0), query_labels.squeeze(0), device)
    total_correct += correct
    total_samples += total

episode_accuracy = 100.0 * total_correct / total_samples
print("Few-Shot Episode Accuracy: {:.2f}%".format(episode_accuracy))

Few-Shot Episode Accuracy: 32.07%


# Contrastive Learning

### Align Label spaces

In [32]:
label_map = {
    "angry": 0,
    "disgust": 1,
    "fear": 2,
    "happy": 3,
    "neutral": 4,
    "sad": 5,
    "surprise": 6
}

In [33]:
class MappedImageFolder(ImageFolder):
    def __init__(self, root, label_map, transform=None):
        super().__init__(root, transform=transform)
        self.samples = [
            (path, label_map[self.classes[label]])
            for path, label in self.samples
            if self.classes[label] in label_map
        ]
        self.targets = [s[1] for s in self.samples]
        
        inverse_label_map = {v: k for k, v in label_map.items()}
        self.classes = [inverse_label_map[i] for i in sorted(inverse_label_map)]


In [34]:
mean=[0.485]
std=[0.229]

manga_transforms = T.Compose([
    T.Grayscale(num_output_channels=3),  # Keep 3 channels but use grayscale
    T.RandomApply([T.GaussianBlur(3), T.RandomSolarize(0.5)], p=0.5),
    T.RandomPerspective(distortion_scale=0.4, p=0.3),
    T.RandomApply([T.RandomRotation(15)], p=0.5),
    T.RandomPerspective(distortion_scale=0.3, p=0.3),
    T.RandomResizedCrop(48, scale=(0.8, 1.2)),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean, std),
    T.RandomErasing(p=0.2)  # Helps with occlusion
])


manga_faces_train_dir = Path(os.getcwd(), 'datasets', 'manga', 'train')
manga_faces_train_images = MappedImageFolder(root=manga_faces_train_dir, label_map=label_map  , transform=manga_transforms)
manga_faces_train_images_loader = DataLoader(manga_faces_train_images, batch_size=32, shuffle=True)


# Test Set
test_transforms = T.Compose([
    T.Grayscale(num_output_channels=3),
    T.Resize((48, 48)),
    T.ToTensor(),
    T.Normalize(mean=[0.485], std=[0.229])
])
manga_faces_test_dir = Path(os.getcwd(), 'datasets', 'manga', 'test')
manga_faces_test_images = MappedImageFolder(root=manga_faces_test_dir, label_map=label_map  , transform=test_transforms)
manga_faces_test_images_loader = DataLoader(manga_faces_test_images, batch_size=32, shuffle=True)

In [35]:
print(set(manga_faces_test_images.targets))

{0, 3, 5, 6}


### Contrastive Loss Class Implementation

In [36]:
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.5):  # Increased temperature
        super().__init__()
        self.tau = temperature # hyperparameter for scaling the similarity scores
        
    def forward(self, source_emb, source_labels, target_emb, target_labels):
        device = source_emb.device
        
        # Normalize embeddings
        source_emb = F.normalize(source_emb, p=2, dim=1)
        target_emb = F.normalize(target_emb, p=2, dim=1)
        
        embeddings = torch.cat([source_emb, target_emb], dim=0)
        labels = torch.cat([source_labels, target_labels], dim=0)
        
        # Similarity matrix
        sim_matrix = torch.mm(target_emb, embeddings.T) / self.tau
        
        # Masks
        pos_mask = torch.zeros_like(sim_matrix, dtype=torch.bool)
        for i, label in enumerate(target_labels):
            pos_mask[i, :len(source_labels)] = (source_labels == label)
            
        neg_mask = (labels != target_labels.unsqueeze(1))
        neg_mask[:, len(source_labels):] &= ~torch.eye(
            len(target_labels), dtype=torch.bool, device=device
        )
        
        # Compute terms with stability
        pos_term = (sim_matrix.exp() * pos_mask.float()).sum(dim=1) + 1e-8
        neg_term = (sim_matrix.exp() * neg_mask.float()).sum(dim=1) + 1e-8
        
        loss = -torch.log(pos_term / (pos_term + neg_term))
        return loss.mean()

### Few-shot sampling function

In [37]:
# Few-shot sampling function
def get_few_shot_indices(dataset, shots_per_class=5):
    """
    Returns a balanced list of indices for few-shot learning by randomly selecting
    a fixed number of samples per class.

    Args:
        dataset (ImageFolder): A PyTorch ImageFolder dataset (or any dataset with a `.samples` attribute 
                              containing (path, label) tuples).
        shots_per_class (int, optional): Number of samples to select per class. Defaults to 5.

    Returns:
        List[int]: A list of selected indices, ensuring `shots_per_class` samples per class.

    Example:
        >>> target_set = ImageFolder(root='data/target', transform=transforms.ToTensor())
        >>> few_shot_indices = get_few_shot_indices(target_set, shots_per_class=3)
        >>> few_shot_loader = DataLoader(Subset(target_set, few_shot_indices), batch_size=3)
    """
    
    class_indices = {}
    for idx, (_, label) in enumerate(dataset.samples):
        class_indices.setdefault(label, []).append(idx)
    
    selected_indices = []
    for label, indices in class_indices.items():
        selected_indices.extend(np.random.choice(indices, shots_per_class, replace=False))
    return selected_indices

In [38]:
model_fer = FineTuningModel(num_classes=7, image_size=48, use_pretrained_ssl=False)
vit_model_2 = model_fer.to(device)
vit_model_2.load_state_dict(torch.load('weights/vit_fine_tuned_with_fer2013_NO_PRETRAIN.pt', weights_only=True))

<All keys matched successfully>

In [39]:
# Initialize model and losses
params = {'num_classes': 7, 'dropout_rate': 0.2}  # Example for FER2013
cls_criterion = nn.CrossEntropyLoss()
cont_criterion = ContrastiveLoss(temperature=0.2)
optimizer = torch.optim.AdamW(vit_model_2.parameters(), lr=1e-4, weight_decay=1e-4)

# Prepare few-shot target loader training set
few_shot_indices = get_few_shot_indices(manga_faces_train_images, shots_per_class=15)
few_shot_loader = DataLoader(
    Subset(manga_faces_train_images, few_shot_indices),
    batch_size=10,
    shuffle=False,
    drop_last=True  # Avoid partial batches
)

### Modified Training Loop for Contrastive Learning

#### Using both CrossEntropyLoss and ContrastiveLoss

In [40]:
# Training loop with domain adaptation
def train_epoch(model, source_loader, target_loader, optimizer, epoch, epochs):
    model.train()
    target_iter = cycle(target_loader)  # Infinite iterator
    
    # Initialize metrics
    total_cls_loss = 0.0
    total_cont_loss = 0.0
    running_total_loss = 0.0
    source_correct = 0
    target_correct = 0
    total_source_samples = 0
    total_target_samples = 0
    
    
    tk = tqdm(source_loader, desc="EPOCH" + "[TRAIN]" + str(epoch) + "/" + str(epochs))
    
    for batch_idx, (source_imgs, source_lbls) in enumerate(tk):
        # Get target batch
        target_imgs, target_lbls = next(target_iter)
        
        # Move to device
        source_imgs = source_imgs.to(device)
        source_lbls = source_lbls.to(device)
        target_imgs = target_imgs.to(device)
        target_lbls = target_lbls.to(device)
        
        # Forward pass with embeddings
        source_logits, source_emb = model(source_imgs, return_embeddings=True)
        target_logits, target_emb = model(target_imgs, return_embeddings=True)
        
        # Calculate accuracy
        source_preds = source_logits.argmax(dim=1)
        target_preds = target_logits.argmax(dim=1)
        
        # Update counters
        batch_source_correct = (source_preds == source_lbls).sum().item()
        batch_target_correct = (target_preds == target_lbls).sum().item()
        
        source_correct += batch_source_correct
        target_correct += batch_target_correct
        total_source_samples += source_lbls.size(0)
        total_target_samples += target_lbls.size(0)
        
        # Loss calculation
        cls_loss = cls_criterion(source_logits, source_lbls) + \
                cls_criterion(target_logits, target_lbls)
        
        cont_loss = cont_criterion(source_emb, source_lbls,
                                 target_emb, target_lbls)
        
        current_loss = cls_loss + 0.9 * cont_loss # Adjusted weight
        
        # Update metrics
        total_cls_loss += cls_loss.item()
        total_cont_loss += cont_loss.item() * 0.9 # Adjusted weight
        running_total_loss += current_loss.item()
        
        # Backpropagation
        optimizer.zero_grad()
        current_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        # Calculate batch-level accuracies
        batch_source_acc = 100 * batch_source_correct / source_lbls.size(0)
        batch_target_acc = 100 * batch_target_correct / target_lbls.size(0)
        
        # Update progress bar
        tk.set_postfix({
            'CLS Loss': f'{total_cls_loss / (batch_idx + 1):.4f}',
            'CONT Loss': f'{total_cont_loss / (batch_idx + 1):.4f}',
            'Total Loss': f'{running_total_loss / (batch_idx + 1):.4f}',
            'Source Acc': f'{batch_source_acc:.2f}%',
            'Target Acc': f'{batch_target_acc:.2f}%'
        })
            
    # Calculate epoch-level metrics
    epoch_cls_loss = total_cls_loss / len(source_loader)
    epoch_cont_loss = total_cont_loss / len(source_loader)
    epoch_total_loss = running_total_loss / len(source_loader)
    
    epoch_source_acc = 100 * source_correct / total_source_samples
    epoch_target_acc = 100 * target_correct / total_target_samples
    
    print(f"\nEpoch {epoch}/{epochs} Summary:")
    print(f"CLS Loss: {epoch_cls_loss:.4f} | CONT Loss: {epoch_cont_loss:.4f} | Total Loss: {epoch_total_loss:.4f}")
    print(f"Source Acc: {epoch_source_acc:.2f}% | Target Acc: {epoch_target_acc:.2f}%")
    
    return epoch_cls_loss, epoch_cont_loss, epoch_total_loss, epoch_source_acc, epoch_target_acc

In [41]:
EPOCHS = 50

In [42]:
contrastive_loss_metrics = {
    'cls_loss': [],
    'cont_loss': [],
    'total_loss': [],
    'source_accuracy': [],
    'target_accuracy': []
}

validation_loss_accuracy = {
    'validation_loss': [],
    'validation_accuracy': []
}

best_valid_loss = np.inf
patience_counter = 0   # Tracks the number of epochs without improvement
early_stop = False # Flag to indicate whether to stop training
save_weights_patience = 5

for epoch in range(1, EPOCHS + 1):
    if early_stop:
        print("Early stopping triggered. Stopping training.")
        break
        
    cls_loss, cont_loss, total_loss, source_acc, target_acc = \
        train_epoch(vit_model_2, train_loader, few_shot_loader, optimizer, epoch, EPOCHS)
    
    contrastive_loss_metrics['cls_loss'].append(cls_loss)
    contrastive_loss_metrics['cont_loss'].append(cont_loss)
    contrastive_loss_metrics['total_loss'].append(total_loss)
    contrastive_loss_metrics['source_accuracy'].append(source_acc)
    contrastive_loss_metrics['target_accuracy'].append(target_acc)
    
    print()
    current_val_loss_accuracy = test_out_of_distribution(vit_model_2, manga_faces_test_images_loader, epochs=1, device=device)
    validation_loss_accuracy['validation_loss'].append(float(current_val_loss_accuracy['validation_loss'][0]))
    validation_loss_accuracy['validation_accuracy'].append(float(current_val_loss_accuracy['validation_accuracy'][0]))
    
    
    if total_loss < best_valid_loss:
        torch.save(vit_model_2.state_dict(), 'weights/vit_model_contrastive_learning_weights.pt')
        print("SAVED-BEST-WEIGHTS!")
        best_valid_loss = total_loss
        patience_counter = 0 # Reset early stopping
    else:
        patience_counter += 1
        print(f"No improvement in validation loss for {patience_counter} epoch(s).")

    if patience_counter >= save_weights_patience:
        print("Patience exceeded. Early stopping at epoch " +str(epoch + 1))
        early_stop = True
        
    print()

EPOCH[TRAIN]1/50: 100%|██████████| 898/898 [01:55<00:00,  7.79it/s, CLS Loss=2.0659, CONT Loss=1.8462, Total Loss=3.9121, Source Acc=20.00%, Target Acc=100.00%]



Epoch 1/50 Summary:
CLS Loss: 2.0659 | CONT Loss: 1.8462 | Total Loss: 3.9121
Source Acc: 24.26% | Target Acc: 94.90%

Epoch 1


EPOCH[VALID]1/1: 100%|██████████| 5/5 [00:01<00:00,  4.26it/s, loss=1.823862, acc=0.325000]



SAVED-BEST-WEIGHTS!



EPOCH[TRAIN]2/50: 100%|██████████| 898/898 [01:51<00:00,  8.05it/s, CLS Loss=2.0372, CONT Loss=1.8076, Total Loss=3.8447, Source Acc=0.00%, Target Acc=100.00%] 



Epoch 2/50 Summary:
CLS Loss: 2.0372 | CONT Loss: 1.8076 | Total Loss: 3.8447
Source Acc: 25.29% | Target Acc: 94.89%

Epoch 1


EPOCH[VALID]1/1: 100%|██████████| 5/5 [00:00<00:00, 15.95it/s, loss=1.818009, acc=0.287500]



SAVED-BEST-WEIGHTS!



EPOCH[TRAIN]3/50: 100%|██████████| 898/898 [01:45<00:00,  8.51it/s, CLS Loss=2.0547, CONT Loss=1.7252, Total Loss=3.7799, Source Acc=0.00%, Target Acc=100.00%] 



Epoch 3/50 Summary:
CLS Loss: 2.0547 | CONT Loss: 1.7252 | Total Loss: 3.7799
Source Acc: 25.64% | Target Acc: 93.57%

Epoch 1


EPOCH[VALID]1/1: 100%|██████████| 5/5 [00:00<00:00, 20.46it/s, loss=1.925983, acc=0.337500]



SAVED-BEST-WEIGHTS!



EPOCH[TRAIN]4/50: 100%|██████████| 898/898 [01:46<00:00,  8.41it/s, CLS Loss=2.0035, CONT Loss=1.7638, Total Loss=3.7673, Source Acc=20.00%, Target Acc=100.00%]



Epoch 4/50 Summary:
CLS Loss: 2.0035 | CONT Loss: 1.7638 | Total Loss: 3.7673
Source Acc: 25.84% | Target Acc: 95.50%

Epoch 1


EPOCH[VALID]1/1: 100%|██████████| 5/5 [00:00<00:00, 20.63it/s, loss=2.471120, acc=0.381250]



SAVED-BEST-WEIGHTS!



EPOCH[TRAIN]5/50: 100%|██████████| 898/898 [01:47<00:00,  8.36it/s, CLS Loss=2.0278, CONT Loss=1.7559, Total Loss=3.7837, Source Acc=20.00%, Target Acc=100.00%]



Epoch 5/50 Summary:
CLS Loss: 2.0278 | CONT Loss: 1.7559 | Total Loss: 3.7837
Source Acc: 26.16% | Target Acc: 94.51%

Epoch 1


EPOCH[VALID]1/1: 100%|██████████| 5/5 [00:00<00:00, 15.50it/s, loss=2.329355, acc=0.450000]



No improvement in validation loss for 1 epoch(s).



EPOCH[TRAIN]6/50: 100%|██████████| 898/898 [01:45<00:00,  8.48it/s, CLS Loss=2.0083, CONT Loss=1.7901, Total Loss=3.7984, Source Acc=40.00%, Target Acc=100.00%]



Epoch 6/50 Summary:
CLS Loss: 2.0083 | CONT Loss: 1.7901 | Total Loss: 3.7984
Source Acc: 26.27% | Target Acc: 94.88%

Epoch 1


EPOCH[VALID]1/1: 100%|██████████| 5/5 [00:00<00:00, 20.64it/s, loss=2.998486, acc=0.343750]



No improvement in validation loss for 2 epoch(s).



EPOCH[TRAIN]7/50: 100%|██████████| 898/898 [01:50<00:00,  8.11it/s, CLS Loss=2.1156, CONT Loss=1.6480, Total Loss=3.7636, Source Acc=60.00%, Target Acc=100.00%]



Epoch 7/50 Summary:
CLS Loss: 2.1156 | CONT Loss: 1.6480 | Total Loss: 3.7636
Source Acc: 25.79% | Target Acc: 90.82%

Epoch 1


EPOCH[VALID]1/1: 100%|██████████| 5/5 [00:00<00:00, 20.98it/s, loss=3.205936, acc=0.281250]



SAVED-BEST-WEIGHTS!



EPOCH[TRAIN]8/50: 100%|██████████| 898/898 [02:03<00:00,  7.26it/s, CLS Loss=1.9773, CONT Loss=1.6692, Total Loss=3.6465, Source Acc=0.00%, Target Acc=100.00%] 



Epoch 8/50 Summary:
CLS Loss: 1.9773 | CONT Loss: 1.6692 | Total Loss: 3.6465
Source Acc: 26.38% | Target Acc: 96.20%

Epoch 1


EPOCH[VALID]1/1: 100%|██████████| 5/5 [00:00<00:00, 18.43it/s, loss=2.615782, acc=0.231250]



SAVED-BEST-WEIGHTS!



EPOCH[TRAIN]9/50: 100%|██████████| 898/898 [02:02<00:00,  7.31it/s, CLS Loss=2.0217, CONT Loss=1.7233, Total Loss=3.7450, Source Acc=0.00%, Target Acc=100.00%] 



Epoch 9/50 Summary:
CLS Loss: 2.0217 | CONT Loss: 1.7233 | Total Loss: 3.7450
Source Acc: 25.88% | Target Acc: 95.11%

Epoch 1


EPOCH[VALID]1/1: 100%|██████████| 5/5 [00:00<00:00, 20.76it/s, loss=2.005509, acc=0.287500]



No improvement in validation loss for 1 epoch(s).



EPOCH[TRAIN]10/50: 100%|██████████| 898/898 [02:02<00:00,  7.33it/s, CLS Loss=1.9849, CONT Loss=1.7393, Total Loss=3.7242, Source Acc=40.00%, Target Acc=100.00%]



Epoch 10/50 Summary:
CLS Loss: 1.9849 | CONT Loss: 1.7393 | Total Loss: 3.7242
Source Acc: 26.63% | Target Acc: 96.24%

Epoch 1


EPOCH[VALID]1/1: 100%|██████████| 5/5 [00:00<00:00, 20.82it/s, loss=1.679839, acc=0.400000]



No improvement in validation loss for 2 epoch(s).



EPOCH[TRAIN]11/50: 100%|██████████| 898/898 [02:02<00:00,  7.36it/s, CLS Loss=2.0353, CONT Loss=1.8179, Total Loss=3.8531, Source Acc=20.00%, Target Acc=100.00%]



Epoch 11/50 Summary:
CLS Loss: 2.0353 | CONT Loss: 1.8179 | Total Loss: 3.8531
Source Acc: 26.42% | Target Acc: 94.71%

Epoch 1


EPOCH[VALID]1/1: 100%|██████████| 5/5 [00:00<00:00, 20.73it/s, loss=1.940639, acc=0.287500]



No improvement in validation loss for 3 epoch(s).



EPOCH[TRAIN]12/50: 100%|██████████| 898/898 [02:00<00:00,  7.47it/s, CLS Loss=2.0468, CONT Loss=1.7537, Total Loss=3.8005, Source Acc=20.00%, Target Acc=100.00%]



Epoch 12/50 Summary:
CLS Loss: 2.0468 | CONT Loss: 1.7537 | Total Loss: 3.8005
Source Acc: 25.60% | Target Acc: 94.18%

Epoch 1


EPOCH[VALID]1/1: 100%|██████████| 5/5 [00:01<00:00,  3.49it/s, loss=2.031717, acc=0.387500]



No improvement in validation loss for 4 epoch(s).



EPOCH[TRAIN]13/50: 100%|██████████| 898/898 [01:47<00:00,  8.35it/s, CLS Loss=2.0326, CONT Loss=1.8070, Total Loss=3.8396, Source Acc=0.00%, Target Acc=100.00%] 



Epoch 13/50 Summary:
CLS Loss: 2.0326 | CONT Loss: 1.8070 | Total Loss: 3.8396
Source Acc: 25.99% | Target Acc: 95.31%

Epoch 1


EPOCH[VALID]1/1: 100%|██████████| 5/5 [00:00<00:00, 15.44it/s, loss=2.056227, acc=0.337500]


No improvement in validation loss for 5 epoch(s).
Patience exceeded. Early stopping at epoch 14

Early stopping triggered. Stopping training.





#### Store Metrics

In [43]:
# Store the metrics from when the model was tested on the out-of-distribution dataset
data = {
    "Epoch": list(range(1, len(contrastive_loss_metrics['cls_loss']) + 1)),
    "CLS_LOSS": contrastive_loss_metrics['cls_loss'],
    "CONT_LOSS": contrastive_loss_metrics['cont_loss'],
    "Total Loss": contrastive_loss_metrics['total_loss'],
    "Source Accuracy": contrastive_loss_metrics['source_accuracy'],
    "Target Accuracy": contrastive_loss_metrics['target_accuracy']
}

# Create a DataFrame
df = pd.DataFrame(data)

# Save to CSV
df.to_csv("stats/vit_model_contrastive_learning_stats_training.csv", index=False)
print("Losses and accuracy saved")

Losses and accuracy saved


In [44]:
# Store the metrics from when the model was tested on the out-of-distribution dataset
data = {
    "Epoch": list(range(1, len(validation_loss_accuracy['validation_loss']) + 1)),
    "Validation Loss": validation_loss_accuracy['validation_loss'],
    "Validation Accuracy": validation_loss_accuracy['validation_accuracy']
}

# Create a DataFrame
df = pd.DataFrame(data)

# Save to CSV
df.to_csv("stats/vit_model_contrastive_learning_TESTINGSET_stats_training.csv", index=False)
print("Losses and accuracy saved")

Losses and accuracy saved


In [45]:
# tEST the model:
cont_test_out_of_distribution_metrics = test_out_of_distribution(vit_model_2, manga_faces_test_images_loader, epochs=10, device=device)

Epoch 1


EPOCH[VALID]1/10: 100%|██████████| 5/5 [00:10<00:00,  2.10s/it, loss=2.126303, acc=0.293750]


Epoch 2


EPOCH[VALID]2/10: 100%|██████████| 5/5 [00:00<00:00,  7.42it/s, loss=2.066677, acc=0.381250]


Epoch 3


EPOCH[VALID]3/10: 100%|██████████| 5/5 [00:00<00:00,  8.76it/s, loss=1.975406, acc=0.381250]


Epoch 4


EPOCH[VALID]4/10: 100%|██████████| 5/5 [00:00<00:00,  9.22it/s, loss=1.982801, acc=0.337500]


Epoch 5


EPOCH[VALID]5/10: 100%|██████████| 5/5 [00:00<00:00,  8.24it/s, loss=2.290566, acc=0.337500]


Epoch 6


EPOCH[VALID]6/10: 100%|██████████| 5/5 [00:00<00:00,  8.16it/s, loss=2.092210, acc=0.425000]


Epoch 7


EPOCH[VALID]7/10: 100%|██████████| 5/5 [00:00<00:00,  9.18it/s, loss=2.093665, acc=0.381250]


Epoch 8


EPOCH[VALID]8/10: 100%|██████████| 5/5 [00:00<00:00,  8.82it/s, loss=2.157536, acc=0.381250]


Epoch 9


EPOCH[VALID]9/10: 100%|██████████| 5/5 [00:00<00:00,  8.35it/s, loss=1.971365, acc=0.381250]


Epoch 10


EPOCH[VALID]10/10: 100%|██████████| 5/5 [00:00<00:00,  9.97it/s, loss=2.018247, acc=0.381250]







In [46]:
from sklearn.metrics import classification_report

def evaluate_model(model, test_loader, target=False):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for imgs, lbls in tqdm(test_loader, desc="Evaluating"):
            imgs = imgs.to(device)
            lbls = lbls.to(device)
            
            logits, _ = model(imgs, return_embeddings=True)
            preds = logits.argmax(dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(lbls.cpu().numpy())
    
    # Calculate metrics
    accuracy = 100 * (np.array(all_preds) == np.array(all_labels)).mean()
    class_report = classification_report(all_labels, all_preds, zero_division=0)
    
    print(f"{'Target' if target else 'Source'} Test Accuracy: {accuracy:.2f}%")
    print("\nClassification Report:\n", class_report)
    
    return accuracy, class_report

#### Test on Source Domain

In [47]:
source_test_accuracy, source_report = evaluate_model(vit_model_2, val_loader)
# Contrastive learning model does not forget the source domain

Evaluating: 100%|██████████| 225/225 [00:44<00:00,  5.05it/s]

Source Test Accuracy: 28.89%

Classification Report:
               precision    recall  f1-score   support

           0       0.17      0.02      0.04       958
           1       0.00      0.00      0.00       111
           2       0.20      0.12      0.15      1024
           3       0.30      0.73      0.42      1774
           4       0.28      0.12      0.17      1233
           5       0.25      0.22      0.23      1247
           6       0.42      0.25      0.32       831

    accuracy                           0.29      7178
   macro avg       0.23      0.21      0.19      7178
weighted avg       0.27      0.29      0.24      7178






#### Test on Target Domain

In [48]:
target_test_accuracy, target_report = evaluate_model(vit_model_2, manga_faces_test_images_loader, target=True)

Evaluating: 100%|██████████| 5/5 [00:00<00:00,  6.34it/s]


Target Test Accuracy: 35.61%

Classification Report:
               precision    recall  f1-score   support

           0       0.18      0.10      0.12        21
           2       0.00      0.00      0.00         0
           3       0.49      0.73      0.59        49
           4       0.00      0.00      0.00         0
           5       0.00      0.00      0.00        22
           6       0.45      0.23      0.30        40

    accuracy                           0.36       132
   macro avg       0.19      0.18      0.17       132
weighted avg       0.35      0.36      0.33       132

