<a href="https://colab.research.google.com/github/namesarnav/SimMIM/blob/main/ViT_MIM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install wandb



In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
import matplotlib.pyplot as plt
from transformers import ViTModel, ViTFeatureExtractor
import time  # For tracking training duration
import wandb  # Import Weights & Biases

In [3]:
wandb.init(project="SimMIM-CIFAR100", config={"epochs": 5, "batch_size": 300, "learning_rate": 2e-5})

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mnamesarnav[0m ([33mnamesarnav-unt[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
# Use A100 GPU efficiently
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True  # Optimizes performance on A100

# Load CIFAR-100 dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize for ViT input
    transforms.ToTensor()
])

trainset = datasets.CIFAR100(root="./data", train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=wandb.config.batch_size, shuffle=True, num_workers=8, pin_memory=True)

# Load pre-trained Vision Transformer (backbone)
model_name = "google/vit-base-patch16-224"
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
backbone = ViTModel.from_pretrained(model_name).to(device)


100%|██████████| 169M/169M [00:13<00:00, 12.8MB/s]
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
# Masking function
def mask_image(image, mask_ratio=0.5, patch_size=32):
    _, H, W = image.shape
    num_patches_h = H // patch_size
    num_patches_w = W // patch_size
    total_patches = num_patches_h * num_patches_w
    masked_indices = np.random.choice(total_patches, int(mask_ratio * total_patches), replace=False)

    image_patches = image.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)
    for idx in masked_indices:
        row, col = divmod(idx, num_patches_w)
        image_patches[:, row, col] = 0  # Mask out entire patch

    return image

# Define pixel reconstruction head (SimMIM-style)
class PixelReconstructionHead(nn.Module):
    def __init__(self, in_features, out_features=224*224*3):
        super().__init__()
        self.fc = nn.Linear(in_features, out_features)  # Predict raw pixel values
    def forward(self, x):
        return self.fc(x)

reconstruction_head = PixelReconstructionHead(backbone.config.hidden_size).to(device)

In [6]:
torch.cuda.empty_cache()

In [8]:
# Optimized mixed precision training setup
scaler = torch.cuda.amp.GradScaler()  # Helps optimize computations on A100
criterion = nn.L1Loss()  # Pixel regression loss
optimizer = optim.AdamW(list(backbone.parameters()) + list(reconstruction_head.parameters()), lr=wandb.config.learning_rate)

# Training loop with WandB logging
num_epochs = wandb.config.epochs
start_time = time.time()


for epoch in range(num_epochs):
    total_loss = 0
    print(f"\n🚀 Starting Epoch {epoch+1}/{num_epochs} 🚀")
    try:
        for batch_idx, (images, _) in enumerate(trainloader):
            images = images.to(device)

            # Apply structured masking
            masked_images = torch.stack([mask_image(img) for img in images]).to(device)

            # Forward pass with mixed precision
            with torch.cuda.amp.autocast():
                inputs = feature_extractor(masked_images, return_tensors="pt")["pixel_values"].to(device)
                features = backbone(inputs).last_hidden_state.mean(dim=1)
                reconstructed_pixels = reconstruction_head(features).view(-1, 3, 224, 224)
                loss = criterion(reconstructed_pixels, images)

            # Backpropagation
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()

            # Log metrics to WandB
            wandb.log({"batch_loss": loss.item()})

            # Display progress every 10 batches
            if batch_idx % 10 == 0:
                print(f"🔄 Batch {batch_idx}/{len(trainloader)} - Loss: {loss.item():.4f}")


        torch.save({
            'epoch': epoch,
            'model_state_dict': backbone.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'reconstruction_head_state_dict': reconstruction_head.state_dict(),
        }, f"simmim_checkpoint_epoch_{epoch}.pth")


        avg_loss = total_loss / len(trainloader)
        wandb.log({"epoch_loss": avg_loss})
        print(f"✅ Epoch {epoch+1} Completed! Avg Loss: {avg_loss:.4f}")

    except OutOfMemoryError:
        torch.cuda.empty_cache()
        # checkpoint = torch.load("simmim_checkpoint_epoch_4.pth", map_location=device)
        # backbone.load_state_dict(checkpoint['model_state_dict'])
        # reconstruction_head.load_state_dict(checkpoint['reconstruction_head_state_dict'])
        # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# Display training duration
end_time = time.time()
print(f"\n🎉 Training Finished in {end_time - start_time:.2f} seconds 🎉")
wandb.log({"training_time_seconds": end_time - start_time})

# Save model checkpoint
torch.save({
    'epoch': num_epochs,
    'model_state_dict': backbone.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'reconstruction_head_state_dict': reconstruction_head.state_dict(),
}, "simmim_model_checkpoint.pth")

wandb.save("simmim_model_checkpoint.pth")  # Upload model to WandB
print(f"\n✅ Model saved and logged to WandB ✅")

# Finish WandB run
wandb.finish()




🚀 Starting Epoch 1/5 🚀


  scaler = torch.cuda.amp.GradScaler()  # Helps optimize computations on A100
  with torch.cuda.amp.autocast():


🔄 Batch 0/167 - Loss: 0.3415
🔄 Batch 10/167 - Loss: 0.3098
🔄 Batch 20/167 - Loss: 0.2834
🔄 Batch 30/167 - Loss: 0.2663
🔄 Batch 40/167 - Loss: 0.2587
🔄 Batch 50/167 - Loss: 0.2531
🔄 Batch 60/167 - Loss: 0.2512
🔄 Batch 70/167 - Loss: 0.2493
🔄 Batch 80/167 - Loss: 0.2451
🔄 Batch 90/167 - Loss: 0.2496
🔄 Batch 100/167 - Loss: 0.2396
🔄 Batch 110/167 - Loss: 0.2511
🔄 Batch 120/167 - Loss: 0.2405
🔄 Batch 130/167 - Loss: 0.2463
🔄 Batch 140/167 - Loss: 0.2499
🔄 Batch 150/167 - Loss: 0.2389
🔄 Batch 160/167 - Loss: 0.2429
✅ Epoch 1 Completed! Avg Loss: 0.2574

🚀 Starting Epoch 2/5 🚀
🔄 Batch 0/167 - Loss: 0.2444
🔄 Batch 10/167 - Loss: 0.2416
🔄 Batch 20/167 - Loss: 0.2431
🔄 Batch 30/167 - Loss: 0.2496
🔄 Batch 40/167 - Loss: 0.2431
🔄 Batch 50/167 - Loss: 0.2479
🔄 Batch 60/167 - Loss: 0.2344
🔄 Batch 70/167 - Loss: 0.2460
🔄 Batch 80/167 - Loss: 0.2465
🔄 Batch 90/167 - Loss: 0.2510
🔄 Batch 100/167 - Loss: 0.2471
🔄 Batch 110/167 - Loss: 0.2392
🔄 Batch 120/167 - Loss: 0.2458
🔄 Batch 130/167 - Loss: 0.2440

0,1
batch_loss,█▅▃▂▂▂▂▁▁▁▁▁▁▁▁▁▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁
epoch_loss,█▁▁▁▁
training_time_seconds,▁

0,1
batch_loss,0.24105
epoch_loss,0.24312
training_time_seconds,1735.34421


In [20]:
torch.cuda.empty_cache()

In [21]:
def evaluate_model(model, reconstruction_head, dataloader):
    print("\n🔍 Starting Evaluation...")
    start_time = time.time()  # Track start time

    total_loss, total_ssim, total_psnr, num_images = 0, 0, 0, 0
    model.eval(), reconstruction_head.eval()

    with torch.no_grad():  # Disable gradients for efficiency
        for batch_idx, (images, _) in enumerate(dataloader):
            images = images.to(device)
            masked_images = torch.stack([mask_image(img) for img in images]).to(device)

            # Forward pass
            inputs = feature_extractor(masked_images, return_tensors="pt")["pixel_values"].to(device)
            features = model(inputs).last_hidden_state.mean(dim=1)
            reconstructed_images = reconstruction_head(features).view(-1, 3, 224, 224)

            # Compute L1 loss (pixel reconstruction error)
            loss = F.l1_loss(reconstructed_images, images)
            total_loss += loss.item()

            # Compute SSIM & PSNR for each image
            for i in range(images.shape[0]):
                orig = images[i].permute(1, 2, 0).cpu().numpy()
                recon = reconstructed_images[i].permute(1, 2, 0).cpu().numpy()

                # Resize images for SSIM (ensuring win_size compatibility)
                orig_resized = resize(orig, (224, 224), anti_aliasing=True)
                recon_resized = resize(recon, (224, 224), anti_aliasing=True)

                ssim_score = ssim(orig_resized, recon_resized, data_range=255, channel_axis=-1, win_size=7)
                total_ssim += ssim_score

                # Compute PSNR (Peak Signal-to-Noise Ratio)
                mse = np.mean((orig - recon) ** 2)
                psnr_score = 20 * np.log10(255.0 / np.sqrt(mse))
                total_psnr += psnr_score

            num_images += images.shape[0]

            # Print progress for every 10 batches
            if batch_idx % 10 == 0:
                print(f"🔄 Batch {batch_idx}/{len(dataloader)} Processed")

    # Compute average metrics
    avg_loss = total_loss / num_images
    avg_ssim = total_ssim / num_images
    avg_psnr = total_psnr / num_images

    end_time = time.time()  # Track end time

    # Print results in a readable format
    print("\n✅ Evaluation Complete!")
    print(f"📌 Avg Reconstruction Loss (L1): {avg_loss:.4f}")
    print(f"📌 Avg Structural Similarity (SSIM): {avg_ssim:.4f}")
    print(f"📌 Avg Peak Signal-to-Noise Ratio (PSNR): {avg_psnr:.2f} dB")
    print(f"⏳ Evaluation Time: {end_time - start_time:.2f} seconds")

    return avg_loss, avg_ssim, avg_psnr

# Run evaluation loop
evaluate_model(backbone, reconstruction_head, trainloader)


🔍 Starting Evaluation...
🔄 Batch 0/167 Processed
🔄 Batch 10/167 Processed
🔄 Batch 20/167 Processed
🔄 Batch 30/167 Processed
🔄 Batch 40/167 Processed
🔄 Batch 50/167 Processed
🔄 Batch 60/167 Processed
🔄 Batch 70/167 Processed
🔄 Batch 80/167 Processed
🔄 Batch 90/167 Processed
🔄 Batch 100/167 Processed
🔄 Batch 110/167 Processed
🔄 Batch 120/167 Processed
🔄 Batch 130/167 Processed
🔄 Batch 140/167 Processed
🔄 Batch 150/167 Processed
🔄 Batch 160/167 Processed

✅ Evaluation Complete!
📌 Avg Reconstruction Loss (L1): 0.0008
📌 Avg Structural Similarity (SSIM): 0.9837
📌 Avg Peak Signal-to-Noise Ratio (PSNR): 57.89 dB
⏳ Evaluation Time: 1836.40 seconds


(0.0008072322288155555, np.float32(0.98368484), np.float32(57.88597))

In [30]:


# Convert accumulated values into averages
epochs = np.arange(1, num_images + 1)
ssim_values = np.array([total_ssim / num_images] * num_images)
psnr_values = np.array([total_psnr / num_images] * num_images)
loss_values = np.array([total_loss / num_images] * num_images)

plt.figure(figsize=(8, 5))
plt.plot(epochs, loss_values, marker='o', linestyle='-', color='b', label="Reconstruction Loss (L1)")
plt.xlabel("Evaluation Batches")
plt.ylabel("Loss")
plt.title("Reconstruction Loss Over Evaluation Batches")
plt.legend()
plt.grid(True)
plt.savefig("evaluation_loss.png")  # Save for PowerPoint
plt.show()

plt.figure(figsize=(8, 5))
plt.plot(epochs, ssim_values, marker='s', linestyle='-', color='g', label="SSIM Score")
plt.plot(epochs, psnr_values, marker='^', linestyle='-', color='r', label="PSNR (dB)")
plt.xlabel("Evaluation Batches")
plt.ylabel("Metric Value")
plt.title("Evaluation Metrics Over Batches")
plt.legend()
plt.grid(True)
plt.savefig("evaluation_metrics.png")  # Save for PowerPoint
plt.show()

def visualize_images(original, masked, reconstructed):
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    axes[0].imshow(original.permute(1, 2, 0).cpu())
    axes[0].set_title("Original Image")
    axes[1].imshow(masked.permute(1, 2, 0).cpu())
    axes[1].set_title("Masked Image")
    axes[2].imshow(reconstructed.permute(1, 2, 0).detach().cpu().numpy())
    axes[2].set_title("Reconstructed Image")

    plt.tight_layout()
    plt.savefig("image_reconstruction.png")  # Save for PowerPoint
    plt.show()

# Load a test image
sample_image = trainset[0][0].to(device)
masked_sample = mask_image(sample_image)
features = backbone(feature_extractor(masked_sample.unsqueeze(0), return_tensors="pt")["pixel_values"].to(device)).last_hidden_state.mean(dim=1)
reconstructed_sample = reconstruction_head(features).view(3, 224, 224)

visualize_images(sample_image, masked_sample, reconstructed_sample)

NameError: name 'num_images' is not defined