# Setup

In [None]:
import math
import warnings
from typing import Any, Dict, Optional

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from lightly.models import utils
from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM
from timm.models.vision_transformer import VisionTransformer
from torch import nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR

In [None]:
pl.seed_everything(seed=3, workers=True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.set_float32_matmul_precision("medium")

warnings.filterwarnings(
    "ignore",
    "Precision bf16-mixed is not supported by the model summary",
)

warnings.filterwarnings(
    "ignore",
    "Please use the new API settings to control TF32 behavior",
)

# Models

## Masked Autoencoder

In [None]:
class MaskedAutoencoder(nn.Module):
    """Masked Autoencoder (MAE) with ViT backbone."""

    def __init__(
        self,
        general_cfg: Dict[str, Any],
        encoder_cfg: Dict[str, Any],
        decoder_cfg: Dict[str, Any],
    ):
        super().__init__()

        self.mask_ratio = general_cfg.get("mask_ratio", 0.75)
        self.image_size = general_cfg.get("image_size", 96)
        self.patch_size = general_cfg.get("patch_size", 6)
        self.in_chans = general_cfg.get("in_chans", 3)

        vit = VisionTransformer(
            img_size=self.image_size,
            patch_size=self.patch_size,
            in_chans=self.in_chans,
            embed_dim=encoder_cfg.get("embed_dim", 384),
            depth=encoder_cfg.get("depth", 12),
            num_heads=encoder_cfg.get("num_heads", 6),
            num_classes=0,
        )

        self.encoder = MaskedVisionTransformerTIMM(vit=vit)
        self.sequence_length = getattr(
            self.encoder,
            "sequence_length",
            vit.patch_embed.num_patches + 1,
        )

        self.decoder = MAEDecoderTIMM(
            num_patches=vit.patch_embed.num_patches,
            patch_size=self.patch_size,
            embed_dim=encoder_cfg.get("embed_dim", 384),
            decoder_embed_dim=decoder_cfg.get("decoder_embed_dim", 512),
            decoder_depth=decoder_cfg.get("decoder_depth", 4),
            decoder_num_heads=decoder_cfg.get("decoder_num_heads", 6),
        )

    def forward_encoder(self, images: torch.Tensor, idx_keep=None):
        return self.encoder.encode(images=images, idx_keep=idx_keep)

    def forward_decoder(self, x_encoded, idx_keep, idx_mask):
        batch_size = x_encoded.shape[0]
        x_decode = self.decoder.embed(x_encoded)

        x_masked = utils.repeat_token(
            token=self.decoder.mask_token,
            size=(batch_size, self.sequence_length),
        )
        x_masked = utils.set_at_index(
            tokens=x_masked,
            index=idx_keep,
            value=x_decode.type_as(x_masked),
        )

        x_decoded = self.decoder.decode(x_masked)
        x_pred = utils.get_at_index(tokens=x_decoded, index=idx_mask)
        x_pred = self.decoder.predict(x_pred)

        return x_pred

    def forward(self, images: torch.Tensor):
        batch_size = images.shape[0]
        idx_keep, idx_mask = utils.random_token_mask(
            size=(batch_size, self.sequence_length),
            mask_ratio=self.mask_ratio,
            device=images.device,
        )

        x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)
        x_pred = self.forward_decoder(
            x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask
        )

        patches = utils.patchify(images=images, patch_size=self.patch_size)
        idx_mask_adj = torch.clamp(idx_mask - 1, min=0)
        target = utils.get_at_index(tokens=patches, index=idx_mask_adj)

        return x_pred, target

## Vision Transformer Classifier

In [None]:
class ViTClassifier(nn.Module):
    """Classifier built on top of a pretrained ViT encoder."""

    def __init__(
        self,
        pretrained_encoder: VisionTransformer,
        num_classes: int = 10,
        head_cfg: Optional[Dict[str, Any]] = None,
    ):
        super().__init__()
        self.encoder = pretrained_encoder
        head_cfg = head_cfg or {}

        embed_dim = head_cfg.get("embed_dim", pretrained_encoder.embed_dim)
        pool_type = head_cfg.get("pool", "cls")  # or "mean"

        self.pool_type = pool_type
        self.head = torch.nn.Linear(embed_dim, num_classes)

    def forward(self, x: torch.Tensor):
        feats = self.encoder.forward_features(x)
        if isinstance(feats, (tuple, list)):
            feats = feats[0]

        if self.pool_type == "cls":
            pooled = feats[:, 0]
        else:
            pooled = feats.mean(dim=1)

        return self.head(pooled)

# Training Modules

## Masked Autoencoder Pretraining

In [None]:
class MAEPretrainModule(pl.LightningModule):
    """Self-supervised pretraining for Masked Autoencoder."""

    def __init__(
        self,
        model_cfg: Dict[str, Any],
        training_cfg: Dict[str, Any],
    ):
        super().__init__()
        self.save_hyperparameters()

        self.model = MaskedAutoencoder(
            general_cfg=model_cfg["general"],
            encoder_cfg=model_cfg["encoder"],
            decoder_cfg=model_cfg["decoder"],
        )

        self.mask_start = training_cfg.get("mask_ratio_start", 0.5)
        self.mask_end = training_cfg.get("mask_ratio_end", 0.85)
        self.ramp_epochs = training_cfg.get("mask_ramp_epochs", 200)

        self.lr = float(training_cfg.get("base_learning_rate", 1.5e-4))
        self.weight_decay = float(training_cfg.get("weight_decay", 0.05))
        self.warmup_epochs = int(training_cfg.get("warmup_epochs", 20))
        self.total_epochs = int(training_cfg.get("total_epochs", 200))
        self.batch_size = int(training_cfg.get("batch_size", 512))
        self.criterion = torch.nn.MSELoss()

    def forward(self, x: torch.Tensor):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        imgs, _ = batch
        preds, targets = self(imgs)
        loss = self.criterion(preds, targets)
        self.log("train_loss", loss, prog_bar=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        imgs, _ = batch
        preds, targets = self(imgs)
        loss = self.criterion(preds, targets)
        self.log("val_loss", loss, prog_bar=True, on_epoch=True)
        return loss

    def configure_optimizers(self):
        effective_lr = self.lr * self.batch_size / 256
        optimizer = AdamW(
            self.parameters(),
            lr=effective_lr,
            weight_decay=self.weight_decay,
        )

        def lr_lambda(epoch):
            warmup = (epoch + 1) / max(1, self.warmup_epochs)
            cosine = 0.5 * (1 + math.cos(math.pi * epoch / self.total_epochs))
            return min(warmup, 1.0) * cosine

        scheduler = LambdaLR(optimizer, lr_lambda)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {"scheduler": scheduler, "interval": "epoch", "name": "lr"},
        }

    def on_train_epoch_start(self):
        """Update mask ratio linearly over epochs."""
        progress = min(self.current_epoch / max(1, self.ramp_epochs - 1), 1.0)
        new_mask = self.mask_start + progress * (self.mask_end - self.mask_start)
        self.model.mask_ratio = new_mask
        self.log("mask_ratio", new_mask, prog_bar=True)

## Vision Transformer Training

In [None]:
class ViTClassifierTrainModule(pl.LightningModule):
    """
    Supervised training for ViTClassifier.

    Supports randomly initialized or pretrained ViT encoder.
    Provides both fine-tuning and linear probing.
    """

    def __init__(
        self,
        pretrained_encoder: Optional[torch.nn.Module] = None,
        model_cfg: Optional[Dict[str, Any]] = None,
        training_cfg: Optional[Dict[str, Any]] = None,
        num_classes: int = 10,
    ):
        super().__init__()
        self.save_hyperparameters(ignore=["pretrained_encoder"])

        self.model_cfg = model_cfg or {}
        self.training_cfg = training_cfg or {}

        self.learning_rate = float(training_cfg.get("learning_rate", 3e-4))
        self.weight_decay = float(training_cfg.get("weight_decay", 0.05))
        self.warmup_epochs = int(training_cfg.get("warmup_epochs", 5))
        self.total_epochs = int(training_cfg.get("total_epochs", 100))
        self.freeze_encoder_flag = self.training_cfg.get("freeze_encoder", True)
        self.num_classes = num_classes

        # Build model
        encoder_cfg = self.model_cfg.get("encoder", {})
        encoder = (
            pretrained_encoder
            if pretrained_encoder is not None
            else VisionTransformer(
                img_size=self.model_cfg["general"]["image_size"],
                patch_size=self.model_cfg["general"]["patch_size"],
                in_chans=self.model_cfg["general"]["in_chans"],
                embed_dim=encoder_cfg.get("embed_dim", 384),
                depth=encoder_cfg.get("depth", 12),
                num_heads=encoder_cfg.get("num_heads", 6),
                num_classes=0,
            )
        )

        self.model = ViTClassifier(
            pretrained_encoder=encoder,
            num_classes=self.num_classes,
            head_cfg=self.model_cfg.get("head", {}),
        )

        # Freeze or unfreeze encoder as requested
        if self.freeze_encoder_flag:
            self.freeze_encoder()
        else:
            self.unfreeze_encoder()

    def forward(self, x: torch.Tensor):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        imgs, labels = batch
        logits = self(imgs)
        loss = F.cross_entropy(logits, labels)
        acc = (logits.argmax(dim=1) == labels).float().mean()

        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        logits = self(imgs)
        loss = F.cross_entropy(logits, labels)
        acc = (logits.argmax(dim=1) == labels).float().mean()

        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        imgs, labels = batch
        logits = self(imgs)
        loss = F.cross_entropy(logits, labels)
        acc = (logits.argmax(dim=1) == labels).float().mean()

        self.log("test_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test_acc", acc, on_step=False, on_epoch=True, prog_bar=True)

        return loss

    def configure_optimizers(self):
        optimizer = AdamW(
            filter(lambda p: p.requires_grad, self.model.parameters()),
            lr=self.learning_rate,
            weight_decay=self.weight_decay,
        )

        def lr_lambda(epoch):
            warmup = (epoch + 1) / max(1, self.warmup_epochs)
            cosine = 0.5 * (1 + math.cos(math.pi * epoch / self.total_epochs))
            return min(warmup, 1.0) * cosine

        scheduler = LambdaLR(optimizer, lr_lambda)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
                "monitor": "val_loss",
            },
        }

    def freeze_encoder(self):
        for name, param in self.model.named_parameters():
            if "head" not in name:
                param.requires_grad = False
        print("ðŸ§Š Encoder frozen (only classifier head is trainable).")

    def unfreeze_encoder(self):
        for param in self.model.parameters():
            param.requires_grad = True
        print("ðŸ”¥ Encoder unfrozen (all parameters trainable).")

    def unfreeze_last_layers(self, n_layers: int):
        """
        Unfreezes only the last `n_layers` Transformer blocks of the ViT encoder.
        All earlier layers remain frozen.
        """
        encoder = self.model.encoder  # timm VisionTransformer
        blocks = encoder.blocks  # list of Transformer blocks
        total = len(blocks)

        if n_layers < 0 or n_layers > total:
            raise ValueError(f"n_layers must be between 0 and {total}, got {n_layers}")

        print(f"ðŸ”“ Unfreezing last {n_layers} of {total} encoder layers...")

        # 1) Freeze ALL parameters first
        for param in encoder.parameters():
            param.requires_grad = False

        # 2) Unfreeze the last N Transformer blocks
        for block in blocks[total - n_layers :]:
            for param in block.parameters():
                param.requires_grad = True

        # 3) Also unfreeze the final LN (norm) layer
        if hasattr(encoder, "norm"):
            for param in encoder.norm.parameters():
                param.requires_grad = True

        # 4) Head (classifier) is always trainable
        for param in self.model.head.parameters():
            param.requires_grad = True

        print("ðŸ”¥ Selective unfreezing complete.")