# Import Libraries

In [None]:
import os
import random
import sys
from datetime import datetime
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as T
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchinfo import summary
from tqdm import tqdm
from torchvision.datasets import ImageFolder
from torchvision.models import vit_b_16
from torchvision import transforms

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

# Datasets

In [None]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

train_transforms = T.Compose([
    #T.ToPILImage(),
    T.Resize((48, 48)),
    #T.RandomHorizontalFlip(p=0.5),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    T.ToTensor(),
    T.Normalize(mean, std)
])

val_transforms = T.Compose([
    T.ToPILImage(),
    T.Resize((48, 48)),
    T.ToTensor(),
    T.Normalize(mean, std)
])

In [None]:
class RandomPatchMask:
    """This class creates a random patch mask for an image.
    """
    def __init__(self, mask_ratio=0.15, patch_size=16):
        self.mask_ratio = mask_ratio
        self.patch_size = patch_size
    
    def __call__(self, img):
        # img is a Tensor [C, H, W]
        C, H, W = img.shape
        # number of patches horizontally and vertically
        num_patches_h = H // self.patch_size
        num_patches_w = W // self.patch_size
        total_patches = num_patches_h * num_patches_w
        
        # how many to mask
        num_masked = int(total_patches * self.mask_ratio)
        
        # choose random patches
        patch_indices = list(range(total_patches))
        random.shuffle(patch_indices)
        mask_indices = patch_indices[:num_masked]
        
        # create a copy to mask
        masked_img = img.clone()
        
        for idx in mask_indices:
            row = idx // num_patches_w
            col = idx % num_patches_w
            y_start = row * self.patch_size
            x_start = col * self.patch_size
            # set patch to 0
            masked_img[:, y_start:y_start+self.patch_size, x_start:x_start+self.patch_size] = 0
        
        return masked_img

### Unlabeled faces: MultiTaskDataset; FER2013

In [None]:
def generate_random_permutations(num_patches=9, num_permutations=30):
    permutations = set()
    while len(permutations) < num_permutations:
        perm = tuple(random.sample(range(num_patches), num_patches))
        permutations.add(perm)
    return [list(p) for p in permutations]

In [None]:
class MultiTaskDataset(Dataset):
    """
    A single dataset that returns a dictionary with keys:
       {
         'denoise':  (denoise_input, denoise_target),
         'rotation': (rotation_input, rotation_label),
         'jigsaw':   (jigsaw_input,   jigsaw_label),
         'mask':     (mask_input,     mask_target)
       }

    Each item in the batch corresponds to one original image, from which
    we generate the different task inputs/targets on-the-fly.
    """
    def __init__(self, image_paths, base_transforms, image_size=224, 
                 jigsaw_permutations=None,  # e.g. a precomputed list of permutations
                 mask_ratio=0.75):
        super().__init__()
        self.image_paths = image_paths
        self.image_size = image_size
        
        # Basic transforms (resize + to tensor + normalization)
        # self.base_transform = T.Compose([
        #     T.Resize((image_size, image_size)),
        #     T.ToTensor(),
        #     T.Normalize(mean=[0.485, 0.456, 0.406],
        #                 std=[0.229, 0.224, 0.225])
        # ])
        self.base_transform = base_transforms
        
        # If you have a fixed set of jigsaw permutations, store them here
        # For example: jigsaw_permutations = [[0,1,2,3,4,5,6,7,8], [3,0,1,4,2,5,7,8,6], ...]
        self.jigsaw_permutations = jigsaw_permutations
        self.mask_ratio = mask_ratio

    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # 1. Load the image
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert('RGB')  # or 'L' for grayscale
        img = img.resize((self.image_size, self.image_size), Image.BILINEAR)
        
        #print('Current image: ', img_path, end='\n')
        
        # Convert to tensor but *don't* normalize yet, because for denoising
        # we might want the original pixel scale in [0,1].
        # We'll do basic conversion:
        img_tensor = T.ToTensor()(img)  # shape: [3, H, W], range ~[0,1]

        # ---------------------------------------
        # (a) Denoise: Add noise -> (noisy_input, clean_target)
        # ---------------------------------------
        noisy_input = self.add_gaussian_noise(img_tensor)
        # The target is the *clean* version (we can still do normal transforms after)
        denoise_input = self.base_transform(T.functional.to_pil_image(noisy_input))
        denoise_target = self.base_transform(img)  
        # (B, 3, H, W) after transform

        # ---------------------------------------
        # (b) Rotation: rotate the image randomly among {0, 90, 180, 270}
        # ---------------------------------------
        angles = [0, 90, 180, 270]
        angle = random.choice(angles)
        rotation_label = angles.index(angle)  # e.g. 0->0°, 1->90°, ...
        rotated_img = img.rotate(angle)
        rotation_input = self.base_transform(rotated_img)
        # rotation_label is just an integer 0..3

        # ---------------------------------------
        # (c) Jigsaw puzzle: shuffle patches, produce a class label
        # ---------------------------------------
        # We'll do a simplistic approach: pick a random permutation from jigsaw_permutations
        # and rearrange the patches accordingly. The "class label" is the index of that permutation.
        # In a real scenario, you might just feed the shuffled patches as input
        # and the label is the permutation index.
        if self.jigsaw_permutations is not None:
            perm_index = random.randint(0, len(self.jigsaw_permutations) - 1)
            perm = self.jigsaw_permutations[perm_index]  # e.g. [3,0,1,2,4,5,6,7,8]
            jigsaw_img = self.create_jigsaw(img, perm)
            jigsaw_input = self.base_transform(jigsaw_img)
            jigsaw_label = torch.tensor(perm_index, dtype=torch.long)
        else:
            # If you haven't defined permutations, fallback
            jigsaw_input = self.base_transform(img)
            jigsaw_label = torch.tensor(0, dtype=torch.long)

        # ---------------------------------------
        # (d) Masked Patch: randomly mask a fraction of patches
        # ---------------------------------------
        # We'll do a naive approach: zero out a random fraction of pixels. 
        # The "target" is the original unmasked image. 
        mask_input, mask_target = self.create_masked_input(img_tensor, self.mask_ratio)
        mask_input = self.base_transform(T.functional.to_pil_image(mask_input))
        mask_target = self.base_transform(T.functional.to_pil_image(mask_target))

        # Return a dictionary of tasks
        return {
            'denoise':  (denoise_input,  denoise_target),
            'rotation': (rotation_input, torch.tensor(rotation_label, dtype=torch.long)),
            'jigsaw':   (jigsaw_input,   jigsaw_label),
            'mask':     (mask_input,     mask_target),
        }

    # -----------------------------
    # Helper functions
    # -----------------------------
    def add_gaussian_noise(self, img_tensor, std=0.1):
        """
        Add random Gaussian noise to a [C, H, W] tensor, range ~[0,1].
        std controls the noise level.
        """
        noise = torch.randn_like(img_tensor) * std
        noisy_img = img_tensor + noise
        return torch.clamp(noisy_img, 0.0, 1.0)

    def create_jigsaw(self, img_pil, permutation, grid_size=3):
        """
        Slices the image into grid_size x grid_size patches, rearranges them
        according to 'permutation' (list of length 9 if grid_size=3),
        then reassembles them into a single PIL image.
        """
        w, h = img_pil.size
        patch_w = w // grid_size
        patch_h = h // grid_size
        patches = []
        
        # Cut into patches
        for row in range(grid_size):
            for col in range(grid_size):
                left = col * patch_w
                upper = row * patch_h
                box = (left, upper, left+patch_w, upper+patch_h)
                patch = img_pil.crop(box)
                patches.append(patch)
        
        # Rearrange using permutation
        shuffled_patches = [patches[i] for i in permutation]
        
        # Reassemble
        new_img = Image.new('RGB', (w, h))
        idx = 0
        for row in range(grid_size):
            for col in range(grid_size):
                new_img.paste(shuffled_patches[idx], (col*patch_w, row*patch_h))
                idx += 1
        
        return new_img

    def create_masked_input(self, img_tensor, mask_ratio=0.75):
        """
        Masks a random fraction (mask_ratio) of the image pixels by setting them to 0.
        Returns (masked_img, original_img) as Tensors in [0,1].
        """
        c, h, w = img_tensor.shape
        num_pixels = h * w
        num_mask = int(num_pixels * mask_ratio)

        # Flatten the image, choose which pixels to mask
        flat_img = img_tensor.view(c, -1).clone()  # shape [3, H*W]
        mask_indices = random.sample(range(num_pixels), num_mask)
        # zero out those pixels in all channels
        for mi in mask_indices:
            flat_img[:, mi] = 0.0
        
        masked_img = flat_img.view(c, h, w)
        # The target is the original unmasked image
        target_img = img_tensor.clone()
        return masked_img, target_img

### Unlabeled faces DataLoader

In [None]:
fer_2013_dir_train = Path(os.getcwd(), 'datasets', 'fer2013', 'train')

In [None]:
# jigsaw_permutations = generate_random_permutations(num_patches=9, num_permutations=30)
# fer2013_train_files = [p for p in Path(fer_2013_dir_train).rglob('*') if p.is_file()]

# dataset = MultiTaskDataset(fer2013_train_files, train_transforms, image_size=48,
#                            jigsaw_permutations=jigsaw_permutations,
#                            mask_ratio=0.75)

# ssl_dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

In [None]:
class UnlabeledFacesDataset(torch.utils.data.Dataset):
    def __init__(self, img_paths, transform=None):
        self.img_paths = img_paths
        self.transform = transform
    
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        # load, detect face, align, etc.
        face_img = cv2.imread(img_path).astype('float32')
        
        if face_img is None:
            # handle missing face or skip
            # for simplicity, just return a dummy or raise an error
            raise RuntimeError(f"No face found in {img_path}")
        
        if self.transform:
            face_img = self.transform(face_img)
        
        return face_img

### Labeled faces: CK+

In [None]:
class LabeledFERDataset(torch.utils.data.Dataset):
    def __init__(self, img_paths, labels, transform=None):
        self.img_paths = img_paths
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        label = self.labels[idx]
        
        face_img = cv2.imread(img_path).astype('float32')
        if face_img is None:
            raise RuntimeError(f"No face found in {img_path}")
        
        if self.transform:
            face_img = self.transform(face_img)
        
        return face_img, label

### FewShot Image Dataset: Manga faces

In [None]:
class FewShotFERDataset(torch.utils.data.Dataset):
    def __init__(self, class_folders, transform=None, n_way=5, k_shot=1, k_query=5):
        # class_folders: e.g. {class_name: [list_of_image_paths]}
        self.class_folders = class_folders
        self.transform = transform
        self.n_way = n_way
        self.k_shot = k_shot
        self.k_query = k_query
    
    def __len__(self):
        # The "length" might be the number of episodes you want
        return 1000  # or some large number for meta-training
    
    def __getitem__(self, idx):
        # Sample n_way classes
        sampled_classes = random.sample(list(self.class_folders.keys()), self.n_way)
        
        support_imgs = []
        support_labels = []
        query_imgs = []
        query_labels = []
        
        label_map = {cls_name: i for i, cls_name in enumerate(sampled_classes)}
        
        for cls_name in sampled_classes:
            paths = self.class_folders[cls_name]
            selected_paths = random.sample(paths, self.k_shot + self.k_query)
            s_paths = selected_paths[:self.k_shot]
            q_paths = selected_paths[self.k_shot:]
            
            for sp in s_paths:
                img = cv2.imread(sp).astype('float32')
                if img is not None and self.transform:
                    img = self.transform(img)
                support_imgs.append(img)
                support_labels.append(label_map[cls_name])
            
            for qp in q_paths:
                img = cv2.imread(qp).astype('float32')
                if img is not None and self.transform:
                    img = self.transform(img)
                query_imgs.append(img)
                query_labels.append(label_map[cls_name])
        
        # Convert lists to tensors
        support_imgs = torch.stack(support_imgs)  # shape: [n_way*k_shot, C, H, W]
        support_labels = torch.tensor(support_labels)  # [n_way*k_shot]
        query_imgs = torch.stack(query_imgs)      # shape: [n_way*k_query, C, H, W]
        query_labels = torch.tensor(query_labels) # [n_way*k_query]
        
        return (support_imgs, support_labels), (query_imgs, query_labels)

# Self-supervised Learning

In [None]:
class Identity(nn.Module):
    def forward(self, x):
        return x

In [None]:
class ViT_SelfSupervised(nn.Module):
    def __init__(self,
                 image_size=224,
                 patch_size=16,
                 in_chans=3,
                 embedding_dim=768,
                 num_rotation_classes=4,
                 num_jigsaw_classes=30):
        """
        image_size: input image resolution (assumed square)
        patch_size: size of each non-overlapping patch
        in_chans: number of input channels (e.g., 3 for RGB)
        embedding_dim: dimension of patch embeddings from ViT
        num_rotation_classes: number of rotation categories (0°, 90°, 180°, 270°)
        num_jigsaw_classes: number of predefined jigsaw permutation orders
        """
        super(ViT_SelfSupervised, self).__init__()
        
        # Calculate the number of patches (ViT divides the image into non-overlapping patches)
        self.num_patches = (image_size // patch_size) ** 2
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.embedding_dim = embedding_dim
        self.image_size = image_size

        # --- ViT Encoder ---
        # Instantiate the ViT model from TorchVision.
        # We set weights=False to train from scratch.
        self.encoder = vit_b_16(weights=None, image_size=image_size, num_classes=7)
        # Remove its default classification head by replacing with Identity.
        # IMPORTANT: In this example we assume that the modified encoder returns the CLS token and the patch tokens.
        self.encoder.heads = Identity()

        # --- Pretext Task Heads ---

        # (a) Denoising & Reconstruction head.
        # For each patch token, we want to predict the original patch pixels.
        # Here we use a simple linear projection: output dim = patch_size x patch_size x in_chans.
        self.denoise_decoder = nn.Linear(embedding_dim, patch_size * patch_size * in_chans)

        # (b) Rotation Prediction head.
        # Uses the CLS token from the encoder.
        self.rotation_head = nn.Linear(embedding_dim, num_rotation_classes)

        # (c) Jigsaw Puzzle head.
        # For the jigsaw task, we assume that the image (with patches shuffled) is processed by the encoder.
        # We then flatten the patch tokens (excluding the CLS token) and feed them to an MLP that predicts one of
        # num_jigsaw_classes possible permutation orders.
        self.jigsaw_head = nn.Sequential(
            nn.Linear(embedding_dim * self.num_patches, 1024),
            nn.ReLU(),
            nn.Linear(1024, num_jigsaw_classes)
        )

        # (d) Masked Patch Prediction head.
        # Similar to the denoising decoder, but applied only to the masked patches.
        self.mask_decoder = nn.Linear(embedding_dim, patch_size * patch_size * in_chans)

    def forward_encoder(self, x):
        """
        Forward pass through the ViT encoder.
        We assume that the encoder returns a tensor of shape:
           (B, num_tokens, embedding_dim)
        where token 0 is the CLS token and tokens 1: are the patch tokens.
        """
        
        # x: [B, C, H, W]
        # Use conv_proj to get patch embeddings.
        # This produces an output of shape [B, embed_dim, H_p, W_p],
        # where H_p * W_p equals the total number of patches.
        x = self.encoder.conv_proj(x)  # shape: [B, embedding_dim, H_p, W_p]

        B, C, H_p, W_p = x.shape
        # Flatten the spatial dimensions and transpose to get shape [B, num_patches, embedding_dim]
        x = x.flatten(2).transpose(1, 2)  # shape: [B, H_p*W_p, embedding_dim]

        # Expand the CLS token to batch size.
        #cls_tokens = self.encoder.cls_token.expand(B, -1, -1)  # shape: [B, 1, embedding_dim]
        cls_tokens = self.encoder.class_token.expand(B, -1, -1)


        # Concatenate the CLS token with the patch tokens.
        x = torch.cat((cls_tokens, x), dim=1)  # shape: [B, 1 + num_patches, embedding_dim]
        
        ### option 1
        # Retrieve positional embeddings from state_dict
        # pos_embed = self.encoder.state_dict()['pos_embed']  # shape: [1, 1+num_patches, embedding_dim]
        
        # #pos_embed = self.encoder.pos_embed
        # x = x + pos_embed.to(x.device)
        
        # x = self.encoder.dropout(x)

        # # Pass through transformer encoder blocks.
        # for blk in self.encoder.encoder.blocks:
        #     x = blk(x)
        # # Apply final normalization.
        # x = self.encoder.encoder.norm(x)
        
        
        ### option 2
        # Check where the correct dropout layer is located
        if hasattr(self.encoder, 'dropout') and callable(self.encoder.dropout):
            x = self.encoder.dropout(x)  # Apply dropout if it exists
        elif hasattr(self.encoder.encoder, 'dropout') and callable(self.encoder.encoder.dropout):
            x = self.encoder.encoder.dropout(x)  # Apply encoder-level dropout
            
        for blk in self.encoder.encoder.layers:
            x = blk(x)
        x = self.encoder.encoder.ln(x)
        
        return x

    # def forward_denoise(self, x):
    #     """
    #     Image denoising and reconstruction task.
    #     x: input noisy images.
    #     Returns: reconstructed patches for all patch tokens.
    #     """
    #     tokens = self.forward_encoder(x)  # shape: (B, 1 + num_patches, D)
    #     # Exclude the CLS token (first token)
    #     patch_tokens = tokens[:, 1:, :]  # shape: (B, num_patches, D)
    #     # Reconstruct each patch.
    #     # The output will be of shape (B, num_patches, patch_size*patch_size*in_chans)
    #     recon_patches = self.denoise_decoder(patch_tokens)
    #     return recon_patches
    def forward_denoise(self, x):
        """Denoising task with proper reshaping"""
        B = x.shape[0]
        tokens = self.forward_encoder(x)
        patch_tokens = tokens[:, 1:, :]  # [B, num_patches, D]
        
        # Reconstruct patches
        recon_patches = self.denoise_decoder(patch_tokens)  # [B, 9, 768]
        
        # Reshape to original image dimensions [B, 3, 48, 48]
        recon_image = self.patches_to_image(recon_patches, B)
        return recon_image

    def forward_masked(self, x):
        """Masked patch prediction with proper reshaping"""
        B = x.shape[0]
        tokens = self.forward_encoder(x)
        patch_tokens = tokens[:, 1:, :]  # [B, 9, D]
        
        # Reconstruct patches
        mask_predictions = self.mask_decoder(patch_tokens)  # [B, 9, 768]
        
        # Reshape to original image dimensions [B, 3, 48, 48]
        recon_image = self.patches_to_image(mask_predictions, B)
        return recon_image

    def patches_to_image(self, patches, batch_size):
        """Convert patch sequence to image tensor"""
        # patches shape: [B, num_patches, patch_size^2 * 3]
        patch_size = self.patch_size
        num_patches = self.num_patches
        channels = self.in_chans
        
        # Reshape to [B, num_patches, C, patch_size, patch_size]
        patches = patches.view(
            batch_size, 
            num_patches, 
            channels, 
            patch_size, 
            patch_size
        )
        
        # Reshape to image grid
        grid_size = int(num_patches ** 0.5)  # 3 for 9 patches
        image = patches.permute(0, 2, 1, 3, 4)  # [B, C, num_patches, p, p]
        image = image.contiguous().view(
            batch_size, 
            channels, 
            grid_size * patch_size, 
            grid_size * patch_size
        )
        return image

    def forward_rotation(self, x):
        """
        Image rotation prediction task.
        x: input rotated images.
        Returns: rotation logits predicted from the CLS token.
        """
        tokens = self.forward_encoder(x)  # shape: (B, 1 + num_patches, D)
        cls_token = tokens[:, 0, :]         # shape: (B, D)
        rotation_logits = self.rotation_head(cls_token)
        return rotation_logits

    def forward_jigsaw(self, x):
        """
        Jigsaw puzzle task.
        x: input images with shuffled patches.
        Returns: logits for predicting the permutation order (classification over num_jigsaw_classes).
        """
        tokens = self.forward_encoder(x)      # shape: (B, 1 + num_patches, D)
        patch_tokens = tokens[:, 1:, :]         # remove CLS token, shape: (B, num_patches, D)
        # Flatten patch tokens for each image: (B, num_patches * D)
        flat_tokens = patch_tokens.reshape(x.size(0), -1)
        jigsaw_logits = self.jigsaw_head(flat_tokens)
        return jigsaw_logits

    # def forward_masked(self, x):
    #     """
    #     Masked patch prediction task.
    #     x: input images with some patches masked out.
    #        (The masking operation should be performed in the data pre-processing or transform.)
    #     Returns: reconstructed predictions for the (unobserved) masked patches.
    #     For simplicity, here we process the entire set of patch tokens, and later you would compare
    #     the output for masked locations with the ground truth.
    #     """
    #     tokens = self.forward_encoder(x)      # shape: (B, 1 + num_patches, D)
    #     patch_tokens = tokens[:, 1:, :]         # shape: (B, num_patches, D)
    #     # Predict patch pixel values for each token.
    #     mask_predictions = self.mask_decoder(patch_tokens)
    #     return mask_predictions

    def forward(self, x, task):
        """
        A unified forward method that selects the appropriate pretext task.
        task: a string specifying the task type: 'denoise', 'rotation', 'jigsaw', or 'mask'.
        """
        if task == 'denoise':
            return self.forward_denoise(x)
        elif task == 'rotation':
            return self.forward_rotation(x)
        elif task == 'jigsaw':
            return self.forward_jigsaw(x)
        elif task == 'mask':
            return self.forward_masked(x)
        else:
            raise ValueError("Invalid task specified. Choose from 'denoise', 'rotation', 'jigsaw', or 'mask'.")

In [None]:
class MultiTaskLossWrapper(nn.Module):
    def __init__(self):
        super().__init__()
        # Initialize the log-variances (or deltas) as learnable parameters.
        # We'll store them in log-space for numerical stability, e.g. log(delta^2).
        self.log_var_den = nn.Parameter(torch.zeros(1))
        self.log_var_rot = nn.Parameter(torch.zeros(1))
        self.log_var_puz = nn.Parameter(torch.zeros(1))
        self.log_var_msk = nn.Parameter(torch.zeros(1))
    
    def forward(self, L_den, L_rot, L_puz, L_msk):
        # Convert log_var_* to delta = exp(log_var / 2)
        # But the formula in the paper effectively uses 1/(2*delta^2), so we can do:
        # Weighted losses:
        w_den = 1.0 / (2.0 * torch.exp(self.log_var_den))
        w_rot = 1.0 / (2.0 * torch.exp(self.log_var_rot))
        w_puz = 1.0 / (2.0 * torch.exp(self.log_var_puz))
        w_msk = 1.0 / (2.0 * torch.exp(self.log_var_msk))

        # Combined loss
        loss = (w_den * L_den 
                + w_rot * L_rot
                + w_puz * L_puz
                + w_msk * L_msk
                + (self.log_var_den + self.log_var_rot 
                   + self.log_var_puz + self.log_var_msk) * 0.5)
        
        return loss

In [None]:
multi_task_loss_fn = MultiTaskLossWrapper().to(device)

In [None]:
def train_ssl_model(model, dataloader, num_epochs=10, device='cuda', learning_rate=0.01):
    """
    This is a simplified training loop that assumes:
      - The dataloader returns a dictionary containing inputs for each task.
        For example, each batch might be a dict with keys: 'denoise', 'rotation', 'jigsaw', 'mask'
      - Each key maps to a tuple: (input_tensor, target) where target is the ground truth for that task.
    In practice, you may need separate dataloaders or combine losses with appropriate weighting.
    """
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.1)
    
    # For simplicity, we use fixed weights for each task’s loss.
    weight_denoise = 1.0
    weight_rotation = 1.0
    weight_jigsaw = 1.0
    weight_mask = 1.0
    
    best_loss = np.inf
    patience_counter = 0   # Tracks the number of epochs without improvement
    early_stop = False # Flag to indicate whether to stop training
    save_weights_patience = 3  # Stop training if no improvement after this many epochs
    
    metrics_loss = []

    for epoch in range(num_epochs):
        if early_stop:
            print("Early stopping triggered. Stopping training.")
            break
        print(f"Epoch {epoch + 1}, LR: {scheduler.optimizer.param_groups[0]['lr']}")
        
        model.train()
        total_loss = 0.0
        tk = tqdm(dataloader, desc="EPOCH" + "[TRAIN]" + str(epoch + 1) + "/" + str(num_epochs))
        #print('DATALOADER LEN \n', len(dataloader))
        
        for t, batch in enumerate(tk):
            optimizer.zero_grad()
            #loss = 0.0
            L_den = 0.0
            L_rot = 0.0
            L_jig = 0.0
            L_mask = 0.0
            #print('BATCH IS \n', batch)
            #break

            # --- Denoising & Reconstruction Task ---
            if 'denoise' in batch:
                inputs_denoise, targets_denoise = batch['denoise']  # targets: original patches or full image patches
                inputs_denoise = inputs_denoise.to(device)
                targets_denoise = targets_denoise.to(device)
                recon_patches = model(inputs_denoise, task='denoise')
                # Reshape targets to match the output if necessary.
                loss_denoise = F.mse_loss(recon_patches, targets_denoise)
                L_den += weight_denoise * loss_denoise

            # --- Rotation Prediction Task ---
            if 'rotation' in batch:
                inputs_rotation, targets_rotation = batch['rotation']
                inputs_rotation = inputs_rotation.to(device)
                targets_rotation = targets_rotation.to(device)
                rotation_logits = model(inputs_rotation, task='rotation')
                loss_rotation = F.cross_entropy(rotation_logits, targets_rotation)
                L_rot += weight_rotation * loss_rotation

            # --- Jigsaw Puzzle Task ---
            if 'jigsaw' in batch:
                inputs_jigsaw, targets_jigsaw = batch['jigsaw']  # targets: permutation labels (integer class index)
                inputs_jigsaw = inputs_jigsaw.to(device)
                targets_jigsaw = targets_jigsaw.to(device)
                jigsaw_logits = model(inputs_jigsaw, task='jigsaw')
                loss_jigsaw = F.cross_entropy(jigsaw_logits, targets_jigsaw)
                L_jig += weight_jigsaw * loss_jigsaw

            # --- Masked Patch Prediction Task ---
            if 'mask' in batch:
                inputs_mask, targets_mask = batch['mask']
                inputs_mask = inputs_mask.to(device)
                targets_mask = targets_mask.to(device)
                mask_predictions = model(inputs_mask, task='mask')
                loss_mask = F.mse_loss(mask_predictions, targets_mask)
                L_mask += weight_mask * loss_mask
                
            # Combine them with learned weights:
            L_ssl = multi_task_loss_fn(L_den, L_rot, L_jig, L_mask)

            L_ssl.backward()
            optimizer.step()
            total_loss += L_ssl.item()
            tk.set_postfix({'loss': '%6f' % float(L_ssl / (t + 1))})
        
        avg_loss = total_loss / len(dataloader)
        #print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
        
        metrics_loss.append(avg_loss)
        scheduler.step(avg_loss)
        
        if avg_loss < best_loss:
            # Save the pretrained encoder weights (and optionally heads) for later fine-tuning.
            torch.save(model.state_dict(), "vit_ssl_pretrained.pth")
            print("SAVED-BEST-WEIGHTS!")
            best_loss = avg_loss
            patience_counter = 0 # Reset early stopping
        else:
            patience_counter += 1
            print(f"No improvement in validation loss for {patience_counter} epoch(s).")

        if patience_counter >= save_weights_patience:
            print("Patience exceeded. Early stopping at epoch " +str(epoch + 1))
            early_stop = True

    print("")
    return metrics_loss

In [None]:
# -------------------------------
# Example usage:
# Assume we have a combined dataloader that yields a dict with keys: 'denoise', 'rotation', 'jigsaw', 'mask'.
# Each entry is a tuple: (input_tensor, target_tensor)
# In practice, you need to implement or combine datasets that perform the corresponding data augmentation.
# -------------------------------
if __name__ == "__main__":
    # Create model instance 
    model_ssl = ViT_SelfSupervised(image_size=48, patch_size=16, in_chans=3,
                                   embedding_dim=768, num_rotation_classes=4,
                                   num_jigsaw_classes=30)
    #print(dir(model_ssl))
    # Here, "ssl_dataloader" should be defined by you.
    # For demonstration purposes, assume it's provided.
    # ssl_dataloader = ...
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # train_ssl_model(model_ssl, ssl_dataloader, num_epochs=10, device=device)
    
    jigsaw_permutations = generate_random_permutations(num_patches=9, num_permutations=30)
    fer2013_train_files = [p for p in Path(fer_2013_dir_train).rglob('*') if p.is_file()]

    dataset = MultiTaskDataset(fer2013_train_files, train_transforms, image_size=48,
                            jigsaw_permutations=jigsaw_permutations,
                            mask_ratio=0.75)

    ssl_dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)
    
    pretrained_metrics_loss = train_ssl_model(model_ssl, ssl_dataloader, num_epochs=10, device=device, learning_rate=0.01)

# ViT encoder fine-tuning

### Data Preprocessing

In [None]:
train_transforms = T.Compose([
    T.Resize((48, 48)),  # Upsample images to 224x224.
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
])

val_transforms = T.Compose([
    T.Resize((48, 48)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
])

ckplus_dir_train = Path(os.getcwd(), 'datasets', 'fer2013')

train_dataset = ImageFolder(root=ckplus_dir_train.joinpath('train'), transform=train_transforms)
val_dataset   = ImageFolder(root=ckplus_dir_train.joinpath('test'), transform=val_transforms)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

### FineTuningModel class

In [None]:
class FineTuningModel(nn.Module):
    def __init__(self, num_classes=7, image_size=224, use_pretrained_ssl=True, pretrained_path="vit_ssl_pretrained.pth"):
        super(FineTuningModel, self).__init__()
        
        # Load the ViT model from torchvision.
        self.encoder = vit_b_16(weights=None, image_size=image_size)
        # Remove the default classification head.
        self.encoder.heads = nn.Identity()
        
        # Optionally load the pretrained weights from SSL stage.
        if use_pretrained_ssl:
            state_dict = torch.load(pretrained_path, map_location='cuda')
            self.encoder.load_state_dict(state_dict, strict=False)
        
        # Define a new classification head.
        embedding_dim = 768  # This is the standard for vit_b_16.
        self.classifier = nn.Linear(embedding_dim, num_classes)
    
    def forward(self, x):
        # Assume the encoder outputs a tensor of shape [B, 1 + num_patches, D].
        tokens = self.encoder(x)
        # We use the CLS token (first token) for classification.
        #cls_token = tokens[:, 0, :]
        #logits = self.classifier(cls_token)
        logits = self.classifier(tokens)
        
        return logits

### Create Train and Test functions

In [None]:
def train_one_epoch(model, dataloader, criterion, optimizer, device, current_epoch, epochs):
    """
    Train one epoch of the model.

    Args:
        model (nn.Module): The  model.
        dataloader (DataLoader): DataLoader for training data.
        device (torch.device): Device to train the model on (CPU/GPU).

    Returns:
        training_loss (float): Returns epoch_loss / len(dataloader)
    """
    model.train()  # Set model to training mode
    epoch_loss = 0.0
    tk = tqdm(dataloader, desc="EPOCH" + "[TRAIN]" + str(current_epoch + 1) + "/" + str(epochs))

    for t, data in enumerate(tk):
        images, labels = data

        images = images.to(device)
        labels = labels.to(device)

        # Zero gradients
        optimizer.zero_grad()

        # Compute log probabilities from model
        logits = model(images)
        loss = criterion(logits, labels)

        # Backpropagation
        loss.backward()

        # Update model parameters
        optimizer.step()

        # Accumulate loss for logging; Total loss
        epoch_loss += loss.item()

        # TODO: Print/log training loss for this epoch
        tk.set_postfix({'loss': '%6f' % float(epoch_loss / (t + 1))})

    return epoch_loss / len(dataloader)

In [None]:
def test_one_epoch(model, dataloader, criterion, device, current_epoch, epochs):
    """
    Test one epoch of the model

    Args:
        model (nn.Module): The model.
        dataloader (DataLoader): DataLoader for training data.
        learning_rate (float): Learning rate for optimizer.
        device (torch.device): Device to train the model on (CPU/GPU).

    Returns:
        training_loss (float): Returns epoch_loss / len(dataloader)
        
        running_acc (float): Returns running accuracy
    """
    model.eval()  # Set model to evaluation mode
    epoch_loss = 0.0
    running_acc = 0.0

    tk = tqdm(dataloader, desc="EPOCH" + "[VALID]" + str(current_epoch + 1) + "/" + str(epochs))

    correct = 0
    total = 0
    with torch.no_grad():  # Disable gradient calculation for testing
        for t, data in enumerate(tk):          
            images, labels = data

            images = images.to(device)
            labels = labels.to(device)

            # Compute log probabilities from model
            logits = model(images)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += images.size(0)            

            # Compute CTC loss
            loss = criterion(logits, labels)

            # Accumulate loss for logging; Total loss
            epoch_loss += loss.item()
            
            running_acc = correct / total * 100

            tk.set_postfix({'loss': '%6f' % float(epoch_loss / (t + 1)), 'acc': '%2f%%' %float(running_acc),})

    return epoch_loss / len(dataloader), running_acc


In [None]:
def train_and_validate_model(model, training_dataloader, testing_dataloader, epochs, learning_rate, device):
    """
    Train and Test the speech recognition model using CTC loss.

    Args:
        model (nn.Module): The model.
        training_dataloader (DataLoader): DataLoader for training data.
        testing_dataloader (DataLoader): DataLoader for testing data.
        epochs (int): Number of training epochs.
        learning_rate (float): Learning rate for optimizer.
        device (torch.device): Device to train the model on (CPU/GPU).
    """
    # Define Loss function
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.1)

    # Move model to device
    model.to(device)

    best_valid_loss = np.inf
    patience_counter = 0   # Tracks the number of epochs without improvement
    early_stop = False # Flag to indicate whether to stop training
    save_weights_patience = 5

    # Dictionary to store loss values over epochs
    metrics_loss = {
        'training_loss': [],
        'validation_loss': []
    }

    for epoch in range(epochs):
        if early_stop:
            print("Early stopping triggered. Stopping training.")
            break

        print(f"Epoch {epoch + 1}, LR: {scheduler.optimizer.param_groups[0]['lr']}")

        # Training step
        train_loss = train_one_epoch(model, training_dataloader, criterion, optimizer, device, epoch, epochs)
        
        # Testing step
        valid_loss, valid_accuracy = test_one_epoch(model, testing_dataloader, criterion, device, epoch, epochs) 

        metrics_loss['training_loss'].append(train_loss)
        metrics_loss['validation_loss'].append(valid_loss)

        # Update the learning rate based on validation loss and print
        scheduler.step(valid_loss)

        if valid_loss < best_valid_loss:
            torch.save(model.state_dict(), 'fine_tuned_with_fer2013_NO_PRETRAIN.pt')
            print("SAVED-BEST-WEIGHTS!")
            best_valid_loss = valid_loss
            patience_counter = 0 # Reset early stopping
        else:
            patience_counter += 1
            print(f"No improvement in validation loss for {patience_counter} epoch(s).")

        if patience_counter >= save_weights_patience:
            print("Patience exceeded. Early stopping at epoch " +str(epoch + 1))
            early_stop = True
            
        
    print("")
    #return model
    return metrics_loss


### Create model instance and call train functions

In [None]:
# Create the model instance:
model_fer = FineTuningModel(num_classes=7, image_size=48, use_pretrained_ssl=False)

# Train the model:
finetuning_losses = train_and_validate_model(model_fer, train_loader, val_loader, epochs=20, learning_rate=0.01, device=device)

# FSL Domain Adaptation