# SSL


## Import Libraries

In [None]:
from typing import Optional, Callable, Union
from enum import Enum
from pathlib import Path

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
import torchvision
from torchvision import models, transforms
from torchvision.datasets import CIFAR10

from tqdm.auto import tqdm

### Check GPU


In [None]:
!nvidia-smi

In [None]:
DEVICE_NUM = 0

device = torch.device(f"cuda:{DEVICE_NUM}" if torch.cuda.is_available() else "cpu")
print("INFO: Using device -", device)

## Load Dataset

In [None]:
class DataType(Enum):
    LABELED_TRAIN = 0
    UNLABELED_TRAIN = 1
    UNLABELED_VALID = 2
    VALID = 3
    TEST = 4

In [None]:
torchvision.datasets.utils.tqdm = tqdm


class CIFAR10Dataset(CIFAR10):
    UNLABELED_DATA_TYPE = {
        DataType.UNLABELED_TRAIN,
        DataType.UNLABELED_VALID
    }
    _indices_cache = {}

    def __init__(
        self,
        root: str,
        data_type: DataType,
        validation_split: float = 0.1,
        labeled_split: float = 0.1,
        transform: Optional[Callable] = None,
    ):
        super().__init__(root, train=(data_type != DataType.TEST), transform=None, download=True)
        self.transform = transform
        self.data_type = data_type
        self.is_unlabeled = data_type in self.UNLABELED_DATA_TYPE

        if data_type == DataType.TEST:
            self.indices = np.arange(len(self.data))
        else:
            self.indices = self._get_indices(validation_split, labeled_split)

    def _get_indices(self, validation_split: float, labeled_split: float):
        cache_key = (len(self.data), validation_split, labeled_split)

        if cache_key not in self._indices_cache:
            rng = np.random.default_rng(seed=42)
            indices = rng.permutation(len(self.data))

            val_size = int(len(self.data) * validation_split)
            val_indices = indices[:val_size]
            train_indices = indices[val_size:]

            labeled_size = int(len(train_indices) * labeled_split)
            labeled_indices = train_indices[:labeled_size]
            unlabeled_indices = train_indices[labeled_size:]

            self._indices_cache[cache_key] = {
                DataType.LABELED_TRAIN: labeled_indices,
                DataType.UNLABELED_TRAIN: unlabeled_indices,
                DataType.UNLABELED_VALID: val_indices,
                DataType.VALID: val_indices,
            }

        return self._indices_cache[cache_key][self.data_type]

    def __getitem__(self, index):
        if self.is_unlabeled:
            img_idx, rotation_k = divmod(index, 4)
            actual_idx = self.indices[img_idx]
            target = rotation_k
        else:
            actual_idx = self.indices[index]
            target = self.targets[actual_idx]

        img = Image.fromarray(self.data[actual_idx])

        if self.is_unlabeled and rotation_k > 0:
            img = img.rotate(90 * rotation_k)

        if self.transform:
            img = self.transform(img)

        return img, target

    def __len__(self):
        return len(self.indices) * (4 if self.is_unlabeled else 1)

In [None]:
DATA_ROOT = './data'

IMG_SIZE = (32, 32)
IMG_NORM = dict(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)

resizer = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(**IMG_NORM)
])

In [None]:
labeled_data = CIFAR10Dataset(DATA_ROOT, DataType.LABELED_TRAIN, transform=resizer)
unlabeled_data = CIFAR10Dataset(DATA_ROOT, DataType.UNLABELED_TRAIN, transform=resizer)
valid_data = CIFAR10Dataset(DATA_ROOT, DataType.VALID, transform=resizer)
unlabeled_valid_data = CIFAR10Dataset(DATA_ROOT, DataType.UNLABELED_VALID, transform=resizer)
test_data = CIFAR10Dataset(DATA_ROOT, DataType.TEST, transform=resizer)

## DataLoader

In [None]:
# Set Batch Size
class BatchSize:
    labeled: int = 256
    unlabeled: int = 1024
    valid: int = 1024

batch_config = BatchSize()

In [None]:
labeled_loader = DataLoader(labeled_data, batch_size=batch_config.labeled, shuffle=True)
unlabeled_loader = DataLoader(unlabeled_data, batch_size=batch_config.unlabeled, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=batch_config.valid, shuffle=True)
unlabeled_valid_loader = DataLoader(unlabeled_valid_data, batch_size=batch_config.valid, shuffle=True)
test_loader = DataLoader(test_data, shuffle=True)

## Define Model

In [None]:
class Model(nn.Module):
    def __init__(self, num_classes: int = 10):
        super().__init__()

        self.embed_dim = 512
        self.in_channels = 3
        self.is_pretext = False

        backbone = models.resnet18()
        backbone.conv1 = nn.Conv2d(self.in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)
        backbone.maxpool = nn.Identity()
        self.backbone = nn.Sequential(*list(backbone.children())[:-2])

        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(self.embed_dim, num_classes)
        )

        self.rotation_classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(self.embed_dim, 4)
        )

    def forward(self, x):
        out = self.backbone(x)
        if self.is_pretext:
            out = self.rotation_classifier(out)
        else:
            out = self.classifier(out)
        return out

    @classmethod
    def from_checkpoint(cls, ckpt_path: Union[str, Path], num_classes: int = 10):
        model = cls(num_classes=num_classes)
        state_dict = torch.load(ckpt_path, map_location='cpu')
        model.load_state_dict(state_dict, strict=False)
        print(f"INFO: Backbone weights successfully loaded from: {ckpt_path}")
        return model

## Trainer Class

In [None]:
class Trainer:
    def __init__(self, model: nn.Module):
        self.model = model
        self.device = None
        self.criterion = nn.CrossEntropyLoss()

    def _train_one_epoch(
        self,
        data_loader: DataLoader,
        optimizer: torch.optim.Optimizer,
        progress_bar: tqdm
    ):
        self.model.train()
        total_loss = 0.0

        for _, (images, labels) in enumerate(data_loader):
            images, labels = images.to(self.device), labels.to(self.device)
            optimizer.zero_grad()
            outputs = self.model(images)
            loss = self.criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            current_loss = loss.item()
            total_loss += current_loss
            
            progress_bar.set_postfix(loss=f"{current_loss:.6f}")
            progress_bar.update(1)

        return total_loss / len(data_loader)

    def _evaluate(self, data_loader: DataLoader, progress_bar: Optional[tqdm] = None):
        self.model.eval()
        total_loss = 0.0
        correct, total = 0, 0

        with torch.no_grad():
            for images, labels in data_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                total_loss += loss.item()

                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                if progress_bar:
                    progress_bar.update(1)

        avg_loss = total_loss / len(data_loader)
        accuracy = correct / total
        return avg_loss, accuracy

    def train(
        self,
        train_loader: DataLoader,
        valid_loader: DataLoader,
        optimizer: torch.optim.Optimizer,
        epochs: int,
        is_pretext: bool = False,
        patience: int = 10,
        save_top_k: int = 3,
        save_path: Union[str, Path] = Path("./checkpoints"),
    ):
        task_name = "pretext" if is_pretext else "downstream"
        tqdm.write(f"--- Starting Task: {task_name.capitalize()} ---")

        if hasattr(self.model, 'is_pretext'):
            self.model.is_pretext = is_pretext

        save_path = Path(save_path)
        save_path.mkdir(parents=True, exist_ok=True)

        best_val_loss = float('inf')
        patience_counter = 0
        top_k_checkpoints = []

        train_length, valid_length = len(train_loader), len(valid_loader)
        epochs_progress = tqdm(range(epochs), desc="Epochs", position=0, leave=True)

        with tqdm(total=train_length, desc="Training", position=1, leave=False) as train_progress, \
             tqdm(total=valid_length, desc="Validation", position=2, leave=False) as valid_progress:

            for epoch in epochs_progress:
                train_progress.reset()
                valid_progress.reset()

                # Training & Validation
                train_loss = self._train_one_epoch(train_loader, optimizer, train_progress)
                valid_loss, valid_acc = self._evaluate(valid_loader, valid_progress)

                final_log = (
                    f"Epoch [{epoch+1:>{len(str(epochs))}}/{epochs}] | "
                    f"Train Loss: {train_loss:.6f} | "
                    f"Valid Loss: {valid_loss:.6f} | "
                    f"Valid Acc: {valid_acc:.4%}"
                )
                tqdm.write(final_log)

                # Checkpoint saving logic
                if valid_loss < best_val_loss:
                    best_val_loss = valid_loss
                    patience_counter = 0
                    
                    ckpt_path = save_path / f"{task_name}_epoch_{epoch+1}_loss_{valid_loss:.6f}.pt"
                    torch.save(self.model.state_dict(), ckpt_path)
                    top_k_checkpoints.append((valid_loss, ckpt_path))
                    top_k_checkpoints.sort(key=lambda x: x[0])

                    if len(top_k_checkpoints) > save_top_k:
                        worst_checkpoint_path = top_k_checkpoints.pop()[1]
                        if worst_checkpoint_path.exists():
                            worst_checkpoint_path.unlink()
                
                else: # Early stopping logic
                    patience_counter += 1

                if patience_counter >= patience:
                    tqdm.write(f"\nEarly stopping at epoch {epoch+1} as validation loss did not improve for {patience} epochs.")
                    break

    def test(self, test_loader: DataLoader):
        if hasattr(self.model, 'is_pretext'):
            self.model.is_pretext = False

        test_progress = tqdm(test_loader, desc="Testing", leave=True)
        test_loss, test_acc = self._evaluate(test_loader, progress_bar=test_progress)
        test_progress.close()
        print(f"\nTest Results | Test Loss: {test_loss:.6f} | Test Acc: {test_acc:.6%}")

    def to(self, device: torch.device):
        self.model.to(device)
        self.device = device
        return self

## Train Model

### Train Supervised Model

In [None]:
supervised_model = Model()

# supervised_model = Model.from_checkpoint()

supervised_trainer = Trainer(supervised_model).to(device)

In [None]:
SL_DOWNSTREAM_LR = 1e-3

sl_optimizer = AdamW(
    supervised_model.parameters(),
    lr=SL_DOWNSTREAM_LR,
    weight_decay=1e-4
)

In [None]:
supervised_trainer.train(
    train_loader=labeled_loader,
    valid_loader=valid_loader,
    optimizer=sl_optimizer,
    epochs=50,
    is_pretext=False,
    patience=5,
    save_top_k=3,
    save_path="./checkpoints/supervised"
)

### Train Self-Supervised Model

#### A. Pretext

In [None]:
from huggingface_hub import hf_hub_download

# Download pre-trained weights from Hugging Face Hub
REPO_ID = "haesol1013/2025-2_dAiv-SSL-Lecture"
FILE_NAME = "pretext_pretrained.pt"
pretrained_weights_path = hf_hub_download(repo_id=REPO_ID, filename=FILE_NAME)

In [None]:
# self_supervised_model = Model()

self_supervised_model = Model().from_checkpoint(pretrained_weights_path)

self_supervised_trainer = Trainer(self_supervised_model).to(device)

In [None]:
SSL_PRETEXT_LR = 1e-3

ssl_pretext_optim = AdamW(
    self_supervised_model.parameters(),
    lr=SSL_PRETEXT_LR,
    weight_decay=1e-4
)

In [None]:
PRETEXT_CKPT_PATH = Path("./checkpoints/self_supervised/pretext")

self_supervised_trainer.train(
    train_loader=unlabeled_loader,
    valid_loader=unlabeled_valid_loader,
    optimizer=ssl_pretext_optim,
    epochs=30,
    is_pretext=True,
    patience=5,
    save_top_k=3,
    save_path=PRETEXT_CKPT_PATH
)

#### B. Downstream

In [None]:
# Load the best checkpoint based on validation loss
checkpoints = []
for ckpt_file in PRETEXT_CKPT_PATH.glob("*.pt"):
    loss_str = ckpt_file.stem.split("_loss_")[-1]
    try:
        loss_val = float(loss_str)
        checkpoints.append((loss_val, ckpt_file))
    except ValueError:
        continue

In [None]:
# If no checkpoints found, initialize a new model
if checkpoints:
    best_loss, best_ckpt = min(checkpoints, key=lambda x: x[0])
    self_supervised_model = Model.from_checkpoint(best_ckpt)
else:
    self_supervised_model = Model()
    print("INFO: No pretext checkpoints found. Training from scratch.")

self_supervised_trainer = Trainer(self_supervised_model).to(device)

In [None]:
SSL_DOWNSTREAM_LR = 1e-4

ssl_downstream_optimizer = AdamW(
    self_supervised_model.parameters(),
    lr=SSL_DOWNSTREAM_LR,
    weight_decay=1e-4
)

In [None]:
self_supervised_trainer.train(
    train_loader=labeled_loader,
    valid_loader=valid_loader,
    optimizer=ssl_downstream_optimizer,
    epochs=30,
    is_pretext=False,
    patience=7,
    save_top_k=3,
    save_path="./checkpoints/self_supervised/downstream"
)

### Model Evaluation

In [None]:
print("-- Supervised Model ---")
supervised_trainer.test(test_loader)

In [None]:
print("--- Self-supervised Model ---")
self_supervised_trainer.test(test_loader)