# DDIM Training
In this notebook we will train a Denoising Diffusion Implicit Model (DDIM) on the MNIST dataset. We follow the "Diffusion Autoencoder" architecture by conditioning the denoising process on the internal representations of a standard autoencoder.

In [3]:
# Import the necessary modules
import torch
import os
import matplotlib.pyplot as plt
import ddim_mnist
import train_ddim_mnist
# Add this at the beginning of your second cell (before you load the autoencoder)
import importlib
importlib.reload(ddim_mnist)
from ddim_mnist import DiffusionModel, SimpleAutoencoder

importlib.reload(train_ddim_mnist)
from train_ddim_mnist import (
    train_autoencoder,
    prepare_dataset,
    train_diffusion_model,
    get_device
)

# Set device (you can use your existing code or the imported function)
device = get_device()
print(f"Using device: {device}")

# Create output directory
os.makedirs("samples", exist_ok=True)

Using device: mps


In [4]:
# Train or load autoencoder
if os.path.exists("autoencoder_weights.pt"):
    print("Loading existing autoencoder...")
    autoencoder = SimpleAutoencoder(device=device)
    autoencoder.load_checkpoint("autoencoder_weights.pt")
else:
    # You can customize parameters here
    autoencoder = train_autoencoder(device, batch_size=128, epochs=5)

# Prepare dataset with feature vectors
train_loader, test_loader, test_features = prepare_dataset(autoencoder, device)

# Train or load diffusion model
if os.path.exists("diffusion_model.pt"):
    print("Loading existing diffusion model...")
    diffusion_model = DiffusionModel(device=device)
    diffusion_model.load_checkpoint("diffusion_model.pt")
else:
    # You can customize parameters here
    diffusion_model = train_diffusion_model(train_loader, test_loader, test_features, device, epochs=30)

Loading existing autoencoder...
Autoencoder using device: mps
Autoencoder loaded from autoencoder_weights.pt
Generating feature vectors...


  encoded = self.encoder(x)


Creating diffusion model...
DiffusionModel using device: mps
Computing dataset statistics...
Dataset statistics - Mean: 0.1275, Std: 0.3058
Training diffusion model...
Epoch: 1/30, Batch: 0/938, Loss: 1.061384
Epoch: 1/30, Batch: 100/938, Loss: 0.373298
Epoch: 1/30, Batch: 200/938, Loss: 0.269611
Epoch: 1/30, Batch: 300/938, Loss: 0.239107
Epoch: 1/30, Batch: 400/938, Loss: 0.207411
Epoch: 1/30, Batch: 500/938, Loss: 0.184052


KeyboardInterrupt: 