In [1]:
from transformers import ViTMAEConfig, ViTMAEForPreTraining, AutoImageProcessor
import torch
from torch import nn
from copy import deepcopy

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
config = ViTMAEConfig()
config

ViTMAEConfig {
  "attention_probs_dropout_prob": 0.0,
  "decoder_hidden_size": 512,
  "decoder_intermediate_size": 2048,
  "decoder_num_attention_heads": 16,
  "decoder_num_hidden_layers": 8,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 224,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "mask_ratio": 0.75,
  "model_type": "vit_mae",
  "norm_pix_loss": false,
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": true,
  "transformers_version": "4.44.2"
}

In [3]:
class MultiDecoderQuantileViTMAE(ViTMAEForPreTraining):
    def __init__(self, config, quantiles=(0.1, 0.5, 0.9)):
        super().__init__(config)
        self.quantiles = quantiles
        self.num_quantiles = len(quantiles)

        # Separate decoders for each quantile
        self.decoders = nn.ModuleList([
            deepcopy(self.decoder) for _ in range(self.num_quantiles)
            ])
        
        self.decoder = None

    def forward_loss(self, pixel_values, preds, mask, interpolate_pos_encoding: bool = False):
        """
        Custom loss for quantile regression with separate decoders.
        Args:
            pixel_values: Original pixel values.
            preds: List of predicted outputs from each decoder.
            mask: Binary mask indicating which patches were masked.
        Returns:
            Combined quantile regression loss.
        """
        target = self.patchify(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)

        quantile_losses = []
        for i, quantile in enumerate(self.quantiles):
            pred = preds[i]
            diff = target - pred  # (batch, num_patches, patch_dim)
            quantile_loss = torch.max(quantile * diff, (quantile - 1) * diff)  # Pinball loss
            quantile_losses.append(quantile_loss.mean(dim=-1))  # Average over patch_dim

        quantile_loss = sum(quantile_losses) / len(self.quantiles)  # Average over quantiles
        loss = (quantile_loss * mask).sum() / mask.sum()  # Apply mask

        return loss

    def forward(
        self,
        pixel_values: torch.FloatTensor = None,
        noise: torch.FloatTensor = None,
        head_mask: torch.FloatTensor = None,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = None,
        interpolate_pos_encoding: bool = False,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Forward pass through the encoder
        outputs = self.vit(
            pixel_values,
            noise=noise,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            interpolate_pos_encoding=interpolate_pos_encoding,
        )

        latent = outputs.last_hidden_state
        ids_restore = outputs.ids_restore
        mask = outputs.mask

        # Forward pass through each decoder
        decoder_outputs = [
            decoder(latent, ids_restore, interpolate_pos_encoding=interpolate_pos_encoding)
            for decoder in self.decoders
        ]

        preds = [decoder_output.logits for decoder_output in decoder_outputs]  # Predictions from each decoder

        # Calculate combined loss
        loss = self.forward_loss(pixel_values, preds, mask, interpolate_pos_encoding=interpolate_pos_encoding)

        if not return_dict:
            return (loss, preds, mask, ids_restore) + outputs[2:]

        return {
            "loss": loss,
            "preds": preds,
            "mask": mask,
            "ids_restore": ids_restore,
            "hidden_states": outputs.hidden_states,
            "attentions": outputs.attentions,
        }

In [4]:
model = MultiDecoderQuantileViTMAE(config)

In [5]:
model.to('cuda:0')

MultiDecoderQuantileViTMAE(
  (vit): ViTMAEModel(
    (embeddings): ViTMAEEmbeddings(
      (patch_embeddings): ViTMAEPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
    )
    (encoder): ViTMAEEncoder(
      (layer): ModuleList(
        (0): ViTMAELayer(
          (attention): ViTMAEAttention(
            (attention): ViTMAESelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTMAESelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTMAEIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias

In [6]:
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

train_dataset = datasets.CIFAR100(root='./data', train=True, transform=transform_train, download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)

Files already downloaded and verified


In [7]:
device = 'cuda:0'
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

In [8]:
model.train()
total_loss = []

In [9]:
for batch_idx, (images, _) in enumerate(train_loader):
    images = images.to(device)
    epoch_loss = 0.0
    
    outputs = model(pixel_values=images)
    loss = outputs["loss"]
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    epoch_loss += loss.item()
    if batch_idx % 10 == 0:
        print(f"Epoch {0 + 1}/{10}, Step {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item()}")
    #break
    

Epoch 1/10, Step 1/391, Loss: 0.26937270164489746
Epoch 1/10, Step 11/391, Loss: 0.14104987680912018
Epoch 1/10, Step 21/391, Loss: 0.11876505613327026
Epoch 1/10, Step 31/391, Loss: 0.11275138705968857
Epoch 1/10, Step 41/391, Loss: 0.10808661580085754
Epoch 1/10, Step 51/391, Loss: 0.09891676157712936
Epoch 1/10, Step 61/391, Loss: 0.09892072528600693
Epoch 1/10, Step 71/391, Loss: 0.09659170359373093
Epoch 1/10, Step 81/391, Loss: 0.09140793234109879
Epoch 1/10, Step 91/391, Loss: 0.0943424180150032
Epoch 1/10, Step 101/391, Loss: 0.08375376462936401
Epoch 1/10, Step 111/391, Loss: 0.08868476748466492
Epoch 1/10, Step 121/391, Loss: 0.08531692624092102
Epoch 1/10, Step 131/391, Loss: 0.08129403740167618
Epoch 1/10, Step 141/391, Loss: 0.07727376371622086
Epoch 1/10, Step 151/391, Loss: 0.07531385123729706
Epoch 1/10, Step 161/391, Loss: 0.07093816250562668
Epoch 1/10, Step 171/391, Loss: 0.06670038402080536
Epoch 1/10, Step 181/391, Loss: 0.06402884423732758
Epoch 1/10, Step 191/391

In [11]:
model = ViTMAEForPreTraining(config)

In [None]:
encoder = model.vit
decoder = model.decoder

In [4]:
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

In [5]:
train_dataset = datasets.CIFAR100(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.CIFAR100(root='./data', train=False, transform=transform, download=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified


In [7]:
batch_size=64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)