# SSL


- Masked ResNet18 Sequential Training
- masking ratio = 0.75
- batch size: labeled 32, unlabeled 128, valid 128
- simple classifier / deeper decoder
<br/><br/>
- optimizer: supervised, downstream, pretext -> Adam
- weight decay: supervised 0, downstream 0, pretext 1e-5
- learning rate: supervised 1e-3, downstream 1e-3, pretext 1e-3
<br/><br/>
- pretext epochs: 25
- downstream,supervised epochs: 10

## Import Libraries

In [None]:
from typing import Optional, Callable
from enum import Enum

import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import models
from torchvision import transforms
from torchvision.datasets import CIFAR10

### Check GPU


In [None]:
!nvidia-smi

In [None]:
# Set CUDA Device Number 0~7
DEVICE_NUM = 0

device = torch.device("cpu")
if torch.cuda.is_available():
    torch.cuda.set_device(DEVICE_NUM)
    device = torch.device("cuda")
print("INFO: Using device -", device)

## Load Dataset

In [None]:
class DataType(Enum):
    LABELED_TRAIN = 0
    UNLABELED_TRAIN = 1
    VALID = 2
    TEST = 3

In [None]:
class CIFAR10Dataset(CIFAR10):
    def __init__(self, data_type: DataType, transform: Optional[Callable] = None, validation_split: float = 0.1, labeled_split: float = 0.1):

        super().__init__('./data', train=(data_type != DataType.TEST), transform=transform, download=True)

        if data_type != DataType.TEST:
            np.random.seed(42)
            indices = np.random.permutation(len(self.data))

            # validation index
            val_size = int(len(self.data) * validation_split)
            val_indices = indices[:val_size]

            # labeled, unlabeled index
            train_indices = indices[val_size:]
            labeled_size = int(len(train_indices) * labeled_split)
            labeled_indices = train_indices[:labeled_size]
            unlabeled_indices = train_indices[labeled_size:]

            if data_type == DataType.LABELED_TRAIN:
                self.data = self.data[labeled_indices]
                self.targets = [self.targets[i] for i in labeled_indices]

            elif data_type == DataType.UNLABELED_TRAIN:
                self.data = self.data[unlabeled_indices]
                self.targets = [-1] * len(unlabeled_indices)
                # self.targets = torch.full_like(torch.from_numpy(unlabeled_indices), -1)

            else:
                self.data = self.data[val_indices]
                self.targets = [self.targets[i] for i in val_indices]

In [None]:
IMG_SIZE = (32, 32)
IMG_NORM = dict(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)

resizer = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(**IMG_NORM)
])

In [None]:
labeled_data = CIFAR10Dataset(DataType.LABELED_TRAIN, transform=resizer)
unlabeled_data = CIFAR10Dataset(DataType.UNLABELED_TRAIN, transform=resizer)
valid_data = CIFAR10Dataset(DataType.VALID, transform=resizer)
test_data = CIFAR10Dataset(DataType.TEST, transform=resizer)

## DataLoader

In [None]:
# Set Batch Size
class BatchSize:
    labeled: int = 32
    unlabeled: int = 128
    valid: int = 128

batch_config = BatchSize()

In [None]:
labeled_loader = DataLoader(labeled_data, batch_size=batch_config.labeled, shuffle=True)
unlabeled_loader = DataLoader(unlabeled_data, batch_size=batch_config.unlabeled, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=batch_config.valid, shuffle=True)
test_loader = DataLoader(test_data, shuffle=True)

## Define Model

In [None]:
class Model(nn.Module):
    def __init__(self, supervised: bool = True, num_classes=10):
        super().__init__()

        self.supervised = supervised
        self.patch_size = 4
        self.mask_ratio = 0.6
        self.embed_dim = 512
        self.in_channels = 3

        backbone = models.resnet18()
        backbone.conv1 = nn.Conv2d(self.in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)
        backbone.maxpool = nn.Identity()
        self.backbone = nn.Sequential(*list(backbone.children())[:-2])

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(self.embed_dim, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 3, 3, 1, 1),
            nn.Tanh()
        )

        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(self.embed_dim, num_classes)
        )

    def random_mask(self, x):
        B, C, H, W = x.shape

        num_patches_h = H // self.patch_size
        num_patches_w = W // self.patch_size
        total_patches = num_patches_h * num_patches_w

        num_masked = int(total_patches * self.mask_ratio)

        mask = torch.ones(B, 1, H, W, device=x.device)

        for b in range(B):
            rand_indices = torch.randperm(total_patches, device=x.device)[:num_masked]

            patch_h = rand_indices // num_patches_w
            patch_w = rand_indices % num_patches_w

            for h, w in zip(patch_h, patch_w):
                h_start = h * self.patch_size
                w_start = w * self.patch_size
                mask[b, :, h_start:h_start+self.patch_size, w_start:w_start+self.patch_size] = 0

        masked_x = x * mask

        return masked_x, mask

    def forward_encoder(self, x):
        x = self.backbone(x)
        return x

    def forward_decoder(self, x):
        x = self.decoder(x)
        return x

    def forward_classifier(self, x):
        x = self.classifier(x)
        return x

    def forward(self, x):
        x = self.forward_encoder(x)
        output = self.forward_classifier(x)
        return output

    def forward_pretext(self, x):
        masked_x, mask = self.random_mask(x)
        features = self.forward_encoder(masked_x)
        reconstructed = self.forward_decoder(features)

        return reconstructed, masked_x, mask

In [None]:
supervised_model = Model(supervised=True).to(device)

In [None]:
self_supervised_model = Model(supervised=False).to(device)

## Model Training

In [None]:
class Trainer:
    def __init__(self, labeled_loader, valid_loader, test_loader, unlabeled_loader=None, supervised=True):
        self.supervised = supervised
        self.labeled_loader = labeled_loader
        self.unlabeled_loader = unlabeled_loader
        self.valid_loader = valid_loader
        self.test_loader = test_loader

        self.classification_loss = nn.CrossEntropyLoss()
        self.reconstruction_loss = nn.MSELoss() if not supervised else None

    def _train_step(self, model, batch, optimizer, device, pretext=False):
        if pretext:
            images, _ = batch
            images = images.to(device)
            reconstructed, _, _ = model.forward_pretext(images)
            loss = self.reconstruction_loss(reconstructed, images)
        else:
            images, labels = batch
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = self.classification_loss(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return loss.item()

    def _train_epoch(self, model, optimizer, device, data_loader, epochs, pretext=False):
        for epoch in range(epochs):
            model.train()
            total_loss = 0.0

            for batch in data_loader:
                loss = self._train_step(model, batch, optimizer, device, pretext)
                total_loss += loss

            avg_train_loss = total_loss / len(data_loader)
            avg_valid_loss = self._evaluate(model, self.valid_loader, device, pretext)

            print(f'Epoch {epoch+1}/{epochs}, 'f'Train Loss: {avg_train_loss:.4f}, 'f'Valid Loss: {avg_valid_loss:.4f}')

    def _evaluate(self, model, data_loader, device, pretext=False, return_accuracy=False):
        model.eval()
        total_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for batch in data_loader:
                if pretext:
                    images, _ = batch
                    images = images.to(device)
                    reconstructed, _, _ = model.forward_pretext(images)
                    loss = self.reconstruction_loss(reconstructed, images)
                else:
                    images, labels = batch
                    images = images.to(device)
                    labels = labels.to(device)
                    outputs = model(images)
                    loss = self.classification_loss(outputs, labels)

                    if return_accuracy:  # evaluate
                        _, predicted = torch.max(outputs.data, 1)
                        total += labels.size(0)
                        correct += (predicted == labels).sum().item()

                total_loss += loss.item()

        avg_loss = total_loss / len(data_loader)

        if return_accuracy:
            accuracy = 100 * correct / total
            return avg_loss, accuracy  # evaluate
        return avg_loss  # validate

    def train(self, model, device, optimizer, pre_optimizer=None, pretext_epochs=15, downstream_epochs=10):
        model.to(device)
        self.pre_optimizer = pre_optimizer

        if self.supervised:
            # Supervised
            self._train_epoch(model, optimizer, device, self.labeled_loader, downstream_epochs, pretext=False)
        else:
            # Selfsupervised
            print("Pretext Task Training")
            self._train_epoch(model, pre_optimizer, device, self.unlabeled_loader, pretext_epochs, pretext=True)

            print("\n" + "Downstream Task Training")
            self._train_epoch(model, optimizer, device, self.labeled_loader, downstream_epochs, pretext=False)

    def evaluate(self, model, device):
        model.to(device)

        avg_loss, accuracy = self._evaluate(model, self.test_loader, device, pretext=False, return_accuracy=True)

        print(f'Test Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')

In [None]:
supervised_trainer = Trainer(
    labeled_loader,
    valid_loader,
    test_loader,
    supervised=True
)
self_supervised_trainer = Trainer(
    labeled_loader,
    valid_loader,
    test_loader,
    unlabeled_loader,
    supervised=False
)

### Set Optimizer

In [None]:
sl_optimizer = torch.optim.Adam(
    supervised_model.parameters(),
    lr=1e-3,
    weight_decay=0
)

In [None]:
ssl_optimizer_pre = torch.optim.Adam(
    self_supervised_model.parameters(),
    lr=1e-3,
    weight_decay=1e-5
)
ssl_optimizer_down = torch.optim.Adam(
    self_supervised_model.parameters(),
    lr=1e-3,
    weight_decay=0
)

### Train Model

In [None]:
supervised_trainer.train(
    model=supervised_model,
    device=device,
    optimizer=sl_optimizer,
    downstream_epochs=10
)
torch.save(supervised_model.state_dict(), 'supervised_model_sequential.pth')

In [None]:
self_supervised_trainer.train(
    model=self_supervised_model,
    device=device,
    optimizer=ssl_optimizer_down,
    pre_optimizer=ssl_optimizer_pre,
    pretext_epochs=20,
    downstream_epochs=10
)
torch.save(self_supervised_model.state_dict(), 'self_supervised_model_sequential.pth')

### Model Evaluation

In [None]:
print("Supervised Model")
supervised_trainer.evaluate(supervised_model, device)

In [None]:
print("\nSelf-supervised Model")
self_supervised_trainer.evaluate(self_supervised_model, device)