In [1]:
import os
import torch
from transformers import ViTMAEConfig
from lightning.pytorch.trainer.trainer import Trainer
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch import seed_everything
from runner_utils import start_of_a_run
from lightning.pytorch.strategies import DDPStrategy
from src.models.MultiDecoderQuantileViTMAE import MultiDecoderQuantileViTMAELightning
from src.datamodules.cifar_100 import DataModule as CIFAR100DataModule
import warnings
warnings.filterwarnings("ignore")

from transformers import ViTMAEConfig, ViTMAEForPreTraining, AutoImageProcessor
import torch
from torch import nn
from copy import deepcopy

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [None]:
model_config = ViTMAEConfig()
quantiles=[0.1, 0.5, 0.9]

In [None]:
model = MultiDecoderQuantileViTMAELightning(config=model_config, quantiles=quantiles, learning_rate=1e-4)

In [None]:
config = ViTMAEConfig()
config

In [3]:
class MultiDecoderQuantileViTMAE(ViTMAEForPreTraining):
    def __init__(self, config, quantiles=(0.1, 0.5, 0.9)):
        super().__init__(config)
        self.quantiles = quantiles
        self.num_quantiles = len(quantiles)

        # Separate decoders for each quantile
        self.decoders = nn.ModuleList([
            deepcopy(self.decoder) for _ in range(self.num_quantiles)
            ])
        
        self.decoder = None

    def forward_loss(self, pixel_values, preds, mask, interpolate_pos_encoding: bool = False):
        """
        Custom loss for quantile regression with separate decoders.
        Args:
            pixel_values: Original pixel values.
            preds: List of predicted outputs from each decoder.
            mask: Binary mask indicating which patches were masked.
        Returns:
            Combined quantile regression loss.
        """
        target = self.patchify(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)

        quantile_losses = []
        for i, quantile in enumerate(self.quantiles):
            pred = preds[i]
            diff = target - pred  # (batch, num_patches, patch_dim)
            quantile_loss = torch.max(quantile * diff, (quantile - 1) * diff)  # Pinball loss
            quantile_losses.append(quantile_loss.mean(dim=-1))  # Average over patch_dim

        quantile_loss = sum(quantile_losses) / len(self.quantiles)  # Average over quantiles
        loss = (quantile_loss * mask).sum() / mask.sum()  # Apply mask

        return loss

    def forward(
        self,
        pixel_values: torch.FloatTensor = None,
        noise: torch.FloatTensor = None,
        head_mask: torch.FloatTensor = None,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = None,
        interpolate_pos_encoding: bool = False,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Forward pass through the encoder
        outputs = self.vit(
            pixel_values,
            noise=noise,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            interpolate_pos_encoding=interpolate_pos_encoding,
        )

        latent = outputs.last_hidden_state
        ids_restore = outputs.ids_restore
        mask = outputs.mask

        # Forward pass through each decoder
        decoder_outputs = [
            decoder(latent, ids_restore, interpolate_pos_encoding=interpolate_pos_encoding)
            for decoder in self.decoders
        ]

        preds = [decoder_output.logits for decoder_output in decoder_outputs]  # Predictions from each decoder

        # Calculate combined loss
        loss = self.forward_loss(pixel_values, preds, mask, interpolate_pos_encoding=interpolate_pos_encoding)

        if not return_dict:
            return (loss, preds, mask, ids_restore) + outputs[2:]

        return {
            "loss": loss,
            "preds": preds,
            "mask": mask,
            "ids_restore": ids_restore,
            "hidden_states": outputs.hidden_states,
            "attentions": outputs.attentions,
        }

In [4]:
model = MultiDecoderQuantileViTMAE(config)

In [None]:
model.to('cuda:0')

In [None]:
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

train_dataset = datasets.CIFAR100(root='./data', train=True, transform=transform_train, download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)

test_dataset = datasets.CIFAR100(root='./data', train=False, transform=transform_train, download=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)

In [None]:
len(train_dataset), len(test_dataset)

In [7]:
device = 'cuda:0'
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

In [8]:
model.train()
total_loss = []

In [None]:
for batch_idx, (images, _) in enumerate(train_loader):
    images = images.to(device)
    epoch_loss = 0.0
    
    outputs = model(pixel_values=images)
    loss = outputs["loss"]
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    epoch_loss += loss.item()
    if batch_idx % 10 == 0:
        print(f"Epoch {0 + 1}/{10}, Step {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item()}")
    #break
    

In [11]:
model = ViTMAEForPreTraining(config)

In [None]:
encoder = model.vit
decoder = model.decoder

In [4]:
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

In [None]:
train_dataset = datasets.CIFAR100(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.CIFAR100(root='./data', train=False, transform=transform, download=True)

In [7]:
batch_size=64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)