In [78]:
import torch
from torch import nn
import torchvision
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from diffusers import DDPMScheduler, UNet2DModel, DDPMPipeline
from diffusers.optimization import get_cosine_schedule_with_warmup

import os
from PIL import Image
from dataclasses import dataclass
from tqdm import tqdm

In [70]:
KOOPMAN = True

In [71]:
@dataclass
class TrainingConfig:
    image_size = 32
    train_batch_size = 64
    eval_batch_size = 64
    num_epochs = 2
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_image_epochs = 1
    save_model_epochs = 2
    output_dir = "ddpm_mnist_koopman" if KOOPMAN else "ddpm_mnist"
    seed = 0

config = TrainingConfig()
os.makedirs(config.output_dir, exist_ok=True)

In [72]:
def get_mnist_dataloader(image_size, batch_size):
    """Loads the MNIST dataset."""
    preprocess = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),  # (pixel - 0.5) / 0.5
        ]
    )

    dataset = torchvision.datasets.MNIST(
        root="./data", train=True, download=True, transform=preprocess
    )

    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [73]:
import torch
import torch.nn as nn
from diffusers.models.embeddings import Timesteps, TimestepEmbedding
from diffusers.models.unets.unet_2d_blocks import (
    UNetMidBlock2D,
    get_down_block,
    get_up_block,
)

class KoopmanUNet(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        image_size = config.sample_size
        in_channels = config.in_channels
        out_channels = config.out_channels
        block_out_channels = config.block_out_channels
        layers_per_block = config.layers_per_block
        down_block_types = config.down_block_types
        up_block_types = config.up_block_types

        time_embed_dim = block_out_channels[0] * 4

        self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0)
        self.time_embedding = TimestepEmbedding(
            in_channels=block_out_channels[0],
            time_embed_dim=time_embed_dim,
        )

        self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1)

        self.down_blocks = nn.ModuleList([])
        output_channel = block_out_channels[0]
        for i, down_block_type in enumerate(down_block_types):
            input_channel = output_channel
            output_channel = block_out_channels[i]
            is_final_block = i == len(down_block_types) - 1

            down_block = get_down_block(
                down_block_type,
                num_layers=layers_per_block,
                in_channels=input_channel,
                out_channels=output_channel,
                temb_channels=time_embed_dim,
                add_downsample=not is_final_block,
                resnet_eps=1e-5,
                resnet_act_fn="silu",
                resnet_groups=32,
                attention_head_dim=8,
                downsample_padding=1,  # <-- FIX 1: Added padding
            )
            self.down_blocks.append(down_block)

        self.mid_block = UNetMidBlock2D(
            in_channels=block_out_channels[-1],
            temb_channels=time_embed_dim,
            resnet_eps=1e-5,
            resnet_act_fn="silu",
            output_scale_factor=1,
            attention_head_dim=8, # Use this arg for MidBlock
            resnet_groups=32,
        )

        self.bottleneck_c = block_out_channels[-1]
        downsample_factor = 2 ** (len(down_block_types) - 1)
        self.bottleneck_h = image_size // downsample_factor
        self.bottleneck_w = image_size // downsample_factor
        self.bottleneck_features = self.bottleneck_c * self.bottleneck_h * self.bottleneck_w

        self.koopman_operator = nn.Linear(self.bottleneck_features, self.bottleneck_features)

        print(f"Koopman Bottleneck Initialized:")
        print(f"  Shape: ({self.bottleneck_c}, {self.bottleneck_h}, {self.bottleneck_w})")
        print(f"  Total Features: {self.bottleneck_features}")

        self.up_blocks = nn.ModuleList([])
        reversed_block_out_channels = list(reversed(block_out_channels))
        output_channel = reversed_block_out_channels[0]
        for i, up_block_type in enumerate(up_block_types):
            prev_output_channel = output_channel
            output_channel = reversed_block_out_channels[i]
            input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]

            is_final_block = i == len(up_block_types) - 1

            up_block = get_up_block(
                up_block_type,
                num_layers=layers_per_block + 1,
                in_channels=input_channel,
                out_channels=output_channel,
                prev_output_channel=prev_output_channel,
                temb_channels=time_embed_dim,
                add_upsample=not is_final_block,
                resnet_eps=1e-5,
                resnet_act_fn="silu",
                resnet_groups=32,
                attention_head_dim=8,
            )
            self.up_blocks.append(up_block)
            prev_output_channel = output_channel

        self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5)
        self.conv_act = nn.SiLU()
        self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)

    def forward(self, x, t, return_dict=False):
        batch_size = x.shape[0]

        # 1. Time Embedding
        t_emb = self.time_proj(t)
        t_emb = self.time_embedding(t_emb)

        # 2. Input
        x = self.conv_in(x)

        # --- 3. ENCODER (Corrected Skip Logic) ---
        # The output of conv_in is the first skip connection
        skip_connections = (x,)

        for block in self.down_blocks:
            x, skips = block(x, t_emb)
            # Add the (2) skips from this block to our giant tuple
            skip_connections += skips

        # 4. MID BLOCK
        x = self.mid_block(x, t_emb)

        # 5. KOOPMAN BOTTLENECK
        x = x.reshape(batch_size, -1) # Use .reshape()
        x = self.koopman_operator(x)
        x = x.reshape(batch_size, self.bottleneck_c, self.bottleneck_h, self.bottleneck_w)

        # --- 6. DECODER (Corrected Skip Logic) ---
        for block in self.up_blocks:
            # Get the number of resnets in *this* block (e.g., 3)
            num_skips = len(block.resnets)

            # Get the last 'num_skips' from our big tuple
            skips_for_this_block = skip_connections[-num_skips:]

            # Shorten the skip_connections tuple for the next loop
            skip_connections = skip_connections[:-num_skips]

            # Pass the correct skips
            x = block(x, skips_for_this_block, temb=t_emb)

        # 7. OUTPUT
        x = self.conv_norm_out(x)
        x = self.conv_act(x)
        x = self.conv_out(x)

        return (x,)

In [74]:
def setup_model_and_scheduler(image_size, koopman=False):
    """Initializes the U-Net model and noise scheduler."""
    baseline_model = UNet2DModel(
        sample_size=image_size,
        in_channels=1,
        out_channels=1,
        layers_per_block=2,
        block_out_channels=(32, 64, 128, 128),
        down_block_types=(
            "DownBlock2D",
            "DownBlock2D",
            "AttnDownBlock2D",
            "DownBlock2D",
        ),
        up_block_types=(
            "UpBlock2D",
            "AttnUpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
        ),
    )

    if koopman:
        model = KoopmanUNet(baseline_model.config)
    else:
        model = baseline_model

    noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
    return model, noise_scheduler

In [75]:
def tensor_to_pil(image_tensor):
    """Converts a tensor image to a PIL Image."""
    image = (image_tensor / 2 + 0.5).clamp(0, 1) # Denormalize from [-1, 1] to [0, 1]
    image = image.cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")
    pil_images = [Image.fromarray(img.squeeze(), 'L') for img in images]
    return pil_images

def generate_and_save_images(model, scheduler, epoch, config):
    """Generates images and saves them to a file."""
    print(f"Generating images for epoch {epoch+1}...")
    model.eval()
    pipeline = DDPMPipeline(unet=model, scheduler=scheduler)

    # generate a batch of images
    images = pipeline(
        batch_size=config.eval_batch_size,
        generator=torch.manual_seed(config.seed),
    ).images

    image_grid = torchvision.utils.make_grid(
        [transforms.ToTensor()(img) for img in images], nrow=8
    )

    pil_grid = transforms.ToPILImage()(image_grid)
    save_path = os.path.join(config.output_dir, f"epoch_{epoch+1:04d}.png")
    pil_grid.save(save_path)
    print(f"Saved image grid to {save_path}")

    display(pil_grid)

In [79]:
def train_loop(config, model, noise_scheduler, optimizer, lr_scheduler, train_dataloader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    print(f"--- Starting Training on {device} ---")

    for epoch in range(config.num_epochs):
        model.train()  
        epoch_loss = 0.0

        for step, (images, _) in enumerate(tqdm(train_dataloader)):
            clean_images = images.to(device)
            batch_size = clean_images.shape[0]

            # sample random noise and timesteps
            noise = torch.randn_like(clean_images)
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (batch_size,), device=device).long()

            # add noise to the clean images (the "forward process")
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

            # get the model's prediction for the noise
            noise_pred = model(noisy_images, timesteps, return_dict=False)[0]

            loss = F.mse_loss(noise_pred, noise)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            lr_scheduler.step()

            # 8. Update progress bar
            epoch_loss += loss.item()

        print(f"Epoch {epoch+1} Average Loss: {epoch_loss / len(train_dataloader)}")

        # --- 5. Save Images and Model Checkpoint ---
        if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
            generate_and_save_images(model, noise_scheduler, epoch, config)

        if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
            model_path = os.path.join(config.output_dir, "unet_model")
            model.save_pretrained(model_path)
            print(f"Saved model to {model_path}")

In [80]:
train_dataloader = get_mnist_dataloader(config.image_size, config.train_batch_size)

model, noise_scheduler = setup_model_and_scheduler(config.image_size, koopman=KOOPMAN)

optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(len(train_dataloader) * config.num_epochs),
)

train_loop(config, model, noise_scheduler, optimizer, lr_scheduler, train_dataloader)

Koopman Bottleneck Initialized:
  Shape: (128, 4, 4)
  Total Features: 2048
--- Starting Training on cpu ---


  7%|▋         | 61/938 [01:27<20:54,  1.43s/it]


KeyboardInterrupt: 

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# --- 1. Extract the trained Koopman matrix ---
# (Make sure 'model' is your trained KoopmanUNet)
model.eval() # Put model in evaluation mode
K_matrix = model.koopman_operator.weight.detach().cpu().numpy()

# --- 2. Perform Eigendecomposition ---
eigenvalues, eigenvectors = np.linalg.eig(K_matrix)

# Sort them by the magnitude (importance) of the eigenvalues
sorted_indices = np.argsort(np.abs(eigenvalues))[::-1] # Sort descending
eigenvalues = eigenvalues[sorted_indices]
eigenvectors = eigenvectors[:, sorted_indices]

print(f"Found {len(eigenvalues)} Koopman modes.")

# --- 3. Analyze the Eigenvalues (The "Dynamics") ---
# Eigenvalues (λ) tell you the *temporal* behavior.
# |\lambda| ≈ 1.0 : "Stable Modes". Patterns that are preserved. (e.g., the digit's identity)
# |\lambda| < 1.0 : "Decaying Modes". Patterns that are removed. (e.g., noise, fuzzy edges)

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(np.real(eigenvalues), np.imag(eigenvalues), 'o', markersize=5)
plt.title("Eigenvalues in Complex Plane")
plt.xlabel("Real Part")
plt.ylabel("Imaginary Part")

plt.subplot(1, 2, 2)
plt.plot(np.abs(eigenvalues), 'o-')
plt.title("Eigenvalue Magnitudes")
plt.xlabel("Mode Index")
plt.ylabel("Magnitude |λ|")
plt.axhline(1.0, color='red', linestyle='--')
plt.show()

# --- 4. Visualize the Eigenvectors (The "Patterns") ---
# The eigenvectors (v) are the "Koopman Modes".
# They live in the latent space (size 512). We need to pass them
# through the DECODER to see what they look like as images.

def visualize_mode(model, eigenvector_column):
    # Convert eigenvector (a numpy column) to a tensor
    mode_vector = torch.from_numpy(np.real(eigenvector_column)).float()
    mode_vector = mode_vector.to(next(model.parameters()).device) # Move to GPU

    # Reshape it to the bottleneck shape [batch, C, H, W]
    mode_tensor = mode_vector.view(1, model.bottleneck_c, model.bottleneck_h, model.bottleneck_w)

    # Create dummy skip connections (all zeros)
    # This is a bit complex, but we need to match the shapes
    dummy_t_emb = model.time_embedding(model.time_proj(torch.tensor([1.0], device=mode_tensor.device)))
    dummy_skips = []

    # We build skip connections by running a dummy input through the ENCODER
    # (Using zeros as input is a common trick)
    temp_x = torch.zeros(1, 1, config.image_size, config.image_size).to(mode_tensor.device)
    for block in model.down_blocks:
        temp_x, skips = block(temp_x, dummy_t_emb)
        for skip in skips:
            # We use zero skips to see the "pure" mode
            dummy_skips.append(torch.zeros_like(skip))

    # --- Run the DECODER ---
    x = mode_tensor
    for block in model.up_blocks:
        skips_for_block = [dummy_skips.pop() for _ in range(block.resnets)]
        x = block(x, skips_for_block, dummy_t_emb)

    # Final output layers
    x = model.conv_norm_out(x)
    x = model.conv_act(x)
    x = model.conv_out(x)

    return x.detach().cpu()

# --- Now, let's visualize the top 5 modes ---
plt.figure(figsize=(15, 3))
for i in range(5):
    mode_image = visualize_mode(model, eigenvectors[:, i])

    # Denormalize image from [-1, 1] to [0, 1]
    img = (mode_image[0].squeeze() / 2 + 0.5).clamp(0, 1)

    plt.subplot(1, 5, i + 1)
    plt.imshow(img, cmap='gray')
    plt.title(f"Mode {i}\n|λ| = {np.abs(eigenvalues[i]):.3f}")
    plt.axis('off')

plt.show()