# MAMBA-GINR with SCENT-Style Decoder (FIXED)

## Fix Applied:
- **Spatial bias dimension fix**: Now uses actual coordinate grid dimensions instead of fixed patch_num
- This prevents RuntimeError: tensor size mismatch (256 vs 1024)

## Key Changes from ResidualBlock Version:
1. **Attention + FeedForward blocks** (not ResidualBlocks)
2. **4x expansion** in FeedForward (512 → 2048 → 512)
3. **Skip connections** from Fourier encoding to output
4. **GEGLU gating**, LayerNorm, self-attention
5. **Sinusoidal initialization** for LP tokens

## Expected Improvement:
- **+5-10 dB PSNR** at 128×128 super-resolution
- Sharp high-frequency details instead of smooth blur

In [None]:
# Cell 1: Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import einops
from einops import rearrange, repeat
import numpy as np
import math
from math import log, pi
from tqdm import tqdm
import matplotlib.pyplot as plt
from mamba_ssm import Mamba

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

In [None]:
# Cell 2: Helper Functions

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def gaussian_fourier_encode(coords, B_matrix):
    if coords.dim() == 3:
        coords = coords.view(-1, coords.shape[-1])
    proj = 2 * np.pi * coords @ B_matrix.T
    features = torch.cat([torch.cos(proj), torch.sin(proj)], dim=-1)
    return features

def create_coordinate_grid(H, W, device):
    y = torch.linspace(0, 1, H, device=device)
    x = torch.linspace(0, 1, W, device=device)
    yy, xx = torch.meshgrid(y, x, indexing='ij')
    coords = torch.stack([yy, xx], dim=-1)
    return coords

def get_sinusoidal_embeddings(n, d):
    assert d % 2 == 0
    position = torch.arange(n, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d, 2).float() * -(log(10000.0) / d))
    pe = torch.zeros(n, d)
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe

print("✓ Helper functions defined")

In [None]:
# Cell 3: SCENT-Style Building Blocks

class PreNorm(nn.Module):
    def __init__(self, dim, fn, context_dim=None):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)
        self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None

    def forward(self, x, **kwargs):
        x = self.norm(x)
        if exists(self.norm_context):
            context = kwargs['context']
            normed_context = self.norm_context(context)
            kwargs.update(context=normed_context)
        return self.fn(x, **kwargs)

class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim=-1)
        return x * F.gelu(gates)

class FeedForward(nn.Module):
    def __init__(self, dim, mult=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult * 2),
            GEGLU(),
            nn.Linear(dim * mult, dim)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)
        self.scale = dim_head ** -0.5
        self.heads = heads
        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, query_dim)

    def forward(self, x, context=None, mask=None, bias=None):
        h = self.heads
        q = self.to_q(x)
        context = default(context, x)
        k, v = self.to_kv(context).chunk(2, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
        sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
        
        if exists(bias):
            if bias.dim() == 3 and bias.shape[0] == x.shape[0]:
                bias = repeat(bias, 'b l n -> (b h) l n', h=h)
            sim = sim + bias
        
        if exists(mask):
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h=h)
            sim.masked_fill_(~mask, max_neg_value)
        
        attn = sim.softmax(dim=-1)
        out = torch.einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        return self.to_out(out)

print("✓ SCENT-style blocks defined")

In [None]:
# Cell 4: Encoder Components

class BiMamba(nn.Module):
    def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
        super().__init__()
        self.forward_mamba = Mamba(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand)
        self.backward_mamba = Mamba(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand)
        self.proj = nn.Linear(2 * d_model, d_model)

    def forward(self, x):
        x_forward = self.forward_mamba(x)
        x_backward = self.backward_mamba(torch.flip(x, dims=[1]))
        x_backward = torch.flip(x_backward, dims=[1])
        x = torch.cat([x_forward, x_backward], dim=-1)
        return self.proj(x)

class PatchEncoder(nn.Module):
    def __init__(self, patch_size=2, in_channels=3, dim=256):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Linear(patch_size * patch_size * in_channels, dim)

    def forward(self, x):
        B, C, H, W = x.shape
        p = self.patch_size
        x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
        return self.proj(x)

class LearnablePositionTokens(nn.Module):
    def __init__(self, num_tokens=256, dim=256, use_sinusoidal=True):
        super().__init__()
        if use_sinusoidal:
            init_tokens = get_sinusoidal_embeddings(num_tokens, dim)
        else:
            init_tokens = torch.randn(num_tokens, dim) * 0.02
        self.tokens = nn.Parameter(init_tokens, requires_grad=True)

    def forward(self, B):
        return repeat(self.tokens, 'n d -> b n d', b=B)

class LatentProcessor(nn.Module):
    def __init__(self, dim, num_layers=2, heads=8, dim_head=64):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.ModuleList([
                PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head)),
                PreNorm(dim, FeedForward(dim, mult=4))
            ])
            for _ in range(num_layers)
        ])

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

print("✓ Encoder components defined")

In [None]:
# Cell 5: SCENT-Style Decoder (FIXED)

class LAINRDecoderSCENT(nn.Module):
    def __init__(self, feature_dim=64, input_dim=2, output_dim=3,
                 sigma_q=16, sigma_ls=[128, 32], n_patches=256,
                 hidden_dim=512, context_dim=256,
                 learnable_frequencies=True, num_layers=3,
                 heads=8, dim_head=64):
        super().__init__()

        self.layer_num = len(sigma_ls)
        self.n_features = feature_dim // 2
        self.patch_num = int(math.sqrt(n_patches))
        self.alpha = 10.0
        self.num_layers = num_layers

        B_q_init = torch.randn(self.n_features, input_dim) / sigma_q
        B_ls_init = [torch.randn(self.n_features, input_dim) / sigma_ls[i]
                     for i in range(self.layer_num)]

        if learnable_frequencies:
            self.B_q = nn.Parameter(B_q_init)
            self.B_ls = nn.ParameterList([nn.Parameter(B_ls_init[i]) for i in range(self.layer_num)])
        else:
            self.register_buffer('B_q', B_q_init)
            for i in range(self.layer_num):
                self.register_buffer(f'B_l_{i}', B_ls_init[i])
            self.B_ls = [getattr(self, f'B_l_{i}') for i in range(self.layer_num)]

        self.query_lin = nn.Linear(feature_dim, hidden_dim)
        self.modulation_ca = PreNorm(hidden_dim, Attention(hidden_dim, context_dim, heads=2, dim_head=64), context_dim=context_dim)

        self.bandwidth_lins = nn.ModuleList([
            nn.Sequential(nn.Linear(feature_dim, hidden_dim), nn.ReLU(),
                *[nn.ModuleList([PreNorm(hidden_dim, Attention(hidden_dim, heads=heads, dim_head=dim_head)),
                                 PreNorm(hidden_dim, FeedForward(hidden_dim, mult=4))])
                  for _ in range(num_layers - 1)])
            for _ in range(self.layer_num)
        ])

        self.modulation_lins = nn.ModuleList([
            nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                *[nn.ModuleList([PreNorm(hidden_dim, Attention(hidden_dim, heads=heads, dim_head=dim_head)),
                                 PreNorm(hidden_dim, FeedForward(hidden_dim, mult=4))])
                  for _ in range(num_layers - 1)])
            for _ in range(self.layer_num)
        ])

        self.hv_lins = nn.ModuleList([
            nn.ModuleList([PreNorm(hidden_dim, Attention(hidden_dim, heads=heads, dim_head=dim_head)),
                          PreNorm(hidden_dim, FeedForward(hidden_dim, mult=4))])
            for _ in range(self.layer_num - 1)
        ])

        self.fourier_skip_projs = nn.ModuleList([nn.Linear(feature_dim, hidden_dim) for _ in range(self.layer_num)])
        self.out_lins = nn.ModuleList([nn.Linear(hidden_dim, output_dim) for _ in range(self.layer_num)])
        self.act = nn.ReLU()

    def get_patch_index(self, grid, H, W):
        y, x = grid[:, 0], grid[:, 1]
        row = (y * H).to(torch.int32).clamp(0, H-1)
        col = (x * W).to(torch.int32).clamp(0, W-1)
        return row * W + col

    def approximate_relative_distances(self, target_index, H, W, m):
        N = H * W
        t = target_index.float() / N
        token_positions = torch.tensor([(i + 0.5) / m for i in range(m)], device=target_index.device)
        t_expanded = t.unsqueeze(0)
        tokens_expanded = token_positions.unsqueeze(1)
        return -self.alpha * torch.abs(t_expanded - tokens_expanded)**2

    def apply_block_sequence(self, x, block_seq):
        if isinstance(block_seq[0], nn.Linear):
            x = block_seq[1](block_seq[0](x))
            start_idx = 2
        else:
            start_idx = 0
        for i in range(start_idx, len(block_seq)):
            block_list = block_seq[i]
            if isinstance(block_list, nn.ModuleList):
                attn, ff = block_list[0], block_list[1]
                x = attn(x) + x
                x = ff(x) + x
        return x

    def forward(self, coords_decoding, tokens, coords_modulation=None):
        B, query_shape = coords_decoding.shape[0], coords_decoding.shape[1:-1]
        coords_dec = coords_decoding.view(B, -1, coords_decoding.shape[-1])
        coords_mod = coords_modulation.view(B, -1, coords_modulation.shape[-1]) if coords_modulation is not None else coords_dec

        # FIXED: Use actual coordinate grid dimensions
        grid_mod = coords_mod[0]
        num_queries = grid_mod.shape[0]
        H_mod = W_mod = int(math.sqrt(num_queries))
        indexes = self.get_patch_index(grid_mod, H_mod, W_mod)
        rel_distances = self.approximate_relative_distances(indexes, H_mod, W_mod, tokens.shape[1])
        bias = repeat(rel_distances, 'l n -> b l n', b=B)

        x_q = repeat(gaussian_fourier_encode(coords_mod[0], self.B_q), 'l d -> b l d', b=B)
        x_q = self.act(self.query_lin(x_q))
        modulation_vector = self.modulation_ca(x_q, context=tokens, bias=bias)

        modulations_l, fourier_encodings = [], []
        for k in range(self.layer_num):
            x_l_fourier = gaussian_fourier_encode(coords_dec[0], self.B_ls[k])
            x_l_fourier_batch = repeat(x_l_fourier, 'l d -> b l d', b=B)
            fourier_encodings.append(x_l_fourier_batch)
            h_l = self.apply_block_sequence(x_l_fourier_batch, self.bandwidth_lins[k])
            m_proj = self.apply_block_sequence(modulation_vector, self.modulation_lins[k])
            modulations_l.append(self.act(h_l + m_proj))

        h_v = [modulations_l[0]]
        for i in range(self.layer_num - 1):
            x_combined = modulations_l[i+1] + h_v[i]
            attn, ff = self.hv_lins[i][0], self.hv_lins[i][1]
            x_combined = attn(x_combined) + x_combined
            x_combined = ff(x_combined) + x_combined
            h_v.append(x_combined)

        outs = []
        for i in range(self.layer_num):
            fourier_skip = self.fourier_skip_projs[i](fourier_encodings[i])
            outs.append(self.out_lins[i](h_v[i] + fourier_skip))

        out = sum(outs).view(B, *query_shape, -1)
        return out

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("✓ SCENT-style decoder defined (FIXED spatial bias)")

In [None]:
# Cell 6: Complete Model

class MambaGINR_SCENT(nn.Module):
    def __init__(self, patch_size=2, in_channels=3, dim=256, num_lp_tokens=256,
                 feature_dim=64, sigma_q=16, sigma_ls=[128, 32],
                 hidden_dim=512, num_layers=3, num_latent_layers=2):
        super().__init__()
        self.patch_encoder = PatchEncoder(patch_size=patch_size, in_channels=in_channels, dim=dim)
        self.lp_tokens = LearnablePositionTokens(num_tokens=num_lp_tokens, dim=dim, use_sinusoidal=True)
        self.mamba = BiMamba(d_model=dim)
        self.latent_processor = LatentProcessor(dim=dim, num_layers=num_latent_layers, heads=8, dim_head=64)
        self.num_patches = (32 // patch_size) ** 2
        self.hyponet = LAINRDecoderSCENT(
            feature_dim=feature_dim, input_dim=2, output_dim=3,
            sigma_q=sigma_q, sigma_ls=sigma_ls, n_patches=self.num_patches,
            hidden_dim=hidden_dim, context_dim=dim,
            learnable_frequencies=True, num_layers=num_layers, heads=8, dim_head=64
        )

    def encode(self, x):
        B = x.shape[0]
        patch_features = self.patch_encoder(x)
        lp_tokens = self.lp_tokens(B)
        combined = torch.cat([patch_features, lp_tokens], dim=1)
        features = self.mamba(combined)
        features = self.latent_processor(features)
        return features[:, -lp_tokens.shape[1]:, :]

    def forward(self, x, coords_decoding, coords_modulation=None):
        lp_features = self.encode(x)
        return self.hyponet(coords_decoding, lp_features, coords_modulation)

print("✓ Complete model defined")

In [None]:
# Cell 7: Training Functions

def adjust_learning_rate(optimizer, epoch, base_lr=5e-4, warmup_epochs=5, max_epoch=40):
    min_lr = 1e-8
    if epoch < warmup_epochs:
        lr = base_lr * (epoch + 1) / warmup_epochs
    else:
        t = (epoch - warmup_epochs) / (max_epoch - warmup_epochs)
        lr = min_lr + 0.5 * (base_lr - min_lr) * (1 + np.cos(np.pi * t))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

def train_epoch(model, loader, optimizer, device, epoch, resolution=32):
    model.train()
    total_loss, total_psnr = 0, 0
    jitter_std = (1.0 / resolution) / 6

    pbar = tqdm(loader, desc=f"Epoch {epoch}")
    for images, _ in pbar:
        images = images.to(device)
        B = images.shape[0]
        base_coords = create_coordinate_grid(resolution, resolution, device)
        jitter = torch.randn_like(base_coords) * jitter_std
        coords = (base_coords + jitter).clamp(0, 1)
        coords_batch = repeat(coords, 'h w d -> b h w d', b=B)

        pred = model(images, coords_batch, coords_batch)
        gt = rearrange(images, 'b c h w -> b h w c')
        mses = ((pred - gt)**2).view(B, -1).mean(dim=-1)
        loss = mses.mean()
        psnr = (-10 * torch.log10(mses)).mean()

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()
        total_psnr += psnr.item()
        pbar.set_postfix({'loss': f"{loss.item():.4f}", 'psnr': f"{psnr.item():.2f}"})

    return total_loss / len(loader), total_psnr / len(loader)

def validate(model, loader, device, resolution=32):
    model.eval()
    total_loss, total_psnr = 0, 0

    with torch.no_grad():
        for images, _ in tqdm(loader, desc="Validation"):
            images = images.to(device)
            B = images.shape[0]
            coords = create_coordinate_grid(resolution, resolution, device)
            coords_batch = repeat(coords, 'h w d -> b h w d', b=B)
            pred = model(images, coords_batch, None)
            gt = rearrange(images, 'b c h w -> b h w c')
            mses = ((pred - gt)**2).view(B, -1).mean(dim=-1)
            total_loss += mses.mean().item()
            total_psnr += (-10 * torch.log10(mses)).mean().item()

    return total_loss / len(loader), total_psnr / len(loader)

def super_resolve(model, images, target_resolution=128, device='cpu'):
    model.eval()
    with torch.no_grad():
        B = images.shape[0]
        lp_features = model.encode(images)
        coords = create_coordinate_grid(target_resolution, target_resolution, device)
        coords_batch = repeat(coords, 'h w d -> b h w d', b=B)
        pred = model.hyponet(coords_batch, lp_features, None)
        return rearrange(pred, 'b h w c -> b c h w')

print("✓ Training functions defined")

In [None]:
# Cell 8: Data Loading

transform = transforms.Compose([transforms.ToTensor()])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)

print(f"Training: {len(train_dataset)}, Test: {len(test_dataset)}")

In [None]:
# Cell 9: Model Initialization

model = MambaGINR_SCENT(
    patch_size=2, in_channels=3, dim=256, num_lp_tokens=256,
    feature_dim=64, sigma_q=16, sigma_ls=[128, 32],
    hidden_dim=512, num_layers=3, num_latent_layers=2
).to(device)

print(f"Total parameters: {count_parameters(model):,}")
print(f"Decoder parameters: {count_parameters(model.hyponet):,}")
print(f"Expected: +5-10 dB PSNR at 128×128")

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4)

In [None]:
# Cell 10: Training Loop

num_epochs = 40
best_val_psnr = 0

for epoch in range(num_epochs):
    lr = adjust_learning_rate(optimizer, epoch, base_lr=5e-4, max_epoch=num_epochs)
    train_loss, train_psnr = train_epoch(model, train_loader, optimizer, device, epoch+1)
    val_loss, val_psnr = validate(model, test_loader, device)
    
    print(f"\nEpoch {epoch+1}/{num_epochs} | LR: {lr:.6f}")
    print(f"  Train - Loss: {train_loss:.4f}, PSNR: {train_psnr:.2f} dB")
    print(f"  Val   - Loss: {val_loss:.4f}, PSNR: {val_psnr:.2f} dB")
    
    if val_psnr > best_val_psnr:
        best_val_psnr = val_psnr
        torch.save(model.state_dict(), 'mamba_ginr_scent_best.pth')
        print(f"  → Best saved (PSNR: {best_val_psnr:.2f} dB)")

print(f"\nBest validation PSNR: {best_val_psnr:.2f} dB")

In [None]:
# Cell 11: Super-Resolution Test

model.load_state_dict(torch.load('mamba_ginr_scent_best.pth'))
test_images, _ = next(iter(test_loader))
test_images = test_images[:8].to(device)

sr_64 = super_resolve(model, test_images, 64, device)
sr_128 = super_resolve(model, test_images, 128, device)
sr_256 = super_resolve(model, test_images, 256, device)

fig, axes = plt.subplots(4, 8, figsize=(20, 10))
for i in range(8):
    axes[0, i].imshow(test_images[i].cpu().permute(1, 2, 0).clamp(0, 1))
    axes[1, i].imshow(sr_64[i].cpu().permute(1, 2, 0).clamp(0, 1))
    axes[2, i].imshow(sr_128[i].cpu().permute(1, 2, 0).clamp(0, 1))
    axes[3, i].imshow(sr_256[i].cpu().permute(1, 2, 0).clamp(0, 1))
    for j in range(4):
        axes[j, i].axis('off')

labels = ['32×32', '64×64 SR', '128×128 SR', '256×256 SR']
for j, label in enumerate(labels):
    axes[j, 0].set_ylabel(label, fontsize=12, rotation=0, labelpad=30)

plt.suptitle('SCENT-Style Decoder: Super-Resolution Results', fontsize=16)
plt.tight_layout()
plt.savefig('scent_sr_results.png', dpi=150, bbox_inches='tight')
plt.show()

print("✓ Results saved as 'scent_sr_results.png'")