In [None]:
!pip install einops

In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from einops import rearrange
import matplotlib.pyplot as plt

In [None]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # (B, C, H, W)
        x = rearrange(x, 'b c h w -> b (h w) c')  # (B, N, D)
        return x

In [5]:
class MAE(nn.Module):
    def __init__(self, encoder, decoder, embed_dim=768, decoder_dim=512, patch_size=16, img_size=224):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, 3, embed_dim)
        self.encoder = encoder
        self.decoder = decoder
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))
        self.decoder_pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches, decoder_dim))
        self.reconstruction_head = nn.Linear(decoder_dim, patch_size**2 * 3)

    def forward(self, x, mask_ratio=0.75):
        B = x.shape[0]
        x = self.patch_embed(x)
        N, D = x.shape[1], x.shape[2]

        num_mask = int(mask_ratio * N)
        noise = torch.rand(B, N, device=x.device)
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)
        ids_keep = ids_shuffle[:, :-num_mask]
        x_visible = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        x_encoded = self.encoder(x_visible)
        decoder_tokens = torch.zeros(B, N, self.mask_token.size(-1), device=x.device)
        decoder_tokens.scatter_(1, ids_keep.unsqueeze(-1).repeat(1, 1, self.mask_token.size(-1)), x_encoded)
        decoder_tokens = decoder_tokens + self.decoder_pos_embed

        x_decoded = self.decoder(decoder_tokens)
        x_reconstructed = self.reconstruction_head(x_decoded)

        return x_reconstructed, ids_restore


In [None]:
def patchify(imgs, patch_size=16):
    p = patch_size
    h, w = imgs.shape[2] // p, imgs.shape[3] // p
    patches = imgs.unfold(2, p, p).unfold(3, p, p)
    patches = patches.contiguous().view(imgs.shape[0], 3, h, w, p, p)
    patches = patches.permute(0, 2, 3, 1, 4, 5).reshape(imgs.shape[0], h * w, -1)
    return patches

def mae_loss(pred, target, mask):
    loss = (pred - target) ** 2
    loss = loss.mean(dim=-1)
    loss = (loss * mask).sum() / mask.sum()
    return loss


In [None]:
from torchvision.datasets import ImageFolder
from glob import glob
import os

train_dirs = sorted(glob("/content/ssl_dataset/train.X*"))

all_train_data = []
for d in train_dirs:
    ds = ImageFolder(d, transform=transform)
    all_train_data.extend(ds.samples)

# Create a new ImageFolder with combined dataset
from torch.utils.data import Dataset
class CombinedImageNet100(Dataset):
    def __init__(self, samples, transform=None):
        self.samples = samples
        self.transform = transform

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = datasets.folder.default_loader(path)
        if self.transform:
            img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.samples)

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor()
])

dataset = CombinedImageNet100(all_train_data, transform=transform)
loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=2)
# Encoder and Decoder
encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(768, 12), num_layers=6)
decoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(512, 8), num_layers=4)
model = MAE(encoder, decoder).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


In [None]:
for epoch in range(3):
    for imgs, _ in loader:
        imgs = imgs.cuda()
        preds, ids_restore = model(imgs, mask_ratio=0.75)
        target = patchify(imgs)

        N = preds.shape[1]
        mask = torch.ones_like(preds[:, :, 0])
        mask.scatter_(1, ids_restore[:, :-int(0.25 * N)], 0)

        loss = mae_loss(preds, target.cuda(), mask.cuda())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}")