<a href="https://colab.research.google.com/github/namesarnav/SimMIM/blob/main/ViT_MIM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Masked Image Modeling using ViT-B on a CIFAR-100 dataset

In [22]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR100
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import matplotlib.pyplot as plt
from tqdm import tqdm

In [23]:
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Hyperparameters
BATCH_SIZE = 2048
EPOCHS = 15
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 0.01
IMAGE_SIZE = 32  # CIFAR images are 32x32
PATCH_SIZE = 4  # Size of patches for ViT
NUM_CLASSES = 100  # CIFAR-100 has 100 classes
MASK_RATIO = 0.75  # Portion of patches to mask
EMBED_DIM = 384  # Embedding dimension
DEPTH = 6  # Number of transformer layers
NUM_HEADS = 6  # Number of attention heads
MLP_RATIO = 4.0  # Ratio for MLP hidden dim

Using device: cuda


In [24]:
# Transformer components
class MLP(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(dim=dim, hidden_dim=mlp_hidden_dim, dropout=drop)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


In [25]:
# Masked Image Modeling ViT
class MaskedVisionTransformer(nn.Module):
    def __init__(
        self,
        img_size=IMAGE_SIZE,
        patch_size=PATCH_SIZE,
        in_chans=3,
        embed_dim=EMBED_DIM,
        depth=DEPTH,
        num_heads=NUM_HEADS,
        mlp_ratio=MLP_RATIO,
        qkv_bias=True,
        drop_rate=0.1,
        attn_drop_rate=0.1
    ):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.embed_dim = embed_dim

        # Number of patches
        self.num_patches = (img_size // patch_size) ** 2

        # Patch embedding
        self.patch_embed = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
            nn.Linear(patch_size * patch_size * in_chans, embed_dim),
        )

        # CLS token and positional embedding
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))

        # Transformer blocks
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
                drop=drop_rate, attn_drop=attn_drop_rate
            )
            for _ in range(depth)
        ])

        # Norm layer
        self.norm = nn.LayerNorm(embed_dim)



        # Initialize weights
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

        # Simplified decoder for SimMIM
        # In SimMIM, the decoder is intentionally kept very lightweight - just a linear projection
        self.decoder = nn.Sequential(
            nn.Linear(embed_dim, patch_size * patch_size * in_chans),
            Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)',
                      h=img_size//patch_size, w=img_size//patch_size,
                      p1=patch_size, p2=patch_size)
        )

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def random_masking(self, x, mask_ratio):
        """
        Perform random masking following SimMIM approach - independent random masking.

        Args:
            x: [N, L, D], sequence
            mask_ratio: percentage of tokens to be masked

        Returns:
            x_unchanged: [N, L, D], the original sequence (not actually masked)
            mask: [N, L], mask -> 0 is keep, 1 is mask (matches SimMIM notation)
            ids_restore: None (not needed for SimMIM approach)
        """
        N, L, D = x.shape  # batch, length, dim

        # Generate random mask - following SimMIM's approach of independent random masking
        # In SimMIM, mask=1 means the token is masked, mask=0 means it's kept
        mask = torch.bernoulli(torch.ones(N, L, device=x.device) * mask_ratio)

        # In SimMIM, we don't actually need to mask the sequence for the encoder
        # The entire sequence is processed, and the loss is only applied on masked tokens

        return x, mask, None

    def forward_encoder(self, img, mask_ratio):
        # Convert image to patches
        patches = self.patch_embed(img)  # [B, num_patches, embed_dim]

        # Add position embeddings (exclude CLS token position at this point)
        patches = patches + self.pos_embed[:, 1:, :]

        # Apply masking strategy (following SimMIM, this doesn't actually mask inputs)
        patches, mask, _ = self.random_masking(patches, mask_ratio)

        # Append CLS token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(patches.shape[0], -1, -1)
        x = torch.cat((cls_tokens, patches), dim=1)

        # Apply Transformer blocks
        for blk in self.blocks:
            x = blk(x)

        # Apply norm
        x = self.norm(x)

        return x, mask

    def forward_decoder(self, x):
        """
        SimMIM-style lightweight decoder - just project encoded patches back to pixels
        """
        # Exclude CLS token
        x = x[:, 1:, :]

        # Following SimMIM's lightweight decoder approach
        # Just a simple linear projection from embeddings back to pixels
        pixels = self.decoder(x)

        return pixels

    def forward(self, imgs, mask_ratio=MASK_RATIO):
        # Encoding with masking strategy (SimMIM style)
        latent, mask = self.forward_encoder(imgs, mask_ratio)

        # Decoding to reconstruct original image with simplified decoder
        pred = self.forward_decoder(latent)

        return pred, mask

In [26]:
# Dataset preparation with transformations
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])

# Load CIFAR-100 dataset
trainset = CIFAR100(root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

testset = CIFAR100(root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)


# Create model
model = MaskedVisionTransformer().to(device)

# L1 Loss for Masked Image Modeling
criterion = nn.L1Loss()

# Optimizer with weight decay
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# Learning rate scheduler
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [27]:
# Training function following SimMIM approach
def train_epoch(model, dataloader, criterion, optimizer, device, epoch):
    model.train()
    running_loss = 0.0
    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch+1}/{EPOCHS}")

    for i, (images, _) in progress_bar:
        images = images.to(device)

        # Forward pass with masking
        reconstructed, mask = model(images)

        # Calculate loss only on masked patches - SimMIM style
        # In SimMIM, mask=1 means the token is masked
        loss = compute_simmim_loss(criterion, reconstructed, images, mask)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update statistics
        running_loss += loss.item()
        progress_bar.set_postfix({"Loss": f"{running_loss/(i+1):.4f}"})

    return running_loss / len(dataloader)

# SimMIM loss computation - only on masked patches
def compute_simmim_loss(criterion, pred, target, mask):
    """
    Compute L1 loss only on masked patches as in SimMIM

    Args:
        criterion: loss function (L1Loss)
        pred: [B, C, H, W] reconstructed images
        target: [B, C, H, W] original images
        mask: [B, L] binary mask (1 = masked, 0 = kept)

    Returns:
        loss: scalar
    """
    # Reshape images to match mask dimensions
    B, C, H, W = pred.shape
    patch_size = PATCH_SIZE

    # Reshape predictions and targets to patch-level
    # [B, C, H, W] -> [B, L, C*P*P] where L is number of patches, P is patch size
    pred = rearrange(pred, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1=patch_size, p2=patch_size)
    target = rearrange(target, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1=patch_size, p2=patch_size)

    # Apply mask: only compute loss on masked patches (where mask == 1)
    loss = criterion(pred[mask.bool()], target[mask.bool()])

    return loss


# Function to visualize original and reconstructed images with SimMIM masking
def visualize_reconstruction(model, dataloader, device, num_images=5):
    model.eval()
    images, _ = next(iter(dataloader))
    images = images[:num_images].to(device)

    with torch.no_grad():
        reconstructed, mask = model(images)

    # Convert mask for visualization
    # In SimMIM, mask=1 means the token is masked, so we need to reshape it for display
    mask_vis = mask.reshape(mask.shape[0], 8, 8)  # Assuming 32x32 image with 4x4 patches
    mask_vis = mask_vis.unsqueeze(1).repeat(1, 3, 1, 1)  # [B, 3, 8, 8]
    # Upsample mask to match image dimensions
    mask_vis = F.interpolate(mask_vis.float(), size=(32, 32), mode='nearest')

    # Convert to CPU for plotting
    images = images.cpu()
    reconstructed = reconstructed.cpu()
    mask_vis = mask_vis.cpu()

    # Denormalize images
    mean = torch.tensor([0.5071, 0.4867, 0.4408]).view(1, 3, 1, 1)
    std = torch.tensor([0.2675, 0.2565, 0.2761]).view(1, 3, 1, 1)

    images = images * std + mean
    reconstructed = reconstructed * std + mean

    # Clip values to [0, 1] range
    images = torch.clamp(images, 0, 1)
    reconstructed = torch.clamp(reconstructed, 0, 1)

    # Create masked image by combining original and reconstructed based on mask
    masked_imgs = images.clone()
    for i in range(num_images):
        # Apply mask - use reconstructed image where mask=1 (masked regions)
        mask_image = mask_vis[i]
        masked_imgs[i] = images[i] * (1 - mask_image) + reconstructed[i] * mask_image

    # Plot
    fig, axs = plt.subplots(3, num_images, figsize=(15, 8))

    for i in range(num_images):
        # Original
        axs[0, i].imshow(images[i].permute(1, 2, 0))
        axs[0, i].set_title("Original")
        axs[0, i].axis('off')

        # Mask visualization (white = masked in SimMIM)
        axs[1, i].imshow(mask_vis[i].permute(1, 2, 0), cmap='gray')
        axs[1, i].set_title("Mask (white=masked)")
        axs[1, i].axis('off')

        # Reconstructed
        axs[2, i].imshow(masked_imgs[i].permute(1, 2, 0))
        axs[2, i].set_title("Masked + Reconstructed")
        axs[2, i].axis('off')

    plt.tight_layout()
    plt.savefig(f"simmim_reconstruction_example.png")
    plt.close()



In [28]:
# Training loop
def train_model(model, trainloader, testloader, criterion, optimizer, lr_scheduler, device, epochs):
    best_loss = float('inf')
    train_losses = []

    for epoch in range(epochs):
        # Train for one epoch
        train_loss = train_epoch(model, trainloader, criterion, optimizer, device, epoch)
        train_losses.append(train_loss)

        # Update learning rate
        lr_scheduler.step()

        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}")

        # Save checkpoint if best model
        if train_loss < best_loss:
            best_loss = train_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_loss,
            }, 'best_vit_mim_model.pth')
            print(f"Checkpoint saved (Loss: {best_loss:.4f})")

        # Visualize reconstruction every 5 epochs
        if (epoch + 1) % 5 == 0:
            visualize_reconstruction(model, testloader, device)

    # Plot training loss
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training Loss Over Time')
    plt.legend()
    plt.savefig('training_loss.png')
    plt.close()

    return train_losses


# Main execution
if __name__ == "__main__":
    print("Starting training Vision Transformer with Masked Image Modeling on CIFAR-100...")

    # Train the model
    losses = train_model(
        model=model,
        trainloader=trainloader,
        testloader=testloader,
        criterion=criterion,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        device=device,
        epochs=EPOCHS
    )

    print("Training complete!")

    # Load best model for final visualization
    checkpoint = torch.load('best_vit_mim_model.pth')
    model.load_state_dict(checkpoint['model_state_dict'])

    # Final reconstruction visualization
    print("Creating final reconstruction visualization...")
    visualize_reconstruction(model, testloader, device, num_images=8)
    print("Visualization saved as 'reconstruction_example.png'")

Starting training Vision Transformer with Masked Image Modeling on CIFAR-100...


Epoch 1/15: 100%|██████████| 25/25 [00:17<00:00,  1.42it/s, Loss=0.3742]


Epoch 1/15, Train Loss: 0.3742
Checkpoint saved (Loss: 0.3742)


Epoch 2/15: 100%|██████████| 25/25 [00:17<00:00,  1.42it/s, Loss=0.1901]


Epoch 2/15, Train Loss: 0.1901
Checkpoint saved (Loss: 0.1901)


Epoch 3/15: 100%|██████████| 25/25 [00:17<00:00,  1.42it/s, Loss=0.1484]


Epoch 3/15, Train Loss: 0.1484
Checkpoint saved (Loss: 0.1484)


Epoch 4/15: 100%|██████████| 25/25 [00:17<00:00,  1.42it/s, Loss=0.1335]


Epoch 4/15, Train Loss: 0.1335
Checkpoint saved (Loss: 0.1335)


Epoch 5/15: 100%|██████████| 25/25 [00:17<00:00,  1.42it/s, Loss=0.1162]


Epoch 5/15, Train Loss: 0.1162
Checkpoint saved (Loss: 0.1162)


Epoch 6/15:   0%|          | 0/25 [00:00<?, ?it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7feca9196a20>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^Exception ignored in: ^^<function _MultiProcessingDataLoaderIter.__del__ at 0x7feca9196a20>^
^Traceback (most recent call last):
^  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
^^    ^self._shutdown_workers()^
^  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers

      File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
if w.is_alive():    
assert self._parent_pid == os.getpid(), 'can only test a child p

Epoch 6/15, Train Loss: 0.1079
Checkpoint saved (Loss: 0.1079)


Epoch 7/15:   0%|          | 0/25 [00:00<?, ?it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7feca9196a20>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^Exception ignored in: ^^<function _MultiProcessingDataLoaderIter.__del__ at 0x7feca9196a20>
^Traceback (most recent call last):
^  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
^
      File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
self._shutdown_workers()    
assert self._parent_pid == os.getpid(), 'can only test a child process'  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers

      if w.is

Epoch 7/15, Train Loss: 0.1020
Checkpoint saved (Loss: 0.1020)


Epoch 8/15:   0%|          | 0/25 [00:00<?, ?it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7feca9196a20>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^Exception ignored in: 
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
<function _MultiProcessingDataLoaderIter.__del__ at 0x7feca9196a20>    
assert self._parent_pid == os.getpid(), 'can only test a child process'Traceback (most recent call last):

  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
      self._shutdown_workers() 
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
      if 

Epoch 8/15, Train Loss: 0.0967
Checkpoint saved (Loss: 0.0967)


Epoch 9/15:   4%|▍         | 1/25 [00:01<00:36,  1.53s/it, Loss=0.0929]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7feca9196a20>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
   Exception ignored in:   <function _MultiProcessingDataLoaderIter.__del__ at 0x7feca9196a20> 
 Traceback (most recent call last):
^  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
^^    ^self._shutdown_workers()^
^  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
^    ^^if w.is_alive():^
^ ^ 
    File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
      assert self._parent_pid == os.getpid()

Epoch 9/15, Train Loss: 0.0914
Checkpoint saved (Loss: 0.0914)


Epoch 10/15:   4%|▍         | 1/25 [00:01<00:36,  1.52s/it, Loss=0.0892]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7feca9196a20>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
     Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7feca9196a20> ^
^Traceback (most recent call last):
^  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
^^    ^self._shutdown_workers()^
^  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
^    ^^if w.is_alive():^

   File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
       assert self._parent_pid == os.getpid(),

Epoch 10/15, Train Loss: 0.0884
Checkpoint saved (Loss: 0.0884)


Epoch 11/15: 100%|██████████| 25/25 [00:17<00:00,  1.42it/s, Loss=0.0861]


Epoch 11/15, Train Loss: 0.0861
Checkpoint saved (Loss: 0.0861)


Epoch 12/15: 100%|██████████| 25/25 [00:17<00:00,  1.42it/s, Loss=0.0846]


Epoch 12/15, Train Loss: 0.0846
Checkpoint saved (Loss: 0.0846)


Epoch 13/15: 100%|██████████| 25/25 [00:17<00:00,  1.42it/s, Loss=0.0837]


Epoch 13/15, Train Loss: 0.0837
Checkpoint saved (Loss: 0.0837)


Epoch 14/15: 100%|██████████| 25/25 [00:17<00:00,  1.42it/s, Loss=0.0831]


Epoch 14/15, Train Loss: 0.0831
Checkpoint saved (Loss: 0.0831)


Epoch 15/15: 100%|██████████| 25/25 [00:17<00:00,  1.42it/s, Loss=0.0828]


Epoch 15/15, Train Loss: 0.0828
Checkpoint saved (Loss: 0.0828)
Training complete!
Creating final reconstruction visualization...
Visualization saved as 'reconstruction_example.png'
