# MAMBA-GINR on CIFAR-10

## Original MAMBA-GINR Architecture

This notebook implements the exact MAMBA-GINR architecture as described in the original paper.

### Architecture:
1. **Patch Encoder**: Convert 32×32 image to 16×16 patches (2×2 patch size)
2. **Learnable Position Tokens**: 256 LP tokens with equidistant insertion
3. **BiMamba Encoder**: 6 layers of bidirectional Mamba
4. **LAINR Decoder**: Hyponetwork for arbitrary-resolution decoding

### Target: 60 PSNR on CIFAR-10 (32×32 reconstruction)

### Key Implementation Details:
- dim=256, num_lp=256, mamba_depth=6
- Fourier features for positional encoding
- ResidualBlocks in decoder (original LAINR)
- AdamW optimizer with cosine annealing
- Batch size=64, lr=5e-4, 40 epochs

In [None]:
# Cell 1: Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from einops import rearrange, repeat
import math

from mamba_ssm import Mamba
from mamba_ssm.modules.block import Block

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
print(f"PyTorch version: {torch.__version__}")

In [None]:
# Cell 2: Helper Functions

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def fourier_encode(coords, n_features=32, std=10.0):
    """
    Fourier feature encoding for coordinates
    Args:
        coords: (N, 2) coordinates in [0, 1]
        n_features: number of frequency components
        std: standard deviation of frequency distribution
    Returns:
        features: (N, 2*n_features) [cos, sin]
    """
    B = torch.randn(n_features, 2, device=coords.device) * std
    proj = 2 * np.pi * coords @ B.T
    return torch.cat([torch.cos(proj), torch.sin(proj)], dim=-1)

def create_coordinate_grid(H, W, device='cpu'):
    """Create normalized coordinate grid in [0,1]"""
    y = torch.linspace(0, 1, H, device=device)
    x = torch.linspace(0, 1, W, device=device)
    yy, xx = torch.meshgrid(y, x, indexing='ij')
    return torch.stack([yy, xx], dim=-1)  # (H, W, 2)

def get_sinusoidal_embeddings(n, d):
    """Sinusoidal positional embeddings"""
    assert d % 2 == 0
    position = torch.arange(n, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d, 2).float() * -(math.log(10000.0) / d))
    pe = torch.zeros(n, d)
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe

print("✓ Helper functions defined")

In [None]:
# Cell 3: BiMamba

class BiMamba(nn.Module):
    """Bidirectional Mamba from MAMBA-GINR"""
    def __init__(self, d_model=256, d_state=16, d_conv=4, expand=2):
        super().__init__()
        self.f_mamba = Mamba(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand)
        self.r_mamba = Mamba(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand)
        self.proj = nn.Linear(2 * d_model, d_model)
    
    def forward(self, x, inference_params=None):
        x_forward = self.f_mamba(x, inference_params=inference_params)
        x_backward = self.r_mamba(torch.flip(x, dims=[1]), inference_params=inference_params)
        x_backward = torch.flip(x_backward, dims=[1])
        x = torch.cat([x_forward, x_backward], dim=-1)
        return self.proj(x)

print("✓ BiMamba defined")

In [None]:
# Cell 4: Mamba Encoder

class MambaEncoder(nn.Module):
    """Stack of Mamba blocks"""
    def __init__(self, depth=6, dim=256, ff_dim=1024, dropout=0.0):
        super().__init__()
        self.blocks = nn.ModuleList([
            Block(
                dim=dim,
                mixer_cls=lambda d: BiMamba(d_model=d),
                mlp_cls=lambda d: nn.Sequential(
                    nn.Linear(d, ff_dim),
                    nn.GELU(),
                    nn.Dropout(dropout),
                    nn.Linear(ff_dim, d),
                    nn.Dropout(dropout),
                ),
                norm_cls=nn.LayerNorm,
                fused_add_norm=False
            )
            for _ in range(depth)
        ])
    
    def forward(self, x):
        residual = None
        for block in self.blocks:
            x, residual = block(x, residual=residual, inference_params=None)
        return x

print("✓ MambaEncoder defined")

In [None]:
# Cell 5: Learnable Position Tokens

class LearnablePositionTokens(nn.Module):
    """LP tokens with sinusoidal initialization and equidistant placement"""
    def __init__(self, num_tokens=256, dim=256, input_len=256):
        super().__init__()
        self.num_tokens = num_tokens
        self.dim = dim
        self.input_len = input_len
        
        # Initialize with sinusoidal embeddings
        init_tokens = get_sinusoidal_embeddings(num_tokens, dim)
        self.tokens = nn.Parameter(init_tokens, requires_grad=True)
        
        # Compute equidistant placement
        total_len = input_len + num_tokens
        self.lp_idxs = torch.linspace(0, total_len - 1, steps=num_tokens).long()
        
        # Compute interleave permutation
        perm = torch.full((total_len,), -1, dtype=torch.long)
        perm[self.lp_idxs] = torch.arange(input_len, input_len + num_tokens)
        perm[perm == -1] = torch.arange(input_len)
        self.register_buffer('perm', perm)
    
    def add_lp(self, x):
        """Add LP tokens to input sequence"""
        B = x.shape[0]
        lps = repeat(self.tokens, 'n d -> b n d', b=B)
        x_full = torch.cat([x, lps], dim=1)  # (B, L+num_tokens, D)
        return x_full[:, self.perm]  # Interleave
    
    def extract_lp(self, x):
        """Extract LP tokens from encoded sequence"""
        return x[:, self.lp_idxs]

print("✓ LearnablePositionTokens defined")

In [None]:
# Cell 6: Original LAINR Decoder (ResidualBlock-based)

class ResidualBlock(nn.Module):
    """Simple residual block from original LAINR"""
    def __init__(self, dim=512):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, dim)
        )
    
    def forward(self, x):
        return x + self.net(x)


class LAINRDecoder(nn.Module):
    """
    Original LAINR-style decoder with ResidualBlocks
    
    Architecture:
    - Fourier encoding of coordinates
    - Cross-attention to LP tokens
    - ResidualBlocks for processing
    - Multi-scale outputs
    """
    def __init__(self, n_features=32, input_dim=2, output_dim=3,
                 hidden_dim=512, context_dim=256, n_patches=256):
        super().__init__()
        
        self.n_features = n_features
        self.patch_num = int(math.sqrt(n_patches))
        self.alpha = 10.0  # Spatial bias coefficient
        
        # Fourier encoding frequencies
        self.register_buffer('B', torch.randn(n_features, input_dim) * 10.0)
        feature_dim = 2 * n_features
        
        # Query encoding
        self.query_proj = nn.Linear(feature_dim, hidden_dim)
        
        # Cross-attention for modulation extraction
        self.to_q = nn.Linear(hidden_dim, hidden_dim)
        self.to_kv = nn.Linear(context_dim, hidden_dim * 2)
        self.attn_out = nn.Linear(hidden_dim, hidden_dim)
        self.scale = (hidden_dim // 2) ** -0.5
        
        # Decoder processing
        self.decoder_blocks = nn.Sequential(
            nn.Linear(feature_dim + hidden_dim, hidden_dim),
            nn.ReLU(),
            ResidualBlock(hidden_dim),
            ResidualBlock(hidden_dim),
            ResidualBlock(hidden_dim),
        )
        
        # Output projection
        self.output_proj = nn.Linear(hidden_dim, output_dim)
    
    def get_patch_index(self, coords, H, W):
        """Convert coordinates to patch indices"""
        y, x = coords[:, 0], coords[:, 1]
        row = (y * H).long().clamp(0, H-1)
        col = (x * W).long().clamp(0, W-1)
        return row * W + col
    
    def compute_spatial_bias(self, target_index, H, W, num_tokens):
        """Compute spatial bias for attention"""
        N = H * W
        t = target_index.float() / N
        token_positions = torch.linspace(0.5/num_tokens, 1 - 0.5/num_tokens, 
                                        num_tokens, device=target_index.device)
        distances = torch.abs(t.unsqueeze(0) - token_positions.unsqueeze(1))
        return -self.alpha * distances**2
    
    def cross_attention(self, queries, context, bias=None):
        """Cross-attention with optional spatial bias"""
        B, N, D = queries.shape
        
        q = self.to_q(queries)  # (B, N, D)
        k, v = self.to_kv(context).chunk(2, dim=-1)  # (B, L, D)
        
        sim = torch.einsum('bnd,bld->bnl', q, k) * self.scale
        
        if bias is not None:
            sim = sim + bias.unsqueeze(0)
        
        attn = sim.softmax(dim=-1)
        out = torch.einsum('bnl,bld->bnd', attn, v)
        return self.attn_out(out)
    
    def forward(self, coords, tokens):
        """
        Args:
            coords: (B, H, W, 2) query coordinates
            tokens: (B, L, D) LP token features
        Returns:
            rgb: (B, H, W, 3) predicted RGB values
        """
        B, H, W, _ = coords.shape
        coords_flat = coords.reshape(B, -1, 2)
        
        # Fourier encoding
        fourier_features = fourier_encode(coords_flat[0], self.n_features)
        fourier_features = repeat(fourier_features, 'n d -> b n d', b=B)
        
        # Query projection
        queries = F.relu(self.query_proj(fourier_features))
        
        # Spatial bias
        indices = self.get_patch_index(coords_flat[0], self.patch_num, self.patch_num)
        bias = self.compute_spatial_bias(indices, self.patch_num, self.patch_num, tokens.shape[1])
        
        # Extract modulation via cross-attention
        modulation = self.cross_attention(queries, tokens, bias)
        
        # Decode
        decoder_input = torch.cat([fourier_features, modulation], dim=-1)
        features = self.decoder_blocks(decoder_input)
        rgb = self.output_proj(features)
        
        return rgb.reshape(B, H, W, 3)

print("✓ Original LAINR decoder defined (ResidualBlock-based)")

In [None]:
# Cell 7: Complete MAMBA-GINR Model

class MambaGINR(nn.Module):
    """Original MAMBA-GINR architecture"""
    
    def __init__(self,
                 img_size=32,
                 patch_size=2,
                 in_channels=3,
                 dim=256,
                 num_lp=256,
                 mamba_depth=6,
                 ff_dim=1024,
                 hidden_dim=512,
                 n_features=32):
        super().__init__()
        
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_num = img_size // patch_size
        
        # Patch embedding
        self.patch_embed = nn.Linear(patch_size * patch_size * in_channels, dim)
        
        # Patch positional encoding
        self.register_buffer('pos_freq', torch.randn(dim // 2, 2) * 10.0)
        self.pos_proj = nn.Linear(dim, dim)
        
        # Learnable position tokens
        self.lp_module = LearnablePositionTokens(
            num_tokens=num_lp,
            dim=dim,
            input_len=self.num_patches
        )
        
        # Mamba encoder
        self.encoder = MambaEncoder(
            depth=mamba_depth,
            dim=dim,
            ff_dim=ff_dim
        )
        
        # LAINR decoder
        self.decoder = LAINRDecoder(
            n_features=n_features,
            input_dim=2,
            output_dim=3,
            hidden_dim=hidden_dim,
            context_dim=dim,
            n_patches=self.num_patches
        )
    
    def patchify(self, images):
        """Convert images to patches"""
        B, C, H, W = images.shape
        p = self.patch_size
        patches = rearrange(images, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
        return patches
    
    def get_patch_positions(self, B, device):
        """Get normalized patch center positions"""
        h = w = self.patch_num
        y = torch.linspace(0.5/h, 1 - 0.5/h, h, device=device)
        x = torch.linspace(0.5/w, 1 - 0.5/w, w, device=device)
        yy, xx = torch.meshgrid(y, x, indexing='ij')
        positions = torch.stack([yy, xx], dim=-1).reshape(-1, 2)
        return positions.unsqueeze(0).expand(B, -1, -1)
    
    def fourier_pos_encoding(self, positions):
        """Fourier positional encoding for patches"""
        proj = 2 * np.pi * positions @ self.pos_freq.T
        encoding = torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)
        return self.pos_proj(encoding)
    
    def encode(self, images):
        """Encode images to LP features"""
        B = images.shape[0]
        
        # Patchify
        patches = self.patchify(images)  # (B, num_patches, patch_dim)
        tokens = self.patch_embed(patches)  # (B, num_patches, dim)
        
        # Add positional encoding
        positions = self.get_patch_positions(B, images.device)
        pos_encoding = self.fourier_pos_encoding(positions)
        tokens = tokens + pos_encoding
        
        # Add LP tokens
        tokens_with_lp = self.lp_module.add_lp(tokens)
        
        # Encode with Mamba
        encoded = self.encoder(tokens_with_lp)
        
        # Extract LP features
        lp_features = self.lp_module.extract_lp(encoded)
        
        return lp_features
    
    def decode(self, lp_features, coords):
        """Decode LP features to RGB at coordinates"""
        return self.decoder(coords, lp_features)
    
    def forward(self, images, coords):
        """Full forward pass"""
        lp_features = self.encode(images)
        return self.decode(lp_features, coords)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("✓ Complete MAMBA-GINR model defined")

In [None]:
# Cell 8: Training Functions

def adjust_learning_rate(optimizer, epoch, base_lr=5e-4, warmup_epochs=5, max_epoch=40):
    """Learning rate schedule: warmup + cosine annealing"""
    min_lr = 1e-8
    if epoch < warmup_epochs:
        lr = base_lr * (epoch + 1) / warmup_epochs
    else:
        t = (epoch - warmup_epochs) / (max_epoch - warmup_epochs)
        lr = min_lr + 0.5 * (base_lr - min_lr) * (1 + np.cos(np.pi * t))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

def train_epoch(model, loader, optimizer, device, epoch, resolution=32):
    """Training for one epoch"""
    model.train()
    total_loss, total_psnr = 0, 0
    
    pbar = tqdm(loader, desc=f"Epoch {epoch}")
    for images, _ in pbar:
        images = images.to(device)
        B = images.shape[0]
        
        # Query coordinates (exact pixel centers)
        coords = create_coordinate_grid(resolution, resolution, device)
        coords_batch = repeat(coords, 'h w d -> b h w d', b=B)
        
        # Forward
        pred = model(images, coords_batch)
        gt = rearrange(images, 'b c h w -> b h w c')
        
        # Loss
        mses = ((pred - gt)**2).view(B, -1).mean(dim=-1)
        loss = mses.mean()
        psnr = (-10 * torch.log10(mses)).mean()
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        total_psnr += psnr.item()
        pbar.set_postfix({'loss': f"{loss.item():.4f}", 'psnr': f"{psnr.item():.2f}"})
    
    return total_loss / len(loader), total_psnr / len(loader)

def validate(model, loader, device, resolution=32):
    """Validation"""
    model.eval()
    total_loss, total_psnr = 0, 0
    
    with torch.no_grad():
        for images, _ in tqdm(loader, desc="Validation"):
            images = images.to(device)
            B = images.shape[0]
            
            coords = create_coordinate_grid(resolution, resolution, device)
            coords_batch = repeat(coords, 'h w d -> b h w d', b=B)
            
            pred = model(images, coords_batch)
            gt = rearrange(images, 'b c h w -> b h w c')
            
            mses = ((pred - gt)**2).view(B, -1).mean(dim=-1)
            total_loss += mses.mean().item()
            total_psnr += (-10 * torch.log10(mses)).mean().item()
    
    return total_loss / len(loader), total_psnr / len(loader)

def super_resolve(model, images, target_resolution=128, device='cpu'):
    """Generate at arbitrary resolution"""
    model.eval()
    with torch.no_grad():
        B = images.shape[0]
        lp_features = model.encode(images)
        coords = create_coordinate_grid(target_resolution, target_resolution, device)
        coords_batch = repeat(coords, 'h w d -> b h w d', b=B)
        pred = model.decode(lp_features, coords_batch)
        return rearrange(pred, 'b h w c -> b c h w')

print("✓ Training functions defined")

In [None]:
# Cell 9: Data Loading

transform = transforms.Compose([transforms.ToTensor()])
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, 
                         num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, 
                        num_workers=4, pin_memory=True)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

# Visualize
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
for i, ax in enumerate(axes.flat):
    img, label = train_dataset[i]
    ax.imshow(img.permute(1, 2, 0))
    ax.set_title(f"Label: {label}")
    ax.axis('off')
plt.tight_layout()
plt.show()

In [None]:
# Cell 10: Model Initialization

model = MambaGINR(
    img_size=32,
    patch_size=2,
    in_channels=3,
    dim=256,
    num_lp=256,
    mamba_depth=6,
    ff_dim=1024,
    hidden_dim=512,
    n_features=32
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4)

total_params = count_parameters(model)
encoder_params = count_parameters(model.encoder)
decoder_params = count_parameters(model.decoder)

print(f"\n{'='*60}")
print(f"MAMBA-GINR ARCHITECTURE")
print(f"{'='*60}")
print(f"Total parameters: {total_params:,}")
print(f"  - Encoder (BiMamba):  {encoder_params:,}")
print(f"  - Decoder (LAINR):    {decoder_params:,}")
print(f"\nConfiguration:")
print(f"  - Image size: 32×32")
print(f"  - Patch size: 2×2")
print(f"  - Num patches: {model.num_patches}")
print(f"  - LP tokens: {model.lp_module.num_tokens}")
print(f"  - Hidden dim: 256")
print(f"  - Mamba depth: 6")
print(f"\nTarget: 60 PSNR on CIFAR-10")
print(f"{'='*60}\n")

In [None]:
# Cell 11: Training Loop

num_epochs = 40
best_val_psnr = 0
train_losses, train_psnrs = [], []
val_losses, val_psnrs = [], []

print(f"Training for {num_epochs} epochs\n")

for epoch in range(num_epochs):
    lr = adjust_learning_rate(optimizer, epoch, base_lr=5e-4, max_epoch=num_epochs)
    train_loss, train_psnr = train_epoch(model, train_loader, optimizer, device, epoch+1)
    val_loss, val_psnr = validate(model, test_loader, device)
    
    train_losses.append(train_loss)
    train_psnrs.append(train_psnr)
    val_losses.append(val_loss)
    val_psnrs.append(val_psnr)
    
    print(f"\nEpoch {epoch+1}/{num_epochs} | LR: {lr:.6f}")
    print(f"  Train - Loss: {train_loss:.4f}, PSNR: {train_psnr:.2f} dB")
    print(f"  Val   - Loss: {val_loss:.4f}, PSNR: {val_psnr:.2f} dB")
    
    if val_psnr > best_val_psnr:
        best_val_psnr = val_psnr
        torch.save(model.state_dict(), 'mamba_ginr_best.pth')
        print(f"  → Best model saved (PSNR: {best_val_psnr:.2f} dB)")
    
    # Plot every 5 epochs
    if (epoch + 1) % 5 == 0:
        fig, axes = plt.subplots(1, 3, figsize=(15, 4))
        
        axes[0].plot(train_losses, label='Train')
        axes[0].plot(val_losses, label='Val')
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('MSE Loss')
        axes[0].set_yscale('log')
        axes[0].legend()
        axes[0].grid(True)
        axes[0].set_title('Loss')
        
        axes[1].plot(train_psnrs, label='Train')
        axes[1].plot(val_psnrs, label='Val')
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('PSNR (dB)')
        axes[1].legend()
        axes[1].grid(True)
        axes[1].set_title('PSNR')
        
        # Sample reconstruction
        model.eval()
        with torch.no_grad():
            sample, _ = next(iter(test_loader))
            sample = sample[:1].to(device)
            coords = create_coordinate_grid(32, 32, device)
            coords_batch = coords.unsqueeze(0)
            pred = model(sample, coords_batch)
            comparison = torch.cat([
                sample[0].cpu().permute(1, 2, 0),
                pred[0].cpu().clamp(0, 1)
            ], dim=1)
            axes[2].imshow(comparison)
            axes[2].set_title('Original | Reconstruction')
            axes[2].axis('off')
        model.train()
        
        plt.tight_layout()
        plt.show()

print(f"\n{'='*60}")
print(f"Training complete!")
print(f"Best validation PSNR: {best_val_psnr:.2f} dB")
print(f"{'='*60}")

In [None]:
# Cell 12: Super-Resolution Test

# Load best model
model.load_state_dict(torch.load('mamba_ginr_best.pth'))
model.eval()

# Get test samples
test_images, _ = next(iter(test_loader))
test_images = test_images[:8].to(device)

print("Generating super-resolution outputs...\n")

# Generate at multiple resolutions
sr_64 = super_resolve(model, test_images, 64, device)
sr_128 = super_resolve(model, test_images, 128, device)
sr_256 = super_resolve(model, test_images, 256, device)

print(f"Original: {test_images.shape}")
print(f"SR 64×64: {sr_64.shape}")
print(f"SR 128×128: {sr_128.shape}")
print(f"SR 256×256: {sr_256.shape}")

# Visualize
fig, axes = plt.subplots(4, 8, figsize=(20, 10))
for i in range(8):
    axes[0, i].imshow(test_images[i].cpu().permute(1, 2, 0).clamp(0, 1))
    axes[1, i].imshow(sr_64[i].cpu().permute(1, 2, 0).clamp(0, 1))
    axes[2, i].imshow(sr_128[i].cpu().permute(1, 2, 0).clamp(0, 1))
    axes[3, i].imshow(sr_256[i].cpu().permute(1, 2, 0).clamp(0, 1))
    for j in range(4):
        axes[j, i].axis('off')

labels = ['32×32', '64×64 SR', '128×128 SR', '256×256 SR']
for j, label in enumerate(labels):
    axes[j, 0].set_ylabel(label, fontsize=12, rotation=0, labelpad=40)

plt.suptitle('MAMBA-GINR: Super-Resolution Results', fontsize=16)
plt.tight_layout()
plt.savefig('mamba_ginr_super_resolution.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n✓ Results saved as 'mamba_ginr_super_resolution.png'")

In [None]:
# Cell 13: Final Analysis

print("="*70)
print("MAMBA-GINR CIFAR-10 Results")
print("="*70)
print(f"\nFinal Metrics:")
print(f"  - Best validation PSNR: {best_val_psnr:.2f} dB")
print(f"  - Final train PSNR: {train_psnrs[-1]:.2f} dB")
print(f"  - Final val PSNR: {val_psnrs[-1]:.2f} dB")
print(f"\nArchitecture Summary:")
print(f"  - BiMamba encoder: {encoder_params:,} params")
print(f"  - LAINR decoder: {decoder_params:,} params")
print(f"  - Total: {total_params:,} params")
print(f"\nKey Features:")
print(f"  ✓ Learnable position tokens (implicit sequential bias)")
print(f"  ✓ Bidirectional Mamba (O(L) complexity)")
print(f"  ✓ Arbitrary-resolution generation")
print(f"  ✓ ResidualBlock-based decoder")
print(f"\nTarget: 60 PSNR (requires perfect reconstruction)")
print(f"Achieved: {best_val_psnr:.2f} dB")
print("="*70)