<a href="https://colab.research.google.com/github/nilaynishant/AIMLTutorial/blob/main/11_Basic_foundation_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Small Geospatial Foundation Model Training (Prithvi-inspired)
# Uses EuroSAT dataset for remote sensing

# Install required packages
!pip install torch torchvision einops timm pillow matplotlib scikit-learn tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
from torchvision.datasets import EuroSAT
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
from pathlib import Path

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# ====================
# 1. CONFIGURATION
# ====================
class Config:
    # Model parameters
    img_size = 64  # Smaller for faster training
    patch_size = 8
    in_channels = 3
    embed_dim = 256  # Smaller than full Prithvi
    depth = 6  # Fewer transformer blocks
    num_heads = 8
    mlp_ratio = 4

    # Training parameters
    batch_size = 32
    num_epochs = 20
    learning_rate = 1e-4
    weight_decay = 0.05
    warmup_epochs = 2

    # Masking parameters (for self-supervised learning)
    mask_ratio = 0.75

    # Data
    data_root = './data'
    num_workers = 2

config = Config()

In [None]:


# ====================
# 2. VISION TRANSFORMER (Encoder-Decoder)
# ====================
class PatchEmbed(nn.Module):
    """Split image into patches and embed them"""
    def __init__(self, img_size, patch_size, in_chans, embed_dim):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2

        self.proj = nn.Conv2d(in_chans, embed_dim,
                             kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # x: (B, C, H, W) -> (B, num_patches, embed_dim)
        x = self.proj(x)  # (B, embed_dim, H/P, W/P)
        x = x.flatten(2).transpose(1, 2)  # (B, num_patches, embed_dim)
        return x

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=True):
        super().__init__()
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

class MLP(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, dim)

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, num_heads)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, int(dim * mlp_ratio))

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class MaskedAutoencoder(nn.Module):
    """Masked Autoencoder for self-supervised pre-training"""
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Encoder
        self.patch_embed = PatchEmbed(
            config.img_size, config.patch_size,
            config.in_channels, config.embed_dim
        )

        num_patches = self.patch_embed.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))

        self.encoder_blocks = nn.ModuleList([
            TransformerBlock(config.embed_dim, config.num_heads, config.mlp_ratio)
            for _ in range(config.depth)
        ])
        self.encoder_norm = nn.LayerNorm(config.embed_dim)

        # Decoder
        decoder_embed_dim = config.embed_dim // 2
        self.decoder_embed = nn.Linear(config.embed_dim, decoder_embed_dim)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        self.decoder_pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, decoder_embed_dim)
        )

        self.decoder_blocks = nn.ModuleList([
            TransformerBlock(decoder_embed_dim, config.num_heads, config.mlp_ratio)
            for _ in range(config.depth // 2)
        ])
        self.decoder_norm = nn.LayerNorm(decoder_embed_dim)

        # Predict pixel values
        self.decoder_pred = nn.Linear(
            decoder_embed_dim,
            config.patch_size ** 2 * config.in_channels
        )

        self.initialize_weights()

    def initialize_weights(self):
        # Initialize pos_embed
        torch.nn.init.normal_(self.pos_embed, std=0.02)
        torch.nn.init.normal_(self.decoder_pos_embed, std=0.02)
        torch.nn.init.normal_(self.cls_token, std=0.02)
        torch.nn.init.normal_(self.mask_token, std=0.02)

    def random_masking(self, x, mask_ratio):
        """Random masking following MAE"""
        B, N, D = x.shape
        len_keep = int(N * (1 - mask_ratio))

        noise = torch.rand(B, N, device=x.device)
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        mask = torch.ones([B, N], device=x.device)
        mask[:, :len_keep] = 0
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

    def forward_encoder(self, x, mask_ratio):
        x = self.patch_embed(x)
        x = x + self.pos_embed[:, 1:, :]

        x, mask, ids_restore = self.random_masking(x, mask_ratio)

        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)

        for blk in self.encoder_blocks:
            x = blk(x)
        x = self.encoder_norm(x)

        return x, mask, ids_restore

    def forward_decoder(self, x, ids_restore):
        x = self.decoder_embed(x)

        mask_tokens = self.mask_token.repeat(
            x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1
        )
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
        x = torch.cat([x[:, :1, :], x_], dim=1)

        x = x + self.decoder_pos_embed

        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)

        x = self.decoder_pred(x)
        x = x[:, 1:, :]  # Remove cls token

        return x

    def forward(self, imgs, mask_ratio=0.75):
        latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
        pred = self.forward_decoder(latent, ids_restore)
        return pred, mask

    def patchify(self, imgs):
        """Convert images to patches"""
        p = self.config.patch_size
        h = w = self.config.img_size // p
        x = imgs.reshape(imgs.shape[0], 3, h, p, w, p)
        x = torch.einsum('nchpwq->nhwpqc', x)
        x = x.reshape(imgs.shape[0], h * w, p**2 * 3)
        return x

    def unpatchify(self, x):
        """Convert patches back to images"""
        p = self.config.patch_size
        h = w = self.config.img_size // p
        x = x.reshape(x.shape[0], h, w, p, p, 3)
        x = torch.einsum('nhwpqc->nchpwq', x)
        x = x.reshape(x.shape[0], 3, h * p, w * p)
        return x







In [None]:
# ====================
# 3. DATASET PREPARATION
# ====================
# Download EuroSAT dataset
transform = transforms.Compose([
    transforms.Resize(config.img_size),
    transforms.CenterCrop(config.img_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])

print("Downloading EuroSAT dataset...")
dataset = EuroSAT(root=config.data_root, download=True, transform=transform)

# Split into train/val
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size,
                        shuffle=False, num_workers=config.num_workers)

print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")



In [None]:
# ====================
# 4. TRAINING SETUP
# ====================
model = MaskedAutoencoder(config).to(device)
optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate,
                        weight_decay=config.weight_decay)

# Cosine learning rate schedule with warmup
def adjust_learning_rate(optimizer, epoch, config):
    if epoch < config.warmup_epochs:
        lr = config.learning_rate * (epoch + 1) / config.warmup_epochs
    else:
        lr = config.learning_rate * 0.5 * (1 + np.cos(
            np.pi * (epoch - config.warmup_epochs) /
            (config.num_epochs - config.warmup_epochs)
        ))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

def compute_loss(model, imgs, mask_ratio):
    pred, mask = model(imgs, mask_ratio)
    target = model.patchify(imgs)

    # MSE loss only on masked patches
    loss = (pred - target) ** 2
    loss = loss.mean(dim=-1)
    loss = (loss * mask).sum() / mask.sum()
    return loss


In [None]:
# ====================
# 5. TRAINING LOOP
# ====================
train_losses = []
val_losses = []

print("\nStarting training...")
for epoch in range(config.num_epochs):
    # Adjust learning rate
    lr = adjust_learning_rate(optimizer, epoch, config)

    # Training
    model.train()
    train_loss = 0
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{config.num_epochs}')

    for imgs, _ in pbar:
        imgs = imgs.to(device)

        optimizer.zero_grad()
        loss = compute_loss(model, imgs, config.mask_ratio)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'lr': f'{lr:.6f}'})

    train_loss /= len(train_loader)
    train_losses.append(train_loss)

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for imgs, _ in val_loader:
            imgs = imgs.to(device)
            loss = compute_loss(model, imgs, config.mask_ratio)
            val_loss += loss.item()

    val_loss /= len(val_loader)
    val_losses.append(val_loss)

    print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")


In [None]:
# ====================
# 6. VISUALIZATION
# ====================
# Plot training curves
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Progress')
plt.legend()
plt.grid(True)
plt.show()

# Visualize reconstructions
model.eval()
with torch.no_grad():
    # Get a batch
    imgs, _ = next(iter(val_loader))
    imgs = imgs[:4].to(device)

    # Forward pass
    pred, mask = model(imgs, config.mask_ratio)
    pred = model.unpatchify(pred)

    # Denormalize for visualization
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)

    imgs_vis = imgs * std + mean
    pred_vis = pred * std + mean

    # Create masked images
    mask = mask.unsqueeze(-1).repeat(1, 1, config.patch_size**2 * 3)
    mask = model.unpatchify(mask)
    imgs_masked = imgs * (1 - mask) + mask
    imgs_masked_vis = imgs_masked * std + mean

    # Plot
    fig, axes = plt.subplots(4, 3, figsize=(12, 16))
    for i in range(4):
        # Original
        axes[i, 0].imshow(imgs_vis[i].cpu().permute(1, 2, 0).clamp(0, 1))
        axes[i, 0].set_title('Original')
        axes[i, 0].axis('off')

        # Masked
        axes[i, 1].imshow(imgs_masked_vis[i].cpu().permute(1, 2, 0).clamp(0, 1))
        axes[i, 1].set_title(f'Masked ({config.mask_ratio*100:.0f}%)')
        axes[i, 1].axis('off')

        # Reconstruction
        axes[i, 2].imshow(pred_vis[i].cpu().permute(1, 2, 0).clamp(0, 1))
        axes[i, 2].set_title('Reconstruction')
        axes[i, 2].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
# ====================
# 7. SAVE MODEL
# ====================
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'config': config,
    'train_losses': train_losses,
    'val_losses': val_losses
}, 'geospatial_foundation_model.pth')

print("\nModel saved as 'geospatial_foundation_model.pth'")
print("\nTo use this model for downstream tasks (classification, segmentation, etc.),")
print("you can load the encoder weights and fine-tune on your specific task!")