In [None]:
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from einops import rearrange, repeat
import ssl
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from math import pi, log
from functools import wraps
import torch.nn.functional as F

# Fix for torchvision dataset download issue
ssl._create_default_https_context = ssl._create_unverified_context

# ===============================================================
# --- 1. The One True Perceiver IO Model Architecture ---
# ===============================================================

# helpers
def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def cache_fn(f):
    cache = None
    @wraps(f)
    def cached_fn(*args, _cache = True, **kwargs):
        if not _cache:
            return f(*args, **kwargs)
        nonlocal cache
        if cache is not None:
            return cache
        cache = f(*args, **kwargs)
        return cache
    return cached_fn

# helper classes
class PreNorm(nn.Module):
    def __init__(self, dim, fn, context_dim = None):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)
        self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None

    def forward(self, x, **kwargs):
        x = self.norm(x)
        if exists(self.norm_context):
            context = kwargs['context']
            normed_context = self.norm_context(context)
            kwargs.update(context = normed_context)
        return self.fn(x, **kwargs)

class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return x * F.gelu(gates)

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult * 2),
            GEGLU(),
            nn.Linear(dim * mult, dim)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)
        self.scale = dim_head ** -0.5
        self.heads = heads
        self.to_q = nn.Linear(query_dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, query_dim)

    def forward(self, x, context = None, mask = None):
        h = self.heads
        q = self.to_q(x)
        context = default(context, x)
        k, v = self.to_kv(context).chunk(2, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
        sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
        if exists(mask):
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h = h)
            sim.masked_fill_(~mask, max_neg_value)
        attn = sim.softmax(dim = -1)
        out = torch.einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
        return self.to_out(out)

from math import log

# This helper function creates the sinusoidal embeddings
def get_sinusoidal_embeddings(n, d):
    """
    Generates sinusoidal positional embeddings.
    
    Args:
        n (int): The number of positions (num_latents).
        d (int): The embedding dimension (latent_dim).

    Returns:
        torch.Tensor: A tensor of shape (n, d) with sinusoidal embeddings.
    """
    # Ensure latent_dim is even for sin/cos pairs
    assert d % 2 == 0, "latent_dim must be an even number for sinusoidal embeddings"
    
    position = torch.arange(n, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d, 2).float() * -(log(10000.0) / d))
    
    pe = torch.zeros(n, d)
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe

# main class
class PerceiverIO(nn.Module):
    def __init__(
        self,
        *,
        depth,
        dim,
        queries_dim,
        logits_dim = None,
        num_latents = 512,
        latent_dim = 512,
        cross_heads = 1,
        latent_heads = 8,
        cross_dim_head = 64,
        latent_dim_head = 64,
        weight_tie_layers = False,
        decoder_ff = False,
        seq_dropout_prob = 0.
    ):
        super().__init__()
        self.seq_dropout_prob = seq_dropout_prob

        # --- MODIFICATION START ---
        # 1. Generate sinusoidal embeddings instead of random noise.
        sinu_embeds = get_sinusoidal_embeddings(num_latents, latent_dim)
        
        # 2. Register 'latents' as a non-trainable buffer instead of a learnable nn.Parameter.
        self.register_buffer('latents', sinu_embeds)
        # --- MODIFICATION END ---
        
        self.cross_attend_blocks = nn.ModuleList([
            PreNorm(latent_dim, Attention(latent_dim, dim, heads = cross_heads, dim_head = cross_dim_head), context_dim = dim),
            PreNorm(latent_dim, FeedForward(latent_dim))
        ])
        get_latent_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, heads = latent_heads, dim_head = latent_dim_head))
        get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim))
        get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff))
        self.layers = nn.ModuleList([])
        cache_args = {'_cache': weight_tie_layers}
        for i in range(depth):
            self.layers.append(nn.ModuleList([
                get_latent_attn(**cache_args),
                get_latent_ff(**cache_args)
            ]))
        self.decoder_cross_attn = PreNorm(queries_dim, Attention(queries_dim, latent_dim, heads = cross_heads, dim_head = cross_dim_head), context_dim = latent_dim)
        self.decoder_ff = PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None
        self.to_logits = nn.Linear(queries_dim, logits_dim) if exists(logits_dim) else nn.Identity()

    def forward(
        self,
        data,
        mask = None,
        queries = None
    ):
        b, *_, device = *data.shape, data.device
        x = repeat(self.latents, 'n d -> b n d', b = b)
        cross_attn, cross_ff = self.cross_attend_blocks
        if self.training and self.seq_dropout_prob > 0.:
            data, mask = dropout_seq(data, mask, self.seq_dropout_prob)
        x = cross_attn(x, context = data, mask = mask) + x
        x = cross_ff(x) + x
        for self_attn, self_ff in self.layers:
            x = self_attn(x) + x
            x = self_ff(x) + x
        if not exists(queries):
            return x
        if queries.ndim == 2:
            queries = repeat(queries, 'n d -> b n d', b = b)
        
        latents = self.decoder_cross_attn(queries, context = x)
        if exists(self.decoder_ff):
            latents = latents + self.decoder_ff(latents)
            
        return self.to_logits(latents)


# ===============================================================
# --- Training Script Starts Here ---
# ===============================================================


# --- Configuration and Setup ---
BATCH_SIZE = 64
EPOCHS = 50
# Adjusted learning rate to a more standard value for this kind of task
LEARNING_RATE = 2e-4
# Using an available GPU, change if needed
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

IMAGE_SIZE_TRAIN = 32
IMAGE_SIZE_HI_RES = 128
CHANNELS = 3

POS_EMBED_DIM = 64
INPUT_DIM = CHANNELS + POS_EMBED_DIM
QUERIES_DIM = POS_EMBED_DIM
LOGITS_DIM = CHANNELS

# --- Data Loading ---
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize to [-1, 1]
])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)


class GaussianFourierFeatures(nn.Module):
    def __init__(self, in_features, mapping_size, scale=10.0):
        super().__init__()
        self.in_features = in_features
        self.mapping_size = mapping_size
        self.register_buffer('B', torch.randn((in_features, mapping_size)) * scale)

    def forward(self, coords):
        projections = coords @ self.B
        fourier_feats = torch.cat([torch.sin(projections), torch.cos(projections)], dim=-1)
        return fourier_feats

# --- In your main script ---
FOURIER_MAPPING_SIZE = 96
POS_EMBED_DIM = FOURIER_MAPPING_SIZE * 2
INPUT_DIM = CHANNELS + POS_EMBED_DIM
QUERIES_DIM = POS_EMBED_DIM

fourier_encoder = GaussianFourierFeatures(
    in_features=2,
    mapping_size=FOURIER_MAPPING_SIZE,
    scale=15.0
).to(DEVICE)

model = PerceiverIO(
    depth=6,
    dim=INPUT_DIM,
    queries_dim=QUERIES_DIM,
    logits_dim=LOGITS_DIM,
    num_latents=256,
    latent_dim=512,
    cross_heads=1,
    latent_heads=8,
    cross_dim_head=64,
    latent_dim_head=64,
    decoder_ff=True
).to(DEVICE)

optimizer = AdamW(list(model.parameters()) + list(fourier_encoder.parameters()), lr=LEARNING_RATE)
loss_fn = nn.MSELoss()
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS * len(train_loader))

print(f"Training on {DEVICE}")
total_params = sum(p.numel() for p in model.parameters()) + sum(p.numel() for p in fourier_encoder.parameters())
print(f"Total parameters: {total_params/1e6:.2f}M")


# --- Helper Functions and Coordinate Grids ---
def create_coordinate_grid(h, w, device):
    grid = torch.stack(torch.meshgrid(
        torch.linspace(-1.0, 1.0, h, device=device),
        torch.linspace(-1.0, 1.0, w, device=device),
        indexing='ij'
    ), dim=-1)
    return rearrange(grid, 'h w c -> (h w) c')

coords_32x32 = create_coordinate_grid(IMAGE_SIZE_TRAIN, IMAGE_SIZE_TRAIN, DEVICE)
coords_128x128 = create_coordinate_grid(IMAGE_SIZE_HI_RES, IMAGE_SIZE_HI_RES, DEVICE)

def prepare_model_input(images, coords, fourier_encoder_fn):
    b, c, h, w = images.shape
    pixels = rearrange(images, 'b c h w -> b (h w) c')
    batch_coords = repeat(coords, 'n d -> b n d', b=b)
    pos_embeddings = fourier_encoder_fn(batch_coords)
    input_with_pos = torch.cat((pixels, pos_embeddings), dim=-1)
    return input_with_pos, pixels, pos_embeddings

def imshow(img, title):
    img = img.cpu() / 2 + 0.5 # Unnormalize
    npimg = img.numpy()
    plt.figure(figsize=(10, 10))
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.title(title, fontsize=14)
    plt.axis('off')
    plt.show()

# ===============================================================
# --- NEW: FFT Validation Function ---
# ===============================================================
def calculate_and_visualize_fft_power_delta(original_imgs, recon_imgs, epoch_num):
    """
    Calculates and visualizes the FFT power spectrum difference between
    the original and reconstructed images.
    """
    # Use the first image in the batch for visualization
    original_img = original_imgs[0]
    recon_img = recon_imgs[0]

    # Convert to grayscale for 2D FFT
    # New versions of torchvision use functional transforms
    original_gray = transforms.functional.rgb_to_grayscale(original_img)
    recon_gray = transforms.functional.rgb_to_grayscale(recon_img)

    # --- FFT Calculation ---
    def get_log_power_spectrum(img_tensor):
        # Squeeze the channel dimension
        img_tensor = img_tensor.squeeze(0)
        # Apply 2D FFT
        fft = torch.fft.fft2(img_tensor)
        # Shift the zero frequency component to the center
        fft_shifted = torch.fft.fftshift(fft)
        # Calculate the power spectrum (magnitude squared)
        power_spectrum = torch.abs(fft_shifted)**2
        # Use log scale for better visualization
        log_power_spectrum = torch.log1p(power_spectrum)
        return log_power_spectrum.cpu().numpy()

    original_fft_power = get_log_power_spectrum(original_gray)
    recon_fft_power = get_log_power_spectrum(recon_gray)
    delta_power = np.abs(original_fft_power - recon_fft_power)

    # --- Visualization ---
    fig, axs = plt.subplots(1, 3, figsize=(18, 6))
    fig.suptitle(f'Epoch {epoch_num}: FFT Power Spectrum Comparison', fontsize=16)

    im1 = axs[0].imshow(original_fft_power, cmap='viridis')
    axs[0].set_title('Original Image FFT Power')
    axs[0].axis('off')
    fig.colorbar(im1, ax=axs[0])

    im2 = axs[1].imshow(recon_fft_power, cmap='viridis')
    axs[1].set_title('Reconstructed Image FFT Power')
    axs[1].axis('off')
    fig.colorbar(im2, ax=axs[1])
    
    im3 = axs[2].imshow(delta_power, cmap='magma')
    axs[2].set_title('Power Difference (Delta)')
    axs[2].axis('off')
    fig.colorbar(im3, ax=axs[2])

    plt.tight_layout()
    plt.show()

# --- Main Training and Validation Loop ---
for epoch in range(EPOCHS):
    model.train()
    fourier_encoder.train()
    total_train_loss = 0.0

    for i, (images, _) in enumerate(train_loader):
        images = images.to(DEVICE)
        input_data, target_pixels, queries = prepare_model_input(images, coords_32x32, fourier_encoder)
        optimizer.zero_grad(set_to_none=True)
        reconstructed_pixels = model(input_data, queries=queries)
        loss = loss_fn(reconstructed_pixels, target_pixels)
        loss.backward()
        optimizer.step()
        scheduler.step()
        total_train_loss += loss.item()
        if (i + 1) % 200 == 0:
            print(f"Epoch [{epoch+1}/{EPOCHS}], Step [{i+1}/{len(train_loader)}], LR: {scheduler.get_last_lr()[0]:.6f}, Loss: {loss.item():.4f}")

    avg_train_loss = total_train_loss / len(train_loader)
    
    model.eval()
    fourier_encoder.eval()
    total_val_loss = 0.0
    with torch.no_grad():
        for images, _ in test_loader:
            images = images.to(DEVICE)
            input_data, target_pixels, queries = prepare_model_input(images, coords_32x32, fourier_encoder)
            reconstructed_pixels = model(input_data, queries=queries)
            total_val_loss += loss_fn(reconstructed_pixels, target_pixels).item()

    avg_val_loss = total_val_loss / len(test_loader)
    print(f"--- Epoch [{epoch+1}/{EPOCHS}] Summary ---")
    print(f"  Avg Training Loss: {avg_train_loss:.4f}")
    print(f"  Avg Validation Loss: {avg_val_loss:.4f}\n")

    # --- Visualization at the end of each epoch ---
    with torch.no_grad():
        context_images, _ = next(iter(test_loader))
        context_images = context_images.to(DEVICE)[:8]
        b, c, h, w = context_images.shape

        # 1. Low-Resolution Reconstruction Test
        input_context, _, queries_context = prepare_model_input(context_images, coords_32x32, fourier_encoder)
        reconstructed_pixels = model(input_context, queries=queries_context)
        reconstructed_images = rearrange(reconstructed_pixels, 'b (h w) c -> b c h w', h=h, w=w)
        
        comparison_grid = torch.cat((context_images, reconstructed_images), dim=0)
        final_grid = torchvision.utils.make_grid(comparison_grid, nrow=8, padding=2)
        imshow(final_grid, f"Epoch {epoch+1}: Top: Original | Bottom: Reconstructed (32x32)")

        # ===============================================================
        # --- NEW: Call the FFT validation function ---
        # ===============================================================
        calculate_and_visualize_fft_power_delta(context_images, reconstructed_images, epoch + 1)
        # ===============================================================

        # 2. High-Resolution Generation Test
        high_res_batch_coords = repeat(coords_128x128, 'n d -> b n d', b=b)
        high_res_queries = fourier_encoder(high_res_batch_coords)
        
        generated_pixels = model(input_context, queries=high_res_queries)
        generated_images = rearrange(generated_pixels, 'b (h w) c -> b c h w', h=IMAGE_SIZE_HI_RES, w=IMAGE_SIZE_HI_RES)

        generated_grid = torchvision.utils.make_grid(generated_images, nrow=4, padding=2)
        imshow(generated_grid, f"Epoch {epoch+1}: Generated High-Resolution Images (128x128)")

print("--- Training finished. ---")