## Learning Few-Step Posterior Samplers by Unfolding and Distillation of Diffusion Models
### Demo (Sampling a distilled model for gaussian deblurring)

#### Introduction
This notebook provides a demonstration of generating posterior samples $x\sim p(x|y)$ from a trained unfolded distilled conditional diffusion model UD$^2$M, as discussed in the paper [Learning Few-Step Posterior Samplers by Unfolding and Distillation of Diffusion Models](https://arxiv.org/abs/2507.02686). This demo considers the relation problem $y = k* x + \mathcal{N}(0, \sigma^2)$ for an anisotropic Gaussian filter $k$, applied to a model trained on LSUN bedroom training data.

The method uses a DDIM discretization to approximate a reversed diffusion process, where the conditional score given $y$ and $x_t$ is approximated through estimated samples from the joint posterior $\hat x_{0,t} \sim p(x_0|y, x_t)$ for a decreasing sequence of values $t$, where $x_t \sim \mathcal N(\sqrt{\overline{\alpha}_t}x, {1 - \overline{\alpha}_t})$.

![alt text](./demo/figs/alg1.png "UDDM Algorithm")

The samples $\hat x_{0,t} = x_0^{(K)}$ are obtained by unrolling $K$ iterations of the LATINO algorithm:
$$
\begin{aligned}
    \tilde x_{0, t}^{(k+1)} &= \textnormal{prox}_{-\delta\log p(y, x_t|\cdot)}(x_{0,t}^{(k)})\\
    x_{t_\delta}^{(k+1)} &= \sqrt{\overline{\alpha}_{t_\delta}}x_{0,t}^{(k+1)} + \sqrt{1 - \overline{\alpha}_{t_\delta}}\mathcal{N}(0, I)\\
    x_{0, t}^{(k+1)} &= \tilde G_\vartheta(x_{t_\delta}^{(k+1)}, t_\delta),
\end{aligned}
$$
where $\tilde G_\theta$ is a denoiser which has been pre-trained by diffusion score matching on images close to the prior distribution $p(x)$. This algorithm is unfolded, with trainable parameters:

- $\Delta\theta$ such that $\vartheta = \theta + \Delta\theta$, $\Delta \theta$ represents a low-rank perturbation of the (large) pre-trained denoiser $\theta$. Using [LoRA](https://arxiv.org/abs/2106.09685), we focus on learning low rank perturbations of the attention layers of the denoiser.
- The step size $\delta$.
- The level of augmented noise added before each denoiser step, $t_\delta$.

To test different initializations $x_{0,t}^{(0)}$ of the unrolled sampling architecture, we consider the following schemes:
- Start from $y$: $x_{0,t}^{(0)} = y$.
- Using the Reconstruct anything model [(RAM)](https://arxiv.org/abs/2503.08915) applied to the joint observation $y, x_t$: $x_{0,t}^{(0)} = \textnormal{RAM}(y, x_t)$. In this setup, the RAM weights are fine-tuned as part of the unrolled sampling architecture.

#### Training 
The unrolled sampling weights are trained using the procedure described in Section 3.2 of the paper. To summarize, the loss takes the form 
$$
    \mathcal{L} = \mathcal{L}_{\text{Adv}, \phi} + \omega_1 \mathcal L_{\text{MSE}} + \omega_2 \mathcal L_{\text{PS}},
$$
where $\mathcal{L}_{\text{Adv}, \phi}$ is an adversarial loss based on the Jensen-Shannon divergence, $\mathcal L_{\text{MSE}}$ is a mean squared error loss, and $\mathcal L_{\text{PS}}$ is a perceptual similarity loss. $\omega_1$ and $\omega_2$ are hyper-parameters used to stabilize training. The adversarial loss is computed between real and fake samples of the joint distribution $(x, y)$ using a discriminator with weights $\phi$, which are trained jointly with the unrolled sampling weights $\vartheta$. To avoid overfitting, a gradient penalty is applied to the discriminator loss.
#### Setup
The dependencies can be installed using 
```
pip install -r requirements.txt
```

Download the pretrained diffusion UNet model weights from [here](https://heibox.uni-heidelberg.de/d/01207c3f6b8441779abf/?p=%2Fdiffusion_models_converted%2Fdiffusion_lsun_bedroom_model&mode=list) and save them as `./model_zoo/diffusion_lsun_bedroom_model-2388000.ckpt`.

Download the UDDM model fine-tuned weights from [here](https://drive.google.com/file/d/17bC8nSU7Sd6M_eTZh37T3wQJLo4RvyQh/view?usp=sharing) and save them as `./demo_models/LSUN_RAM_gaussian_deblur.ckpt`.


In [None]:
# Necessary imports 
import torch 
import torchvision
import sys
from pathlib import Path
import os 
# Import data
from datasets import GetDatasets
from utils import utils_model
from metrics.metrics import Metrics

In [None]:
# Model specification
from configs.args_parse import configs
import yaml
from utils import dict_to_dotdict
im_size = 256 # Image resolution

#
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with open("configs/demo.yml", 'r') as file:
    config_dict = yaml.safe_load(file)
args = dict_to_dotdict(config_dict)

lora_checkpoint = "demo_models/LSUN_RAM_gaussian_deblur.ckpt"
Num_Grad_steps = 3 
Num_Diff_steps = 3

In [None]:
# ---> pre-trained model 
data = GetDatasets(
dataset_dir="./demo/demo_data",
im_size=256,
dataset_name="demo",
transform=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.CenterCrop((im_size, im_size)),
])
)
print(f"Dataset length: {len(data)}")

data_loader = torch.utils.data.DataLoader(
    data,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    drop_last=False,
)

#### Define forward observation model

In [None]:
# Define physics operator for deblurring
from physics import Deblurring, Kernels

sigma = 0.05 #  Noise level

kernel = Kernels(
    operator_name="anisotropic",
    kernel_size=(5,1),
    device=device
).get_blur()

physics_operator = Deblurring(
    sigma_model=sigma,
    operator_name="gaussian",
    device=device,
    scale_image=True
)

# Plot example data
import matplotlib.pyplot as plt

x = next(iter(data_loader)).to(device) # Images loaded in [-1,1] range
# x = (x + 1)/2  # Normalize to [0, 1]
y = physics_operator.y(x, blur = kernel)  # Apply the physics operator (deblurring)

x_pl = (x.permute(0, 2, 3, 1)+1)/2  # Permute to (batch, height, width, channels)
x_pl = x_pl[0].cpu().numpy()  # Convert to numpy for plotting
y_pl = (y.permute(0, 2, 3, 1)+1)/2  # Permute to (batch, height, width, channels)
y_pl = y_pl[0].cpu().clip(0,1).numpy()  # Convert to numpy for plotting
plt.subplot(1,2,1)
plt.title("Original Image")
plt.axis("off")
plt.imshow(x_pl)
plt.subplot(1,2,2)
plt.title("Blurred Image")
plt.imshow(y_pl)
plt.axis("off")

#### Load the diffusion UNet model
- Apply LoRA fine-tuning to the model weights $\theta_{\text{attention}}$ within each attention layer such that $\vartheta_{\text{attention}} = \theta_{\text{attention}} + \Delta\theta_{\text{attention}}$, where $\Delta\theta_{\text{attention}} = U^TU$ for $U\in \mathbb{R}^{d\times r}$ is a low-rank perturbation of the attention weights. For this demo, we use $r=5$.

In [None]:
# Load model
from models import adapter_lora_model
import contextlib

print("Loading model...")
with contextlib.redirect_stdout(None): # Reduce verbosity of model output
    
    state_dict = torch.load(lora_checkpoint, map_location=device, weights_only=True)

    model = adapter_lora_model(args)
    model.load_state_dict(
        state_dict["state_dict"],
        strict=False
    )

    model.eval()
    model.to(device)
print("Model loaded successfully.")
print(
    "\nLoRA weights:", sum([p.numel() for k, p in state_dict["state_dict"].items()]),
    "\nTotal model parameters:", sum([p.numel() for p in model.parameters()]),
    "\nLoRA accounts for {:.2f}% of the total model parameters.".format(
        sum([p.numel() for k, p in state_dict["state_dict"].items()]) / sum([p.numel() for p in model.parameters()]) * 100)
    )

#### Load the Diffusion Schedule 

The pretrained diffusion model was trained using score matching to predict $x_0$ from the observations
$$
    x_t \sim \mathcal{N}(\sqrt{\overline{\alpha}_t}x_0, (1 - \overline{\alpha}_t)I),
$$
where $\overline{\alpha}_t = \prod_{s=1}^t (1-\beta_s)$ and $\beta_s$ is a linear noise schedule ranging from $0.0001$ to $0.02$ over $T=1000$ steps. The diffusion schedule is defined by the parameters $\overline{\alpha}_t$ and $\beta_t$. 

The following code loads the diffusion schedule and a denoising timestep which is used to sample $x_{t_\delta}$ from $x_{0,t}$ within the LATINO architecture.

In [None]:
# Load diffusion model 
from unfolded_models import GetDenoisingTimestep, DiffusionScheduler

diffusion_schedule = DiffusionScheduler(device = device)
denoising_timestep = GetDenoisingTimestep(scheduler = diffusion_schedule, device = device)

#### Load the LATINO architecture and Conditional Diffusion Loop

In [None]:
# Define the unrolled posterior sampling loop to sample from x_0 given y and xt
from unfolded_models import HQS_models as LATINO_module #

unrolled_loop = LATINO_module(
    model,
    physics_operator,
    diffusion_schedule,
    denoising_timestep,
    args,
    device=device,
    max_unfolded_iter=Num_Grad_steps,
)

# Use RAM to initialize the unrolled joint posterior sampling loop
use_RAM = False
if use_RAM:
    # Mirror the physics operator in deepinverse
    from deepinv.physics import BlurFFT, GaussianNoise
    dinv_physics = BlurFFT(
        (3, im_size, im_size),
        filter = kernel.unsqueeze(0).unsqueeze(0),
        device=device,
        noise_model = GaussianNoise(sigma=sigma),
    )


# Define the distilled conditional diffusion model
from runners import Conditional_sampler
diffusion = Conditional_sampler(
    unrolled_loop,
    physics_operator,
    diffusion_schedule,
    device,
    args,
    dphys = dinv_physics if use_RAM else None,
)

if "HQS_state_dict" in state_dict.keys():
    print("Loading HQS state dict...")
    print(f"Successfully loaded {sum([p.numel() for k, p in state_dict['HQS_state_dict'].items() if k in diffusion.hqs_model.state_dict().keys() and (diffusion.hqs_model.state_dict()[k]-p).abs().sum() != 0])} parameters.")
    diffusion.hqs_model.load_state_dict(
        state_dict["HQS_state_dict"],
        strict=False
    )

if use_RAM and "RAM_state_dict" in state_dict.keys():
    print("Loading fine-tuned RAM state dict...")
    print(f"Successfully loaded {sum([p.numel() for k, p in state_dict['RAM_state_dict'].items() if k in diffusion.RAM.state_dict().keys() and (diffusion.RAM.state_dict()[k]-p).abs().sum() != 0])} parameters.")
    diffusion.RAM.load_state_dict(
        state_dict["RAM_state_dict"],
        strict=False
    )



#### Perform Inference on Test Data

In [None]:
# Run the diffusion sampler on the test data
outputs = []
for i, x in enumerate(data_loader):
    x = x.to(device)
    x_true = (x + 1) / 2  # Normalize to [0, 1]
    y = physics_operator.y(x, blur=kernel)  # Apply the physics operator (deblurring)
    # Sample from the diffusion model
    out = diffusion.sampler(
        y, 
        f"im_{i}",
        num_timesteps=Num_Diff_steps,
    )
    outputs.append((x_true, y, out))

#### Plot Results

In [None]:
# Process and visualize the sampled output
# Print reconstruction Metrics
for x_true, y, out in outputs:
    metrics = Metrics(device=device)
    psnrs = metrics.psnr_function(x_true, out["xstart_pred"])
    plt.figure(layout="constrained")
    plt.subplot(1,3,1)
    plt.title("True Image")
    plt.axis("off")
    plt.imshow(x_true.permute(0, 2, 3, 1)[0].cpu().numpy())  # Display the true image
    
    plt.subplot(1,3,2)
    plt.axis("off")
    plt.title("Blurred Image")
    plt.imshow(y.permute(0, 2, 3, 1)[0].cpu().add(1).mul(0.5).clip(0,1).numpy())  # Display the blurred image
    
    plt.subplot(1,3,3)
    plt.title("Sampled Image")
    plt.text(0.5, -0.01, f"PSNR: {psnrs.item():.2f} dB", ha='center', va='top', transform=plt.gca().transAxes)
    sampled_x0 = out["xstart_pred"].permute(0, 2, 3, 1)  # Permute to (batch, height, width, channels)
    sampled_x0 = sampled_x0[0].cpu().numpy()   # Convert to numpy for plotting
    plt.axis("off")
    plt.imshow(sampled_x0)  # Display the sampled image
    


    # Plot the progressive image through the diffusion steps
    seq,_ = diffusion_schedule.get_seq_progress_seq(Num_Diff_steps)
    progression = out["progress_list"]
    fig = plt.figure(layout="constrained")
    fig.suptitle(r"Progressive $x_t$ through Diffusion Steps", y=0.85)
    for j, img in enumerate(progression):
        ax = fig.add_subplot(1, len(progression), j + 1)
        ax.axis("off")
        ax.set_title(fr"$t = {seq[j]}$")
        ax.imshow(img)  # Display the progressive image
        metrics = Metrics(device=device)
        psnrs = metrics.psnr_function(x_true, torch.tensor(img).unsqueeze(0).permute(0,3,1,2).to(device))
        plt.text(0.5, -0.01, f"PSNR: {psnrs.item():.2f} dB", ha='center', va='top', transform=plt.gca().transAxes)

    # Plot the progressive prediction through the diffusion steps
    seq,_ = diffusion_schedule.get_seq_progress_seq(Num_Diff_steps)
    progression = out["progress_zero"]
    fig = plt.figure(layout="constrained")
    fig.suptitle(r"Progressive $x_0$ through Diffusion Steps", y=0.85)
    for j, img in enumerate(progression):
        img = (img.permute(1, 2, 0) + 1)/2
        img = img.cpu().clip(0,1).numpy()  # Convert to numpy for plotting
        ax = fig.add_subplot(1, len(progression), j + 1)
        ax.axis("off")
        ax.set_title(fr"$t = {seq[j]}$")
        ax.imshow(img)  # Display the progressive image
        psnrs = metrics.psnr_function(x_true, torch.tensor(img).unsqueeze(0).permute(0,3,1,2).to(device))
        plt.text(0.5, -0.01, f"PSNR: {psnrs.item():.2f} dB", ha='center', va='top', transform=plt.gca().transAxes)

