# Super-Resolution Diagnostics for MAMBA-GINR

This notebook diagnoses WHY super-resolution might be sub-optimal.

**Key insight**: True super-resolution CAN work at single training resolution through:
1. High-frequency Fourier features (σ_l = [128, 32])
2. MLP learning to generate textures from semantic features
3. Proper jittering for continuous coordinate learning

We'll test which component is the bottleneck.

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from scipy.fft import fft2, fftshift
import seaborn as sns

# Assuming you've already trained the model and loaded it
# model = ... (your trained MAMBA-GINR model)
# test_images = ... (test images)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Diagnostic 1: Fourier Feature Activation Analysis

Check if the network is actually USING the high-frequency Fourier features

In [None]:
def analyze_fourier_feature_usage(model, test_image, resolutions=[32, 64, 128]):
    """
    Analyze which Fourier frequencies are being utilized by the network
    """
    model.eval()
    
    img = test_image[:1].to(device)
    lp_features = model.encode(img)
    
    fig, axes = plt.subplots(2, len(resolutions), figsize=(15, 8))
    
    for idx, res in enumerate(resolutions):
        # Create coordinate grid
        coord = create_coordinate_grid(res, res, device).unsqueeze(0)
        
        # Hook to capture Fourier features
        fourier_features_list = []
        
        def hook_fn(module, input, output):
            # Capture activations from bandwidth encoding
            fourier_features_list.append(input[0].detach())
        
        # Register hooks on bandwidth_lins layers
        hooks = []
        for layer in model.hyponet.bandwidth_lins:
            hooks.append(layer.register_forward_hook(hook_fn))
        
        # Forward pass
        with torch.no_grad():
            output = model.decode(lp_features, coord)
        
        # Remove hooks
        for hook in hooks:
            hook.remove()
        
        # Analyze frequency usage for each scale
        for scale_idx, fourier_feats in enumerate(fourier_features_list):
            # fourier_feats: (B, HW, feature_dim)
            # Compute activation magnitude per feature dimension
            activation_magnitude = fourier_feats[0].abs().mean(dim=0).cpu().numpy()
            
            # Plot activation vs frequency
            ax = axes[scale_idx, idx]
            
            # Get corresponding omegas
            sigma = [128, 32][scale_idx]
            n = len(activation_magnitude) // 4  # 2 coords * 2 (sin/cos)
            omegas = torch.logspace(1, np.log10(sigma), n).numpy()
            
            # Plot for each coordinate's sin/cos pairs
            ax.plot(omegas, activation_magnitude[:n], label='X-sin', alpha=0.7)
            ax.plot(omegas, activation_magnitude[n:2*n], label='X-cos', alpha=0.7)
            ax.plot(omegas, activation_magnitude[2*n:3*n], label='Y-sin', alpha=0.7)
            ax.plot(omegas, activation_magnitude[3*n:4*n], label='Y-cos', alpha=0.7)
            
            ax.set_xlabel('Frequency ω')
            ax.set_ylabel('Activation Magnitude')
            ax.set_title(f'{res}×{res}, σ={sigma}')
            ax.set_xscale('log')
            ax.grid(True, alpha=0.3)
            if idx == 0:
                ax.legend(fontsize=8)
    
    plt.tight_layout()
    plt.savefig('fourier_feature_activation.png', dpi=150)
    plt.show()
    
    print("\n" + "="*70)
    print("FOURIER FEATURE ACTIVATION ANALYSIS")
    print("="*70)
    print("\nWhat to look for:")
    print("  ✓ High activations at HIGH frequencies (ω > 50) → Network uses high-freq")
    print("  ✗ Low activations at high frequencies → Network ignores them (resampling)")
    print("\nIf activations drop off before max frequency (σ), the network is NOT")
    print("utilizing the full high-frequency capacity!")
    print("="*70)

# Run analysis
analyze_fourier_feature_usage(model, test_images)

## Diagnostic 2: Gradient Magnitude Analysis

Compare spatial gradients in super-resolved images vs bicubic upsampling.
Sharp edges = high gradients = true super-resolution

In [None]:
def compute_gradient_magnitude(img):
    """
    Compute spatial gradient magnitude
    img: (B, C, H, W)
    """
    # Sobel filters
    sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], 
                           dtype=torch.float32, device=img.device).view(1, 1, 3, 3)
    sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], 
                           dtype=torch.float32, device=img.device).view(1, 1, 3, 3)
    
    # Apply to each channel
    grad_x = F.conv2d(img, sobel_x.repeat(img.shape[1], 1, 1, 1), 
                      padding=1, groups=img.shape[1])
    grad_y = F.conv2d(img, sobel_y.repeat(img.shape[1], 1, 1, 1), 
                      padding=1, groups=img.shape[1])
    
    # Magnitude
    grad_mag = torch.sqrt(grad_x**2 + grad_y**2)
    
    return grad_mag


def analyze_gradient_sharpness(model, test_images, device, num_samples=8):
    """
    Compare gradient distributions: model vs bicubic
    """
    model.eval()
    
    gradients_32_model = []
    gradients_64_model = []
    gradients_128_model = []
    gradients_64_bicubic = []
    gradients_128_bicubic = []
    
    with torch.no_grad():
        for i in range(num_samples):
            img = test_images[i:i+1].to(device)
            
            # Model reconstructions
            coord_32 = create_coordinate_grid(32, 32, device).unsqueeze(0)
            recon_32 = model(img, coord_32)
            recon_32 = einops.rearrange(recon_32, 'b h w c -> b c h w')
            
            coord_64 = create_coordinate_grid(64, 64, device).unsqueeze(0)
            recon_64 = model(img, coord_64)
            recon_64 = einops.rearrange(recon_64, 'b h w c -> b c h w')
            
            coord_128 = create_coordinate_grid(128, 128, device).unsqueeze(0)
            recon_128 = model(img, coord_128)
            recon_128 = einops.rearrange(recon_128, 'b h w c -> b c h w')
            
            # Bicubic upsampling
            bicubic_64 = F.interpolate(img, size=64, mode='bicubic', align_corners=False)
            bicubic_128 = F.interpolate(img, size=128, mode='bicubic', align_corners=False)
            
            # Compute gradients
            gradients_32_model.append(compute_gradient_magnitude(recon_32).flatten())
            gradients_64_model.append(compute_gradient_magnitude(recon_64).flatten())
            gradients_128_model.append(compute_gradient_magnitude(recon_128).flatten())
            gradients_64_bicubic.append(compute_gradient_magnitude(bicubic_64).flatten())
            gradients_128_bicubic.append(compute_gradient_magnitude(bicubic_128).flatten())
    
    # Concatenate all samples
    grad_32_model = torch.cat(gradients_32_model).cpu().numpy()
    grad_64_model = torch.cat(gradients_64_model).cpu().numpy()
    grad_128_model = torch.cat(gradients_128_model).cpu().numpy()
    grad_64_bicubic = torch.cat(gradients_64_bicubic).cpu().numpy()
    grad_128_bicubic = torch.cat(gradients_128_bicubic).cpu().numpy()
    
    # Plot distributions
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # 32×32 (reference)
    axes[0].hist(grad_32_model, bins=100, alpha=0.7, density=True, label='32×32')
    axes[0].set_xlabel('Gradient Magnitude')
    axes[0].set_ylabel('Density')
    axes[0].set_title('32×32 (Training Resolution)')
    axes[0].set_yscale('log')
    axes[0].grid(True, alpha=0.3)
    axes[0].legend()
    
    # 64×64
    axes[1].hist(grad_64_bicubic, bins=100, alpha=0.5, density=True, 
                label='Bicubic', color='gray')
    axes[1].hist(grad_64_model, bins=100, alpha=0.7, density=True, 
                label='MAMBA-GINR', color='steelblue')
    axes[1].set_xlabel('Gradient Magnitude')
    axes[1].set_ylabel('Density')
    axes[1].set_title('64×64 Super-Resolution')
    axes[1].set_yscale('log')
    axes[1].grid(True, alpha=0.3)
    axes[1].legend()
    
    # 128×128
    axes[2].hist(grad_128_bicubic, bins=100, alpha=0.5, density=True, 
                label='Bicubic', color='gray')
    axes[2].hist(grad_128_model, bins=100, alpha=0.7, density=True, 
                label='MAMBA-GINR', color='steelblue')
    axes[2].set_xlabel('Gradient Magnitude')
    axes[2].set_ylabel('Density')
    axes[2].set_title('128×128 Super-Resolution')
    axes[2].set_yscale('log')
    axes[2].grid(True, alpha=0.3)
    axes[2].legend()
    
    plt.tight_layout()
    plt.savefig('gradient_sharpness_analysis.png', dpi=150)
    plt.show()
    
    # Compute statistics
    print("\n" + "="*70)
    print("GRADIENT MAGNITUDE ANALYSIS")
    print("="*70)
    
    def compute_stats(grads, name):
        mean = grads.mean()
        median = np.median(grads)
        p95 = np.percentile(grads, 95)
        p99 = np.percentile(grads, 99)
        print(f"\n{name}:")
        print(f"  Mean:   {mean:.4f}")
        print(f"  Median: {median:.4f}")
        print(f"  95th percentile: {p95:.4f}")
        print(f"  99th percentile: {p99:.4f}")
        return mean, p95, p99
    
    stats_32 = compute_stats(grad_32_model, "32×32 (reference)")
    
    print("\n" + "-"*70)
    stats_64_model = compute_stats(grad_64_model, "64×64 MAMBA-GINR")
    stats_64_bicubic = compute_stats(grad_64_bicubic, "64×64 Bicubic")
    
    ratio_64_mean = stats_64_model[0] / (stats_64_bicubic[0] + 1e-10)
    ratio_64_p95 = stats_64_model[1] / (stats_64_bicubic[1] + 1e-10)
    
    print("\n" + "-"*70)
    stats_128_model = compute_stats(grad_128_model, "128×128 MAMBA-GINR")
    stats_128_bicubic = compute_stats(grad_128_bicubic, "128×128 Bicubic")
    
    ratio_128_mean = stats_128_model[0] / (stats_128_bicubic[0] + 1e-10)
    ratio_128_p95 = stats_128_model[1] / (stats_128_bicubic[1] + 1e-10)
    
    print("\n" + "="*70)
    print("SHARPNESS RATIO (Model / Bicubic):")
    print(f"  64×64  - Mean gradient: {ratio_64_mean:.3f}x")
    print(f"  64×64  - 95th %ile:     {ratio_64_p95:.3f}x")
    print(f"  128×128 - Mean gradient: {ratio_128_mean:.3f}x")
    print(f"  128×128 - 95th %ile:     {ratio_128_p95:.3f}x")
    
    print("\n📊 INTERPRETATION:")
    if ratio_128_mean > 1.5 and ratio_128_p95 > 1.5:
        print("  ✅ TRUE SUPER-RESOLUTION: Model generates sharper edges than bicubic")
    elif ratio_128_mean > 1.1:
        print("  ⚠️  MARGINAL: Model slightly sharper, but limited")
    else:
        print("  ❌ RESAMPLING: Model similar to bicubic (smooth interpolation)")
    
    print("="*70)
    
    return {
        'ratio_64_mean': ratio_64_mean,
        'ratio_128_mean': ratio_128_mean
    }

# Run analysis
gradient_analysis = analyze_gradient_sharpness(model, test_images, device)

## Diagnostic 3: Modulation Vector Analysis

Check if modulation vectors are diverse or just doing nearest-neighbor lookup

In [None]:
def analyze_modulation_diversity(model, test_image, resolution=128):
    """
    Analyze spatial variation in modulation vectors
    High diversity = continuous learning
    Low diversity = discrete lookup table
    """
    model.eval()
    
    img = test_image[:1].to(device)
    lp_features = model.encode(img)
    
    # Create coordinate grid
    coord = create_coordinate_grid(resolution, resolution, device).unsqueeze(0)
    
    # Hook to capture modulation vectors
    modulation_vectors = []
    
    def hook_fn(module, input, output):
        modulation_vectors.append(output.detach())
    
    # Register hook on modulation cross-attention
    hook = model.hyponet.modulation_ca.register_forward_hook(hook_fn)
    
    # Forward pass
    with torch.no_grad():
        output = model.decode(lp_features, coord)
    
    hook.remove()
    
    # Get modulation vectors: (B, HW, D)
    mod_vec = modulation_vectors[0][0].cpu().numpy()  # (HW, D)
    mod_vec = mod_vec.reshape(resolution, resolution, -1)  # (H, W, D)
    
    # Compute spatial gradients of modulation vectors
    grad_y = np.abs(mod_vec[1:, :, :] - mod_vec[:-1, :, :])  # (H-1, W, D)
    grad_x = np.abs(mod_vec[:, 1:, :] - mod_vec[:, :-1, :])  # (H, W-1, D)
    
    # Average gradient across feature dimensions
    grad_y_mag = grad_y.mean(axis=-1)  # (H-1, W)
    grad_x_mag = grad_x.mean(axis=-1)  # (H, W-1)
    
    # Visualize
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # First 3 modulation channels
    for i in range(3):
        im = axes[0, i].imshow(mod_vec[:, :, i], cmap='viridis')
        axes[0, i].set_title(f'Modulation Channel {i}')
        axes[0, i].axis('off')
        plt.colorbar(im, ax=axes[0, i])
    
    # Gradients
    im1 = axes[1, 0].imshow(grad_y_mag, cmap='hot')
    axes[1, 0].set_title('Modulation Gradient (Y)')
    axes[1, 0].axis('off')
    plt.colorbar(im1, ax=axes[1, 0])
    
    im2 = axes[1, 1].imshow(grad_x_mag, cmap='hot')
    axes[1, 1].set_title('Modulation Gradient (X)')
    axes[1, 1].axis('off')
    plt.colorbar(im2, ax=axes[1, 1])
    
    # Gradient histogram
    axes[1, 2].hist(grad_y_mag.flatten(), bins=100, alpha=0.5, label='Y-grad')
    axes[1, 2].hist(grad_x_mag.flatten(), bins=100, alpha=0.5, label='X-grad')
    axes[1, 2].set_xlabel('Modulation Gradient Magnitude')
    axes[1, 2].set_ylabel('Frequency')
    axes[1, 2].set_title('Gradient Distribution')
    axes[1, 2].set_yscale('log')
    axes[1, 2].legend()
    axes[1, 2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('modulation_diversity_analysis.png', dpi=150)
    plt.show()
    
    # Statistics
    mean_grad = (grad_y_mag.mean() + grad_x_mag.mean()) / 2
    std_grad = (grad_y_mag.std() + grad_x_mag.std()) / 2
    
    print("\n" + "="*70)
    print("MODULATION VECTOR DIVERSITY ANALYSIS")
    print("="*70)
    print(f"\nMean modulation gradient: {mean_grad:.6f}")
    print(f"Std modulation gradient:  {std_grad:.6f}")
    
    print("\n📊 INTERPRETATION:")
    if mean_grad > 0.01:
        print("  ✅ HIGH DIVERSITY: Modulation vectors vary smoothly")
        print("     → Network learned continuous coordinate mapping")
    else:
        print("  ❌ LOW DIVERSITY: Modulation vectors are blocky/constant")
        print("     → Network doing discrete lookup (need better jittering!)")
    
    print("="*70)
    
    return mean_grad

# Run analysis
modulation_diversity = analyze_modulation_diversity(model, test_images[0])

## Diagnostic 4: Texture Synthesis Test

Test if the network generates NEW texture patterns or just smooths existing ones

In [None]:
def texture_synthesis_test(model, test_images, device, num_samples=4):
    """
    Visual comparison of texture detail at different resolutions
    """
    model.eval()
    
    fig, axes = plt.subplots(num_samples, 5, figsize=(15, 12))
    
    with torch.no_grad():
        for i in range(num_samples):
            img = test_images[i:i+1].to(device)
            
            # Original 32×32
            axes[i, 0].imshow(img[0].cpu().permute(1, 2, 0).clamp(0, 1))
            axes[i, 0].set_title('32×32 Original' if i == 0 else '')
            axes[i, 0].axis('off')
            
            # Model 64×64
            coord_64 = create_coordinate_grid(64, 64, device).unsqueeze(0)
            recon_64 = model(img, coord_64)[0].cpu().clamp(0, 1)
            axes[i, 1].imshow(recon_64)
            axes[i, 1].set_title('64×64 Model' if i == 0 else '')
            axes[i, 1].axis('off')
            
            # Bicubic 64×64
            bicubic_64 = F.interpolate(img, 64, mode='bicubic', align_corners=False)
            axes[i, 2].imshow(bicubic_64[0].cpu().permute(1, 2, 0).clamp(0, 1))
            axes[i, 2].set_title('64×64 Bicubic' if i == 0 else '')
            axes[i, 2].axis('off')
            
            # Model 128×128 (crop center for detail)
            coord_128 = create_coordinate_grid(128, 128, device).unsqueeze(0)
            recon_128 = model(img, coord_128)[0].cpu().clamp(0, 1)
            crop_128_model = recon_128[32:96, 32:96, :]  # Center 64×64
            axes[i, 3].imshow(crop_128_model)
            axes[i, 3].set_title('128×128 Model (crop)' if i == 0 else '')
            axes[i, 3].axis('off')
            
            # Bicubic 128×128 (crop center)
            bicubic_128 = F.interpolate(img, 128, mode='bicubic', align_corners=False)
            crop_128_bicubic = bicubic_128[0, :, 32:96, 32:96].cpu().permute(1, 2, 0).clamp(0, 1)
            axes[i, 4].imshow(crop_128_bicubic)
            axes[i, 4].set_title('128×128 Bicubic (crop)' if i == 0 else '')
            axes[i, 4].axis('off')
    
    plt.tight_layout()
    plt.savefig('texture_synthesis_comparison.png', dpi=200)
    plt.show()
    
    print("\n" + "="*70)
    print("TEXTURE SYNTHESIS TEST")
    print("="*70)
    print("\nVisual inspection:")
    print("  • Look at 128×128 crops (columns 4 vs 5)")
    print("  • TRUE SR: Model shows texture details not in bicubic")
    print("  • RESAMPLING: Model looks similar to bicubic (smooth)")
    print("\nKey areas to check:")
    print("  - Edges: Sharp or blurry?")
    print("  - Textures: Repeated patterns or smooth?")
    print("  - Fine details: Visible or missing?")
    print("="*70)

# Run test
texture_synthesis_test(model, test_images, device)

## Diagnostic 5: Fix Jittering and Retrain (Quick Test)

Test if fixing jittering improves super-resolution

In [None]:
# TODO: Implement corrected jittering training
# Compare before/after fixing jittering bug

def train_epoch_corrected_jittering(model, loader, optimizer, device, epoch, 
                                    use_jittering=True, jitter_std=0.01):
    """
    CORRECTED: Jittering happens INSIDE batch loop
    """
    model.train()
    total_loss = 0
    total_psnr = 0
    
    pbar = tqdm(loader, desc=f"Epoch {epoch}")
    for images, _ in pbar:
        images = images.to(device)
        B = images.shape[0]
        
        # Create coordinate grid INSIDE loop (CORRECTED!)
        coord = create_coordinate_grid(32, 32, device)
        
        # Jitter coordinates (different for each batch!)
        if use_jittering:
            coord = add_gaussian_noise_to_grid(coord, std=jitter_std)
        
        coord = einops.repeat(coord, 'h w d -> b h w d', b=B)
        
        # Forward pass
        pred = model(images, coord)
        gt = einops.rearrange(images, 'b c h w -> b h w c')
        
        # Loss
        mses = ((pred - gt)**2).view(B, -1).mean(dim=-1)
        loss = mses.mean()
        psnr = (-10 * torch.log10(mses)).mean()
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        total_psnr += psnr.item()
        pbar.set_postfix({'loss': f"{loss.item():.4f}", 'psnr': f"{psnr.item():.2f}"})
    
    return total_loss / len(loader), total_psnr / len(loader)

print("Corrected training function defined.")
print("To test: Train for a few epochs and re-run diagnostics.")

## Summary: Diagnostic Checklist

Run all diagnostics above and check:

### ✅ Signs of TRUE Super-Resolution:
1. **Fourier features**: High activations at ω > 50
2. **Gradients**: 1.5-2x sharper than bicubic
3. **Modulation diversity**: Smooth spatial variation (grad > 0.01)
4. **Textures**: Visible details not in bicubic

### ❌ Signs of RESAMPLING:
1. **Fourier features**: Low activations at high frequencies
2. **Gradients**: Similar to bicubic (ratio ≈ 1.0)
3. **Modulation diversity**: Blocky/constant (grad < 0.001)
4. **Textures**: Smooth like bicubic

### 🔧 If Resampling, Try:
1. Fix jittering (move inside batch loop)
2. Train longer (more epochs)
3. Increase MLP capacity (hidden_dim = 512)
4. Stronger jittering (std = 0.02)
5. Add texture loss / perceptual loss

The network SHOULD be capable of true super-resolution with proper training!