# MAMBA-GINR with LARGER DECODER

## Changes from previous version:
- ✅ Gaussian Fourier features (same as before)
- ✅ Correct jittering implementation (same as before)
- ✅ Separate modulation/decoding coordinates (same as before)
- ✅ No jittering at test time (same as before)
- 🆕 **LARGER DECODER**: hidden_dim=512, num_layers=3 (was 256, 1)
- 🆕 **Expected improvement**: +3-5 dB PSNR at 128×128 super-resolution

## Decoder capacity comparison:
- **Small decoder** (previous): 256 dims, 1 layer/scale, ~450K params
- **Large decoder** (this notebook): 512 dims, 3 layers/scale, ~4M params
- **Increase**: 9x parameters

In [None]:
# Cell 1: Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import einops
import numpy as np
import math
from tqdm import tqdm
import matplotlib.pyplot as plt
from mamba_ssm import Mamba

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Cell 2: Gaussian Fourier Feature Encoding

def gaussian_fourier_encode(coords, B_matrix):
    """
    Gaussian Fourier Feature encoding (Rahimi & Recht, 2007)
    
    Args:
        coords: (H, W, 2) or (HW, 2) - normalized coordinates in [0,1]
        B_matrix: (n_features, 2) - random frequency matrix
    
    Returns:
        features: (HW, 2*n_features) - Fourier features
    """
    if coords.dim() == 3:
        coords = coords.view(-1, coords.shape[-1])
    
    proj = 2 * np.pi * coords @ B_matrix.T
    features = torch.cat([torch.cos(proj), torch.sin(proj)], dim=-1)
    return features


def create_coordinate_grid(H, W, device):
    """Create normalized coordinate grid in [0,1]"""
    y = torch.linspace(0, 1, H, device=device)
    x = torch.linspace(0, 1, W, device=device)
    yy, xx = torch.meshgrid(y, x, indexing='ij')
    coords = torch.stack([yy, xx], dim=-1)
    return coords


def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

In [None]:
# Cell 3: Core Components (BiMamba, Encoder, LP Tokens)

class BiMamba(nn.Module):
    def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
        super().__init__()
        self.forward_mamba = Mamba(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand)
        self.backward_mamba = Mamba(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand)
        self.proj = nn.Linear(2 * d_model, d_model)

    def forward(self, x):
        x_forward = self.forward_mamba(x)
        x_backward = self.backward_mamba(torch.flip(x, dims=[1]))
        x_backward = torch.flip(x_backward, dims=[1])
        x = torch.cat([x_forward, x_backward], dim=-1)
        x = self.proj(x)
        return x


class PatchEncoder(nn.Module):
    def __init__(self, patch_size=2, in_channels=3, dim=256):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Linear(patch_size * patch_size * in_channels, dim)

    def forward(self, x):
        B, C, H, W = x.shape
        p = self.patch_size
        x = einops.rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
        x = self.proj(x)
        return x


class LearnablePositionTokens(nn.Module):
    def __init__(self, num_tokens=256, dim=256):
        super().__init__()
        self.tokens = nn.Parameter(torch.randn(num_tokens, dim) * 0.02)

    def forward(self, B):
        return einops.repeat(self.tokens, 'n d -> b n d', b=B)

In [None]:
# Cell 4: Cross-Attention and Residual Block

class SharedTokenCrossAttention(nn.Module):
    """Cross-attention with spatial bias"""
    def __init__(self, query_dim, context_dim=None, heads=2, dim_head=64):
        super().__init__()
        context_dim = default(context_dim, query_dim)
        inner_dim = dim_head * heads
        self.heads = heads
        self.dim_head = dim_head
        self.scale = dim_head ** -0.5

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, query_dim)

    def forward(self, x, context, bias=None):
        B, HW, D = x.shape
        H = self.heads
        Dh = self.dim_head
        D_inner = H * Dh

        q = self.to_q(x)
        kv = self.to_kv(context)
        k, v = kv.chunk(2, dim=-1)

        q = q.view(B, HW, H, Dh).transpose(1, 2)
        k = k.view(B, -1, H, Dh).transpose(1, 2)
        v = v.view(B, -1, H, Dh).transpose(1, 2)

        sim = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        if bias is not None:
            bias = einops.repeat(bias, 'b l n -> b h l n', h=H)
            bias = bias.transpose(-2, -1)
            sim = sim + bias

        attn = sim.softmax(dim=-1)
        out = torch.matmul(attn, v)

        out = out.transpose(1, 2).contiguous().view(B, HW, D_inner)
        out = self.to_out(out)
        return out


class ResidualBlock(nn.Module):
    """Residual block for deeper decoder"""
    def __init__(self, dim):
        super().__init__()
        self.linear1 = nn.Linear(dim, dim)
        self.linear2 = nn.Linear(dim, dim)
        self.act = nn.ReLU()

    def forward(self, x):
        residual = x
        x = self.act(self.linear1(x))
        x = self.linear2(x)
        return self.act(x + residual)

In [None]:
# Cell 5: LARGER LAINR Decoder

class LAINRDecoderLarge(nn.Module):
    """
    LARGER LAINR Decoder with increased capacity
    
    Changes from original:
    - hidden_dim: 256 → 512 (2x width)
    - num_layers: 1 → 3 (3x depth per scale)
    - Residual connections between layers
    - Total parameters: ~450K → ~4M (9x increase)
    """

    def __init__(self, feature_dim=64, input_dim=2, output_dim=3,
                 sigma_q=16, sigma_ls=[128, 32], n_patches=256,
                 hidden_dim=512,
                 context_dim=256,
                 learnable_frequencies=True,
                 num_layers=3):
        super().__init__()

        self.layer_num = len(sigma_ls)
        self.n_features = feature_dim // 2
        self.patch_num = int(math.sqrt(n_patches))
        self.alpha = 10.0
        self.num_layers = num_layers

        # Initialize Gaussian Fourier frequency matrices
        B_q_init = torch.randn(self.n_features, input_dim) / sigma_q
        B_ls_init = [torch.randn(self.n_features, input_dim) / sigma_ls[i]
                     for i in range(self.layer_num)]

        if learnable_frequencies:
            self.B_q = nn.Parameter(B_q_init)
            self.B_ls = nn.ParameterList([
                nn.Parameter(B_ls_init[i]) for i in range(self.layer_num)
            ])
        else:
            self.register_buffer('B_q', B_q_init)
            for i in range(self.layer_num):
                self.register_buffer(f'B_l_{i}', B_ls_init[i])
            self.B_ls = [getattr(self, f'B_l_{i}') for i in range(self.layer_num)]

        # Query encoding - LARGER
        self.query_lin = nn.Linear(feature_dim, hidden_dim)

        self.modulation_ca = SharedTokenCrossAttention(
            query_dim=hidden_dim, context_dim=context_dim, heads=2
        )

        # Bandwidth encoders - LARGER + DEEPER
        self.bandwidth_lins = nn.ModuleList([
            nn.Sequential(
                nn.Linear(feature_dim, hidden_dim),
                nn.ReLU(),
                *[ResidualBlock(hidden_dim) for _ in range(num_layers - 1)]
            )
            for _ in range(self.layer_num)
        ])

        # Modulation projections - LARGER + DEEPER
        self.modulation_lins = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                *[ResidualBlock(hidden_dim) for _ in range(num_layers - 1)]
            )
            for _ in range(self.layer_num)
        ])

        # Hidden value layers - LARGER
        self.hv_lins = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                *[ResidualBlock(hidden_dim) for _ in range(num_layers - 1)]
            )
            for _ in range(self.layer_num - 1)
        ])

        # Output layers - LARGER + DEEPER
        self.out_lins = nn.ModuleList([
            nn.Sequential(
                *[ResidualBlock(hidden_dim) for _ in range(num_layers - 1)],
                nn.Linear(hidden_dim, output_dim)
            )
            for _ in range(self.layer_num)
        ])

        self.act = nn.ReLU()

    def get_patch_index(self, grid, H, W):
        """Convert coordinates to patch indices"""
        y = grid[:, 0]
        x = grid[:, 1]
        row = (y * H).to(torch.int32).clamp(0, H-1)
        col = (x * W).to(torch.int32).clamp(0, W-1)
        return row * W + col

    def approximate_relative_distances(self, target_index, H, W, m):
        """Compute spatial bias"""
        alpha = self.alpha
        N = H * W
        t = target_index.float() / N

        token_positions = torch.tensor(
            [(i + 0.5) / m for i in range(m)],
            device=target_index.device
        )

        t_expanded = t.unsqueeze(0)
        tokens_expanded = token_positions.unsqueeze(1)
        rel_distances = -alpha * torch.abs(t_expanded - tokens_expanded)**2

        return rel_distances

    def forward(self, coords_decoding, tokens, coords_modulation=None):
        """
        Forward pass with larger capacity

        Args:
            coords_decoding: (B, H, W, 2) - where to predict RGB
            tokens: (B, L, D) - LP token features
            coords_modulation: (B, H, W, 2) - where to extract modulation (None = test mode)
        """
        B, query_shape = coords_decoding.shape[0], coords_decoding.shape[1:-1]
        coords_dec = coords_decoding.view(B, -1, coords_decoding.shape[-1])

        if coords_modulation is not None:
            coords_mod = coords_modulation.view(B, -1, coords_modulation.shape[-1])
        else:
            coords_mod = coords_dec

        # === MODULATION EXTRACTION ===

        grid_mod = coords_mod[0]
        indexes = self.get_patch_index(grid_mod, self.patch_num, self.patch_num)
        rel_distances = self.approximate_relative_distances(
            indexes, self.patch_num, self.patch_num, tokens.shape[1]
        )
        bias = einops.repeat(rel_distances, 'l n -> b l n', b=B)

        # Query encoding
        x_q = einops.repeat(
            gaussian_fourier_encode(coords_mod[0], self.B_q), 'l d -> b l d', b=B
        )
        x_q = self.act(self.query_lin(x_q))

        # Extract modulation
        modulation_vector = self.modulation_ca(x_q, context=tokens, bias=bias)

        # === MULTI-SCALE DECODING (with larger capacity) ===

        modulations_l = []
        for k in range(self.layer_num):
            # Bandwidth encoding (DEEPER network)
            x_l = einops.repeat(
                gaussian_fourier_encode(coords_dec[0], self.B_ls[k]), 'l d -> b l d', b=B
            )
            h_l = self.bandwidth_lins[k](x_l)

            # Modulation projection (DEEPER network)
            m_proj = self.modulation_lins[k](modulation_vector)

            # Combine
            m_l = self.act(h_l + m_proj)
            modulations_l.append(m_l)

        # Residual connections (DEEPER network)
        h_v = [modulations_l[0]]
        for i in range(self.layer_num - 1):
            h_vl = self.hv_lins[i](modulations_l[i+1] + h_v[i])
            h_v.append(h_vl)

        # Output layers (DEEPER network)
        outs = [self.out_lins[i](h_v[i]) for i in range(self.layer_num)]
        out = sum(outs)
        out = out.view(B, *query_shape, -1)

        return out


# Count parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("✓ LARGER LAINR Decoder defined")
print("  - hidden_dim: 512 (was 256)")
print("  - num_layers: 3 per scale (was 1)")
print("  - Expected parameters: ~4M (was ~450K)")

In [None]:
# Cell 6: Complete Model

class MambaGINR_LargeDecoder(nn.Module):
    def __init__(self, 
                 patch_size=2,
                 in_channels=3,
                 dim=256,
                 num_lp_tokens=256,
                 feature_dim=64,
                 sigma_q=16,
                 sigma_ls=[128, 32],
                 hidden_dim=512,
                 num_layers=3):
        super().__init__()
        
        # Encoder
        self.patch_encoder = PatchEncoder(patch_size=patch_size, in_channels=in_channels, dim=dim)
        self.lp_tokens = LearnablePositionTokens(num_tokens=num_lp_tokens, dim=dim)
        self.mamba = BiMamba(d_model=dim)
        
        # Decoder (LARGER)
        self.num_patches = (32 // patch_size) ** 2
        self.hyponet = LAINRDecoderLarge(
            feature_dim=feature_dim,
            input_dim=2,
            output_dim=3,
            sigma_q=sigma_q,
            sigma_ls=sigma_ls,
            n_patches=self.num_patches,
            hidden_dim=hidden_dim,
            context_dim=dim,
            learnable_frequencies=True,
            num_layers=num_layers
        )

    def encode(self, x):
        B = x.shape[0]
        patch_features = self.patch_encoder(x)
        lp_tokens = self.lp_tokens(B)
        combined = torch.cat([patch_features, lp_tokens], dim=1)
        features = self.mamba(combined)
        lp_features = features[:, -lp_tokens.shape[1]:, :]
        return lp_features

    def forward(self, x, coords_decoding, coords_modulation=None):
        lp_features = self.encode(x)
        rgb = self.hyponet(coords_decoding, lp_features, coords_modulation)
        return rgb

print("✓ Complete model defined with LARGER decoder")

In [None]:
# Cell 7: Training Functions

def adjust_learning_rate(optimizer, epoch, base_lr=5e-4, warmup_epochs=5, max_epoch=40):
    """Learning rate schedule with warmup + cosine annealing"""
    min_lr = 1e-8

    if epoch < warmup_epochs:
        lr = base_lr * (epoch + 1) / warmup_epochs
    else:
        t = (epoch - warmup_epochs) / (max_epoch - warmup_epochs)
        lr = min_lr + 0.5 * (base_lr - min_lr) * (1 + np.cos(np.pi * t))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return lr


def train_epoch_correct(model, loader, optimizer, device, epoch,
                        resolution=32,
                        jitter_std=None,
                        offset_std=0.0):
    """
    Training with CORRECT jittering specification
    """
    model.train()
    total_loss = 0
    total_psnr = 0

    # Auto-compute jittering std if not provided
    if jitter_std is None:
        pixel_size = 1.0 / resolution
        max_allowed_jitter = pixel_size / 2
        jitter_std = max_allowed_jitter / 3  # 3-sigma rule

    pbar = tqdm(loader, desc=f"Epoch {epoch}")
    for images, _ in pbar:
        images = images.to(device)
        B = images.shape[0]

        # Create base coordinate grid INSIDE loop (critical!)
        base_coords = create_coordinate_grid(resolution, resolution, device)

        # Apply jittering for modulation extraction
        jitter_small = torch.randn_like(base_coords) * jitter_std
        coords_modulation = (base_coords + jitter_small).clamp(0, 1)

        # Prediction offset (default: 0)
        if offset_std > 0:
            prediction_offset = torch.randn_like(base_coords) * offset_std
            coords_decoding = (coords_modulation + prediction_offset).clamp(0, 1)
        else:
            coords_decoding = coords_modulation

        # Repeat for batch
        coords_mod_batch = einops.repeat(coords_modulation, 'h w d -> b h w d', b=B)
        coords_dec_batch = einops.repeat(coords_decoding, 'h w d -> b h w d', b=B)

        # Forward pass
        pred = model(images, coords_dec_batch, coords_mod_batch)

        # Ground truth
        gt = einops.rearrange(images, 'b c h w -> b h w c')

        # Loss
        mses = ((pred - gt)**2).view(B, -1).mean(dim=-1)
        loss = mses.mean()
        psnr = (-10 * torch.log10(mses)).mean()

        # Backward
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()
        total_psnr += psnr.item()
        pbar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'psnr': f"{psnr.item():.2f}",
            'jitter': f"σ={jitter_std:.5f}"
        })

    return total_loss / len(loader), total_psnr / len(loader)


def validate_correct(model, loader, device, resolution=32):
    """Validation with NO jittering"""
    model.eval()
    total_loss = 0
    total_psnr = 0

    with torch.no_grad():
        for images, _ in tqdm(loader, desc="Validation"):
            images = images.to(device)
            B = images.shape[0]

            # Exact coordinates (no jittering)
            coords = create_coordinate_grid(resolution, resolution, device)
            coords_batch = einops.repeat(coords, 'h w d -> b h w d', b=B)

            # Forward pass (coords_modulation=None → test mode)
            pred = model(images, coords_batch, coords_modulation=None)

            # Ground truth
            gt = einops.rearrange(images, 'b c h w -> b h w c')

            # Metrics
            mses = ((pred - gt)**2).view(B, -1).mean(dim=-1)
            loss = mses.mean()
            psnr = (-10 * torch.log10(mses)).mean()

            total_loss += loss.item()
            total_psnr += psnr.item()

    return total_loss / len(loader), total_psnr / len(loader)


def super_resolve_correct(model, images, target_resolution=128, device='cpu'):
    """Super-resolution with NO jittering"""
    model.eval()

    with torch.no_grad():
        B = images.shape[0]

        # Encode at training resolution
        lp_features = model.encode(images)

        # Decode at target resolution (no jittering)
        coords = create_coordinate_grid(target_resolution, target_resolution, device)
        coords_batch = einops.repeat(coords, 'h w d -> b h w d', b=B)

        # Predict (coords_modulation=None → test mode)
        pred = model.hyponet(coords_batch, lp_features, coords_modulation=None)

        # Rearrange to image format
        pred_images = einops.rearrange(pred, 'b h w c -> b c h w')

    return pred_images

print("✓ Training functions defined")

In [None]:
# Cell 8: Data Loading

transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

In [None]:
# Cell 9: Model Initialization

model = MambaGINR_LargeDecoder(
    patch_size=2,
    in_channels=3,
    dim=256,
    num_lp_tokens=256,
    feature_dim=64,
    sigma_q=16,
    sigma_ls=[128, 32],
    hidden_dim=512,      # LARGER
    num_layers=3         # DEEPER
).to(device)

total_params = count_parameters(model)
decoder_params = count_parameters(model.hyponet)

print(f"\nTotal parameters: {total_params:,}")
print(f"Decoder parameters: {decoder_params:,}")
print(f"\nDecoder capacity: {decoder_params / 1e6:.1f}M params (was ~0.45M)")
print(f"Expected improvement: +3-5 dB PSNR at 128×128 super-resolution")

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4)

print("\n✓ Model initialized")

In [None]:
# Cell 10: Training Loop

num_epochs = 40
resolution = 32

print(f"\nTraining LARGER decoder for {num_epochs} epochs")
print(f"Resolution: {resolution}×{resolution}")
print(f"Auto jitter_std: {1/(6*resolution):.6f}\n")

for epoch in range(num_epochs):
    lr = adjust_learning_rate(optimizer, epoch, base_lr=5e-4, max_epoch=num_epochs)
    
    train_loss, train_psnr = train_epoch_correct(
        model, train_loader, optimizer, device,
        epoch+1,
        resolution=resolution,
        jitter_std=None,    # Auto: 1/(6*32)
        offset_std=0.0
    )
    
    val_loss, val_psnr = validate_correct(
        model, test_loader, device, resolution=resolution
    )
    
    print(f"Epoch {epoch+1}/{num_epochs} | LR: {lr:.6f}")
    print(f"  Train - Loss: {train_loss:.4f}, PSNR: {train_psnr:.2f} dB")
    print(f"  Val   - Loss: {val_loss:.4f}, PSNR: {val_psnr:.2f} dB\n")

print("Training complete!")

In [None]:
# Cell 11: Super-Resolution Test

# Get test images
test_images, _ = next(iter(test_loader))
test_images = test_images[:8].to(device)

# Super-resolve at multiple resolutions
print("Testing super-resolution with LARGER decoder...\n")

sr_64 = super_resolve_correct(model, test_images, target_resolution=64, device=device)
sr_128 = super_resolve_correct(model, test_images, target_resolution=128, device=device)
sr_256 = super_resolve_correct(model, test_images, target_resolution=256, device=device)

print(f"Original: {test_images.shape}")
print(f"SR 64×64: {sr_64.shape}")
print(f"SR 128×128: {sr_128.shape}")
print(f"SR 256×256: {sr_256.shape}")

# Visualize
fig, axes = plt.subplots(4, 8, figsize=(16, 8))

for i in range(8):
    # Original 32×32
    axes[0, i].imshow(test_images[i].cpu().permute(1, 2, 0).clamp(0, 1))
    axes[0, i].axis('off')
    if i == 0:
        axes[0, i].set_title('32×32', fontsize=10)
    
    # SR 64×64
    axes[1, i].imshow(sr_64[i].cpu().permute(1, 2, 0).clamp(0, 1))
    axes[1, i].axis('off')
    if i == 0:
        axes[1, i].set_title('64×64 SR', fontsize=10)
    
    # SR 128×128
    axes[2, i].imshow(sr_128[i].cpu().permute(1, 2, 0).clamp(0, 1))
    axes[2, i].axis('off')
    if i == 0:
        axes[2, i].set_title('128×128 SR', fontsize=10)
    
    # SR 256×256
    axes[3, i].imshow(sr_256[i].cpu().permute(1, 2, 0).clamp(0, 1))
    axes[3, i].axis('off')
    if i == 0:
        axes[3, i].set_title('256×256 SR', fontsize=10)

plt.tight_layout()
plt.savefig('larger_decoder_superresolution_results.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n✓ Super-resolution visualization saved as 'larger_decoder_superresolution_results.png'")
print("\nExpected improvements with LARGER decoder:")
print("  - Better texture synthesis")
print("  - More high-frequency details at 128×128 and 256×256")
print("  - Less smooth/blurry appearance")
print("  - Expected +3-5 dB PSNR improvement")