In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel, DDPMPipeline
from diffusers.optimization import get_cosine_schedule_with_warmup
from diffusers.models.embeddings import Timesteps, TimestepEmbedding
from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
from dataclasses import dataclass
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [2]:
KOOPMAN = True

In [25]:
@dataclass
class TrainingConfig:
    # Optimization
    image_size = 32
    train_batch_size = 64
    eval_batch_size = 64
    num_epochs = 20
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_image_epochs = 5
    save_model_epochs = 10
    seed = 0

    # Dynamic Output Directory
    output_dir = "ddpm_mnist_koopman" if KOOPMAN else "ddpm_mnist_baseline"

    # Shared Architecture (Used by BOTH models for fair comparison)
    in_channels = 1
    out_channels = 1
    block_out_channels = (32, 64, 128, 128)
    layers_per_block = 2
    down_block_types = ("DownBlock2D", "DownBlock2D", "AttnDownBlock2D", "DownBlock2D")
    up_block_types = ("UpBlock2D", "AttnUpBlock2D", "UpBlock2D", "UpBlock2D")
    sample_size = image_size

config = TrainingConfig()
os.makedirs(config.output_dir, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [16]:
def get_dataloader(config):
    preprocess = transforms.Compose([
        transforms.Resize((config.image_size, config.image_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ])
    dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=preprocess)
    return DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)

def save_sample(model, scheduler, epoch, config):
    model.eval()
    pipeline = DDPMPipeline(unet=model, scheduler=scheduler)
    images = pipeline(batch_size=config.eval_batch_size, generator=torch.manual_seed(config.seed), num_inference_steps=50).images

    # Convert to grid
    grid = torchvision.utils.make_grid([transforms.ToTensor()(img) for img in images], nrow=8)
    pil_grid = transforms.ToPILImage()(grid)
    pil_grid.save(f"{config.output_dir}/epoch_{epoch+1:04d}.png")

In [21]:
class KoopmanUNet(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # 1. Time Embedding
        time_embed_dim = config.block_out_channels[0] * 4
        self.time_proj = Timesteps(config.block_out_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0)
        self.time_embedding = TimestepEmbedding(config.block_out_channels[0], time_embed_dim)

        # 2. Pre-process
        self.conv_in = nn.Conv2d(config.in_channels, config.block_out_channels[0], kernel_size=3, padding=1)

        # 3. Down Blocks (Encoder)
        self.down_blocks = nn.ModuleList([])
        output_channel = config.block_out_channels[0]
        for i, down_block_type in enumerate(config.down_block_types):
            input_channel = output_channel
            output_channel = config.block_out_channels[i]
            is_final = i == len(config.down_block_types) - 1
            
            self.down_blocks.append(get_down_block(
                down_block_type, 
                num_layers=config.layers_per_block, 
                in_channels=input_channel,
                out_channels=output_channel, 
                temb_channels=time_embed_dim, 
                add_downsample=not is_final,
                resnet_eps=1e-5,
                resnet_act_fn="silu",
                resnet_groups=32, 
                attention_head_dim=8,
                downsample_padding=1
            ))

        # 4. Mid Block
        self.mid_block = UNetMidBlock2D(
            in_channels=config.block_out_channels[-1], 
            temb_channels=time_embed_dim,
            resnet_eps=1e-5,
            resnet_act_fn="silu",
            resnet_groups=32, 
            attention_head_dim=8,
            output_scale_factor=1
        )

        # 5. Koopman Bottleneck
        self.bottleneck_c = config.block_out_channels[-1]
        ds_factor = 2 ** (len(config.down_block_types) - 1)
        self.bottleneck_h = config.image_size // ds_factor
        self.bottleneck_w = config.image_size // ds_factor
        features = self.bottleneck_c * self.bottleneck_h * self.bottleneck_w
        
        self.koopman_operator = nn.Linear(features, features)
        print(f"Koopman Operator initialized with {features} features.")

        # 6. Up Blocks (Decoder)
        self.up_blocks = nn.ModuleList([])
        reversed_ch = list(reversed(config.block_out_channels))
        output_channel = reversed_ch[0]
        for i, up_block_type in enumerate(config.up_block_types):
            prev_output_channel = output_channel
            output_channel = reversed_ch[i]
            input_channel = reversed_ch[min(i + 1, len(config.block_out_channels) - 1)]
            is_final = i == len(config.up_block_types) - 1

            self.up_blocks.append(get_up_block(
                up_block_type, 
                num_layers=config.layers_per_block + 1, 
                in_channels=input_channel,
                out_channels=output_channel, 
                prev_output_channel=prev_output_channel,
                temb_channels=time_embed_dim, 
                add_upsample=not is_final,
                resnet_eps=1e-5,
                resnet_act_fn="silu",
                resnet_groups=32, 
                attention_head_dim=8
            ))
            prev_output_channel = output_channel

        # 7. Output
        self.conv_norm_out = nn.GroupNorm(32, config.block_out_channels[0], eps=1e-5)
        self.conv_act = nn.SiLU()
        self.conv_out = nn.Conv2d(config.block_out_channels[0], config.out_channels, kernel_size=3, padding=1)

    @property
    def device(self):
        """Helper for Diffusers pipeline compatibility."""
        return next(self.parameters()).device

    @property
    def dtype(self):
        """Helper for Diffusers pipeline compatibility."""
        return next(self.parameters()).dtype

    def forward(self, x, t, return_dict=False):
        # Handle Time
        t = t.to(x.device)
        if t.dim() == 0: t = t.unsqueeze(0).expand(x.shape[0])
        t_emb = self.time_embedding(self.time_proj(t))

        # Encoder
        x = self.conv_in(x)
        skips = (x,)
        for block in self.down_blocks:
            x, s = block(x, t_emb)
            skips += s
        
        # Mid & Koopman
        x = self.mid_block(x, t_emb)
        B, C, H, W = x.shape
        
        # Use reshape() instead of view() to handle non-contiguous memory
        x = self.koopman_operator(x.reshape(B, -1)).reshape(B, C, H, W)

        # Decoder
        for block in self.up_blocks:
            res_skips = skips[-len(block.resnets):]
            skips = skips[:-len(block.resnets)]
            x = block(x, res_skips, temb=t_emb)

        # Output
        x = self.conv_out(self.conv_act(self.conv_norm_out(x)))
        
        if return_dict: return {"sample": x}
        from diffusers.utils import BaseOutput
        return BaseOutput(sample=x)

In [22]:
def get_model(config, use_koopman=True):
    """
    Returns either the Custom KoopmanUNet or the Standard Diffusers UNet2DModel
    based on the boolean flag, ensuring identical architecture settings.
    """
    if use_koopman:
        print(f"Initializing Custom KoopmanUNet (Rank constrained bottleneck)...")
        return KoopmanUNet(config)

    else:
        print(f"Initializing Standard Baseline UNet2DModel...")
        return UNet2DModel(
            sample_size=config.image_size,
            in_channels=config.in_channels,
            out_channels=config.out_channels,
            layers_per_block=config.layers_per_block,
            block_out_channels=config.block_out_channels,
            down_block_types=config.down_block_types,
            up_block_types=config.up_block_types,
        )

In [23]:
def train(config):
    # get the correct model
    model = get_model(config, use_koopman=KOOPMAN).to(device)

    # standard Setup
    scheduler = DDPMScheduler(num_train_timesteps=1000)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
    dataloader = get_dataloader(config)
    lr_scheduler = get_cosine_schedule_with_warmup(
        optimizer, config.lr_warmup_steps, len(dataloader) * config.num_epochs
    )

    print(f"--- Starting Training: {config.output_dir} ---")

    for epoch in range(config.num_epochs):
        model.train()
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{config.num_epochs}")
        losses = []

        for x, _ in pbar:
            x = x.to(device)
            noise = torch.randn_like(x)
            t = torch.randint(0, 1000, (x.shape[0],), device=device).long()
            noisy_x = scheduler.add_noise(x, noise, t)

            # Diffusers UNet returns a tuple or object depending on return_dict,
            # we unify this here
            if KOOPMAN:
                noise_pred = model(noisy_x, t).sample
            else:
                noise_pred = model(noisy_x, t).sample

            loss = F.mse_loss(noise_pred, noise)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            lr_scheduler.step()

            losses.append(loss.item())
            pbar.set_postfix(loss=np.mean(losses[-100:]))

        # save images/models
        if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
            save_sample(model, scheduler, epoch, config)

        if (epoch + 1) % config.save_model_epochs == 0:
            # handle saving differently for diffusers model vs ours
            save_path = f"{config.output_dir}/model.pth"
            if hasattr(model, "save_pretrained") and not KOOPMAN:
                 model.save_pretrained(config.output_dir) 
            else:
                 torch.save(model.state_dict(), save_path) 

    return model

In [None]:
model = train(config)

Initializing Custom KoopmanUNet (Rank constrained bottleneck)...
Koopman Operator initialized with 2048 features.
--- Starting Training: ddpm_mnist_koopman ---


Epoch 1/20:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 2/20:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 3/20:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 4/20:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 5/20:   0%|          | 0/938 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 6/20:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 7/20:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 8/20:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 9/20:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 10/20:   0%|          | 0/938 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 11/20:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 12/20:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 13/20:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 14/20:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 15/20:   0%|          | 0/938 [00:00<?, ?it/s]

In [4]:
def analyze_koopman(model):
    model.eval()
    print("--- Analyzing Koopman Operator ---")

    # Extract K matrix
    K = model.koopman_operator.weight.detach().cpu().numpy()

    # 1. Singular Values (Rank)
    U, S, Vh = np.linalg.svd(K)
    plt.figure(figsize=(6, 4))
    plt.plot(S[:100], 'o-') # Top 100
    plt.yscale('log')
    plt.title("Singular Values of K (Top 100)")
    plt.show()

    # 2. Eigenvalues (Dynamics)
    eigenvalues, _ = np.linalg.eig(K)
    plt.figure(figsize=(6, 6))
    ax = plt.gca()
    ax.add_patch(plt.Circle((0, 0), 1, color='r', fill=False, ls='--'))
    plt.scatter(eigenvalues.real, eigenvalues.imag, alpha=0.5, s=10)
    plt.title("Eigenvalues in Complex Plane")
    plt.axis('equal')
    plt.show()

# Load model if skipping training
# model = KoopmanUNet(config)
# model.load_state_dict(torch.load(f"{config.output_dir}/model.pth", map_location=device))
analyze_koopman(model)

  ax1.set_title("Eigenvalues $\lambda$ in Complex Plane", fontsize=16)
  ax2.set_ylabel("Magnitude $|\lambda|$")
  axs[i].set_title(f"Mode {i}\n$|\lambda|={np.abs(eigenvalues[mode_index]):.3f}$")
