Goal is to explore avgpool for use as a downsampler in our diffusion

Sections
* AvgPool Twice
* Variance Exploration

In [1]:
import torch
import torch.nn.functional as F

# AvgPool twice? 
It seems that we can expect a max error around 2.3842e-07 per dimension. 
In terms of vectors sizes, we can expect a max error around 1.0476e-05

In [2]:
D = lambda x, d: F.avg_pool2d(x,d, d, 0, False, False)

In [3]:
dim = 1024
N = 64
example_N3DimDim = torch.randn(N, 3, dim, dim)

In [4]:
twice = D(D(example_N3DimDim, 2), 2)
once = D(example_N3DimDim, 4)
torch.allclose(twice, once, atol=1e-7)

True

In [5]:
torch.abs(twice-once).max()

tensor(2.3842e-07)

In [6]:
mags = torch.sqrt(torch.sum((twice-once)**2, dim=(1,2,3)))
mags.max()

tensor(1.0427e-05)

In [7]:
def avgpool_stats_dim(N, dim, scale=1.0): 
    example_N3DimDim = torch.randn(N, 3, dim, dim) * scale
    
    twice = D(D(example_N3DimDim, 2), 2)
    once = D(example_N3DimDim, 4)
    close = torch.allclose(twice, once, atol=1e-7)
    print(f"Close: {close}", end=" ")

    max_abs_diff = torch.abs(twice-once).max()
    print(f"{max_abs_diff=}", end=" ")

    mags = torch.sqrt(torch.sum((twice-once)**2, dim=(1,2,3)))
    max_magnitude_diff = mags.max()
    print(f"{max_magnitude_diff=}")

In [8]:
for dim in [4, 16, 32, 64, 128, 256, 512, 1028]: 
    avgpool_stats_dim(N, dim)

Close: True max_abs_diff=tensor(5.9605e-08) max_magnitude_diff=tensor(6.8286e-08)
Close: True max_abs_diff=tensor(1.4901e-07) max_magnitude_diff=tensor(2.2781e-07)
Close: True max_abs_diff=tensor(1.1921e-07) max_magnitude_diff=tensor(3.7674e-07)
Close: True max_abs_diff=tensor(1.7881e-07) max_magnitude_diff=tensor(7.1357e-07)
Close: True max_abs_diff=tensor(2.3842e-07) max_magnitude_diff=tensor(1.3735e-06)
Close: True max_abs_diff=tensor(2.3842e-07) max_magnitude_diff=tensor(2.6434e-06)
Close: True max_abs_diff=tensor(2.3842e-07) max_magnitude_diff=tensor(5.2739e-06)
Close: False max_abs_diff=tensor(2.3842e-07) max_magnitude_diff=tensor(1.0474e-05)


In [9]:
for dim in [4, 16, 32, 64, 128, 256, 512, 1028]: 
    avgpool_stats_dim(N, dim, scale=100)

Close: False max_abs_diff=tensor(7.6294e-06) max_magnitude_diff=tensor(8.5831e-06)
Close: False max_abs_diff=tensor(1.5259e-05) max_magnitude_diff=tensor(2.3253e-05)
Close: False max_abs_diff=tensor(1.5259e-05) max_magnitude_diff=tensor(3.7434e-05)
Close: False max_abs_diff=tensor(1.5259e-05) max_magnitude_diff=tensor(6.9951e-05)
Close: False max_abs_diff=tensor(2.2888e-05) max_magnitude_diff=tensor(0.0001)
Close: False max_abs_diff=tensor(3.0518e-05) max_magnitude_diff=tensor(0.0003)
Close: False max_abs_diff=tensor(3.0518e-05) max_magnitude_diff=tensor(0.0005)
Close: False max_abs_diff=tensor(3.0518e-05) max_magnitude_diff=tensor(0.0010)


I think the larger error can be explained as a small error summed over many entries. It is strange that we get larger errors when scaling by a larger value, but I don't think we should worry about it. The scales are small. 

In [10]:
64*16, 7.1760e-05 * 16

(1024, 0.00114816)

In [11]:
def avgpool_stats_dim_unif(N, dim): 
    example_N3DimDim = torch.rand(N, 3, dim, dim) * 2 - 1
    
    twice = D(D(example_N3DimDim, 2), 2)
    once = D(example_N3DimDim, 4)
    close = torch.allclose(twice, once, atol=1e-7)
    print(f"Close: {close}", end=" ")

    max_abs_diff = torch.abs(twice-once).max()
    print(f"{max_abs_diff=}", end=" ")

    mags = torch.sqrt(torch.sum((twice-once)**2, dim=(1,2,3)))
    max_magnitude_diff = mags.max()
    print(f"{max_magnitude_diff=}")

In [12]:
for dim in [4, 16, 32, 64, 128, 256, 512, 1028]: 
    avgpool_stats_dim_unif(N, dim)

Close: True max_abs_diff=tensor(5.9605e-08) max_magnitude_diff=tensor(6.1889e-08)
Close: True max_abs_diff=tensor(8.9407e-08) max_magnitude_diff=tensor(1.3637e-07)
Close: True max_abs_diff=tensor(8.9407e-08) max_magnitude_diff=tensor(2.0768e-07)
Close: True max_abs_diff=tensor(1.1921e-07) max_magnitude_diff=tensor(3.9192e-07)
Close: True max_abs_diff=tensor(1.1921e-07) max_magnitude_diff=tensor(7.3982e-07)
Close: True max_abs_diff=tensor(1.4901e-07) max_magnitude_diff=tensor(1.4206e-06)
Close: True max_abs_diff=tensor(1.4901e-07) max_magnitude_diff=tensor(2.8310e-06)
Close: True max_abs_diff=tensor(1.4901e-07) max_magnitude_diff=tensor(5.6384e-06)


These errors disappear when we deal with integer (no decimal point) values: 
* This indicates that we are running into 32-bit precision problems. I wouldn't worry about it

In [13]:
def avgpool_stats_dim_arange(N, dim): 
    example_N3DimDim = torch.arange(N*3*dim*dim).reshape(N, 3, dim, dim)
    
    twice = D(D(example_N3DimDim, 2), 2)
    once = D(example_N3DimDim, 4)
    close = torch.allclose(twice, once, atol=1e-7)
    print(f"Close: {close}", end=" ")

    max_abs_diff = torch.abs(twice-once).max()
    print(f"{max_abs_diff=}", end=" ")

    mags = torch.sqrt(torch.sum((twice-once)**2, dim=(1,2,3)))
    max_magnitude_diff = mags.max()
    print(f"{max_magnitude_diff=}")

In [14]:
for dim in [4,16,32,64,128, 256, 512,1028]: 
    avgpool_stats_dim_arange(N, dim)

Close: True max_abs_diff=tensor(0) max_magnitude_diff=tensor(0.)
Close: True max_abs_diff=tensor(0) max_magnitude_diff=tensor(0.)
Close: True max_abs_diff=tensor(0) max_magnitude_diff=tensor(0.)
Close: True max_abs_diff=tensor(0) max_magnitude_diff=tensor(0.)
Close: True max_abs_diff=tensor(0) max_magnitude_diff=tensor(0.)
Close: True max_abs_diff=tensor(0) max_magnitude_diff=tensor(0.)
Close: True max_abs_diff=tensor(0) max_magnitude_diff=tensor(0.)
Close: True max_abs_diff=tensor(0) max_magnitude_diff=tensor(0.)


# Variance Adjustment

Now, consider mean-pooling with a kernel and strid of size (a,b).

Suppose we apply it to a Gaussian with isotropic covariance, or 
$$
x \sim \mathcal{N}(\mu, I\sigma^2).
$$
After pooling, each entry is comprised of a sample average of $ab$ entries of $x$, or 
$$
    x_{ij}' = \frac{1}{ab} \sum_{(k,l) \in \text{idx}(i,j)} x_{k,l}.
$$
This is a sample average over gaussians, so the result is a gaussian as well. 

The mean is simply the sample mean, but the variance is given by 
$$
\begin{align*}
\text{Var}(x_{ij}')
&= \text{Var}(\frac{1}{ab} \sum_{(k,l) \in \text{idx}(i,j)} x_{k,l}) \\ 
&= \frac{1}{(ab)^2} \sum_{(k,l) \in \text{idx}(i,j)} \text{Var}(x_{k,l}) \\
&= \frac{1}{(ab)^2} (ab) \text{Var}(x_{k,l}) \\ 
&= \frac{1}{(ab)^2} (ab) \sigma^2 \\
&= \frac{1}{ab} \sigma^2.
\end{align*}
$$
Since the original gaussian is isotropic, and the pooling grabs distinct sets of entries, the covariances are zero. 

Importantly, if we apply this to noise, the covariance of the noise is scaled by $(1/(ab))$.

In order to avoid this, we can
* add isotropic noise of variance $(1 - 1/(ab))\sigma^2$


This means that if we want to imagine adding noise at different scales, we need to be careful about how we handle it. We can imagine that downscale operations implicitly add this noise - to keep variance consistent across scales. 

## Example Diffusion

In [15]:
import sys
sys.path.append("../src/")

%load_ext autoreload
%autoreload 2
# Importing our custom module(s)
import unet
from unet import UNet

In [16]:
test_net = UNet()

In [17]:
input_NShape = torch.randn((64, 3, 32, 32))
t_N = torch.rand((64,))
test_net.forwardt_same(t_N, input_NShape).shape

torch.Size([64, 1, 32, 32])

In [18]:
test_net.forwardt(t_N, input_NShape).shape

torch.Size([64, 1, 64, 64])

In [93]:
# give T for each level
# we'll define our own betas
# assume half size each time

from typing import List

class BasicAvgPoolMultiScaleDiffusion(torch.nn.Module): 
    """
        We do diffusion until step T0. Then avgpool. 
        Then do diffusion until step T1.
    """
    def __init__(
            self, 
            Ts: List[int] = []
        ) -> None:
        super().__init__()
        self.Ts = Ts
        
        self.Tends = [sum(self.Ts[0:i+1]) for i in range(len(Ts))]
        # self.register_buffer("betas_A", torch.linspace(1e-4, 0.02, sum(self.Ts)))
        betas_A = torch.linspace(1e-4, 0.02, sum(self.Ts))
        self.register_buffer("alphas_cumprod_A", torch.cumprod(1.0 - betas_A, dim=0))


    def _xt_given_x0_same(self, x0_NShape: torch.Tensor, level: int, t_N: torch.Tensor): 
        """
            All t on the same resolution level.

            In this function, we compute x_t, and then get x_{t+1}

            t must be below the last t on this level
        """
        # downsize at end of level. 
        x0_NDown = F.avg_pool2d(x0_NShape, 2**level)
        shp = x0_NDown.shape

        N = x0_NDown.size(0)

        # we are assuming that downsize adds some noise to compensate for 
        # the variance reduction. SO we can just use betas as is
        # computing here, in case we want to make betas learnable
        alpha_bar_N = self.alphas_cumprod_A[t_N]

        eps_NDown = torch.randn_like(x0_NDown).to(x0_NShape.device)

        xt_NDown = (
            torch.sqrt(alpha_bar_N)[:, None, None, None] * x0_NDown
            + torch.sqrt(1 - alpha_bar_N)[:, None, None, None] * eps_NDown
        )
        # print("same")
        # print(x0_NShape.shape)
        # print(xt_NDown.shape)
        # print(eps_NDown.shape)
        # print(t_N.shape)
        return xt_NDown, eps_NDown
    
    def _xt_given_x0_jump(self, x0_NShape: torch.Tensor, level: int): 
        """
            Just came from a jump at level. So at level+1
        """
        # downsize at end of level. 
        x0_NDown = F.avg_pool2d(x0_NShape, 2**(level))
        x0_NSmall = F.avg_pool2d(x0_NDown, 2)

        N = x0_NSmall.size(0)

        # we are assuming that downsize adds some noise to compensate for 
        # the variance reduction. SO we can just use betas as is
        # computing here, in case we want to make betas learnable

        t_val = self.Tends[level]
        t_N = torch.full((N,), t_val, dtype=torch.long, device=x0_NShape.device)
        alpha_bar_N = self.alphas_cumprod_A[t_N]

        eps_NDown = torch.randn_like(x0_NDown).to(x0_NShape.device)
        # downscale, so mult by (hw) to keep variance scale
        eps_SNSmall = F.avg_pool2d(eps_NDown, 2) * 4

        xt_NSmall = (
            torch.sqrt(alpha_bar_N)[:, None, None, None] * x0_NSmall
            + torch.sqrt(1 - alpha_bar_N)[:, None, None, None] * eps_SNSmall
        )
        return xt_NSmall, eps_NDown


    def xts_given_x0(self, x0_NShape: torch.Tensor, level: int): 
        N = x0_NShape.size(0) 

        # this ends up over-emphasizing jumps? 
        do_jump = torch.rand(1).item() < 0.1

        if do_jump: 
            # from level 0 to 1
            xt_NSmall, eps_NSmall = self._xt_given_x0_jump(x0_NShape, 0)
            t_end = self.Tends[level]
            t_N = torch.full((N,), t_end, dtype=torch.long, device=x0_NShape.device)

            return xt_NSmall, eps_NSmall, t_N
        
        else: 
            t_end = self.Tends[level]
            if level == 0: 
                t_start = 0 
            else: 
                t_start = self.Tends[level-1] + 1

            t_N = torch.randint(t_start, t_end, (N,), device=x0_NShape.device)
            xt_NSmall, eps_NSmall = self._xt_given_x0_same(x0_NShape, level, t_N)

            return xt_NSmall, eps_NSmall, t_N
    

    def loss(self, eps_model: UNet, x0_NShape: torch.Tensor): 
        # N x C x H x W
        min_dim = min(x0_NShape.shape[-2:])

        import math
        shrink = int(math.log(min_dim, 2))

        level = int(torch.randint(0, 1 + shrink,  (1,)).item())
        xt_NSmall, eps_NSmall, t_N = self.xts_given_x0(x0_NShape, level)

        # print(xt_NSmall.shape)
        # print(eps_NSmall.shape)
        # print(t_N.shape)

        if xt_NSmall.shape == eps_NSmall.shape: 
            eps_pred_NSmall = eps_model.forwardt_same(t_N, xt_NSmall)
        else: 
            eps_pred_NSmall = eps_model.forwardt(t_N, xt_NSmall)
        return torch.mean((eps_pred_NSmall - eps_NSmall)**2)
    

    @torch.no_grad()
    def sample(self, eps_model: UNet, xT_NShape: torch.Tensor):
        """
        Reverse DDPM across all timesteps, mapping each t to its corresponding
        level using self.Tends. Upscaling happens exactly when t == Tends[level]
        (and level > 0), before calling the eps model for that timestep.
        """
        import bisect

        device = xT_NShape.device
        N = xT_NShape.shape[0]

        # Precompute alpha schedules on device
        alpha_bar = self.alphas_cumprod_A      # (total_T,)
        alpha_bar_prev = torch.cat([torch.tensor([1.0], device=device), alpha_bar[:-1]])

        total_T = alpha_bar.shape[0]
        x = xT_NShape  # start at final (lowest-res) noise

        # We'll iterate global t from total_T-1 down to 0
        for t in range(total_T - 1, -1, -1):
            # find level for this t: first index i with t <= Tends[i]
            # note: self.Tends should be a list of ints
            level = bisect.bisect_left(self.Tends, t)
            # Safety clamp
            if level >= len(self.Tends):
                level = len(self.Tends) - 1

            # create batch-sized t tensor
            t_N = torch.full((N,), t, dtype=torch.long, device=device)

            # Choose the appropriate UNet call:
            # - If we're at the finest level (level == 0) OR we just upsampled (so x matches fine res),
            #   use forwardt_same; otherwise use forwardt.
            if (t == self.Tends[level]) and (level < (len(self.Tends) - 1)):
                eps_pred = eps_model.forwardt(t_N, x)
                x = torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest")
            else:
                eps_pred = eps_model.forwardt_same(t_N, x)

            # DDPM reverse step (per-batch indexing so broadcasting works)
            ab_t = alpha_bar[t_N].view(N,1,1,1)
            ab_prev = alpha_bar_prev[t_N].view(N,1,1,1)
            a_t = ab_t / ab_prev

            coef1 = (1.0 / torch.sqrt(a_t)).view(N, 1, 1, 1)                   # (N,1,1,1)
            coef2 = (((1.0 - a_t) / torch.sqrt(1.0 - ab_t))).view(N, 1, 1, 1)  # (N,1,1,1)

            mean = coef1 * (x - coef2 * eps_pred)

            if t > 0:
                sigma_t = torch.sqrt(((1.0 - ab_prev) / (1.0 - ab_t)) * (1.0 - a_t))  # (N,)
                sigma_t = sigma_t.view(N, 1, 1, 1)
                noise = torch.randn_like(x)
                x = mean + sigma_t * noise
            else:
                x = mean  # at t==0 no noise

        return x

In [94]:
test_diff = BasicAvgPoolMultiScaleDiffusion([100, 100, 100, 100, 100, 5])
# 32, 16, 8, 4, 2, 1


In [95]:
eps_model = UNet(in_channels=3, num_classes=1)


In [96]:
device = 'cuda'

In [97]:
start = torch.randn((16,3,1,1), device=device)
eps_model = eps_model.to(device)
test_diff = test_diff.to(device)

test_diff.sample(eps_model, start).shape

torch.Size([16, 3, 32, 32])

In [98]:
# test_XT_NStart = torch.randn((5, 3, 1,1))
# sample_X0_NEnd = test.sample(test_net, test_XT_NStart)

In [99]:
# test_XT_NStart.shape, sample_X0_NEnd.shape

0 = orig
99 - last at orig res
100 - first downsample and noise


# Generating a Loop with AI for now

In [100]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, utils
from torch.utils.data import DataLoader
from tqdm import tqdm
import os

# ---------------------------
# Hyperparameters
# ---------------------------
batch_size = 64
lr = 2e-4
num_epochs = 10
device = "cuda" if torch.cuda.is_available() else "cpu"

# ---------------------------
# Dataset: MNIST 32x32
# ---------------------------
transform = transforms.Compose([
    transforms.Resize(32),        # rescale to 32x32
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),  # [-1,1]
])

train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)

# ---------------------------
# Diffusion parameters
# ---------------------------
Ts = [100,100,50,50,50,5] 
diffusion = BasicAvgPoolMultiScaleDiffusion(Ts=Ts).to(device)

# ---------------------------
# UNet model
# ---------------------------
eps_model = UNet(in_channels=1, num_classes=1).to(device)
optimizer = optim.AdamW(eps_model.parameters(), lr=lr)

In [101]:
# ---------------------------
# Optional EMA
# ---------------------------
class EMA:
    def __init__(self, model, decay=0.999):
        self.decay = decay
        self.shadow = {name: p.clone().detach() for name, p in model.named_parameters()}

    @torch.no_grad()
    def update(self, model):
        for name, p in model.named_parameters():
            self.shadow[name].mul_(self.decay).add_(p.data, alpha=1 - self.decay)

    @torch.no_grad()
    def copy_to(self, model):
        for name, p in model.named_parameters():
            p.data.copy_(self.shadow[name])

ema = EMA(eps_model, decay=0.999)

In [102]:
# chk = torch.load('checkpoints/checkpoint_epoch_0.pt', map_location="cpu")

# eps_model.load_state_dict(chk["eps_model"])
# ema = EMA(eps_model)
# diffusion.load_state_dict(chk["diffusion"])
# optimizer.load_state_dict(chk["optimizer"])

# eps_model = eps_model.to(device)
# diffusion = diffusion.to(device)

# print("Loaded checkpoint:", chk)

In [None]:


# ---------------------------
# Training loop
# ---------------------------
step = 0
for epoch in range(num_epochs):
    pbar = tqdm(train_dl, desc=f"Epoch {epoch}")
    for x0, _ in pbar:
        eps_model.train()
        
        x0 = x0.to(device)  # (N,1,32,32)
        optimizer.zero_grad()

        # diffusion loss
        loss = diffusion.loss(eps_model, x0)

        loss.backward()
        optimizer.step()
        ema.update(eps_model)

        pbar.set_postfix({"loss": loss.item()})
        step += 1

        # ---------------------------
        # periodic sampling
        # ---------------------------
        eps_model.eval()
        if step % 2000 == 0:
            with torch.no_grad():
                ema.copy_to(eps_model)
                N = 16
                noise = torch.randn(N, 1, 1, 1).to(device)
                samples = diffusion.sample(eps_model, noise)
                samples = samples.clamp(-1,1)
                grid = utils.make_grid((samples+1)/2, nrow=4)
                os.makedirs("samples", exist_ok=True)
                utils.save_image(grid, f"samples/sample_step_{step}.png")

    # ---------------------------
    # Save checkpoint per epoch
    # ---------------------------
    os.makedirs("checkpoints", exist_ok=True)
    torch.save({
        "eps_model": eps_model.state_dict(),
        "ema": ema.shadow,
        "diffusion": diffusion.state_dict(),
        "optimizer": optimizer.state_dict(),
    }, f"checkpoints/checkpoint_epoch_{epoch}.pt")


Epoch 0:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 0: 100%|██████████| 938/938 [02:00<00:00,  7.76it/s, loss=0.0499] 
Epoch 1: 100%|██████████| 938/938 [02:28<00:00,  6.31it/s, loss=0.118]   
Epoch 2: 100%|██████████| 938/938 [01:40<00:00,  9.34it/s, loss=0.107]   
Epoch 3: 100%|██████████| 938/938 [01:32<00:00, 10.15it/s, loss=0.0404]  
Epoch 4: 100%|██████████| 938/938 [01:42<00:00,  9.17it/s, loss=0.142]   
Epoch 5: 100%|██████████| 938/938 [01:32<00:00, 10.10it/s, loss=0.771]   
Epoch 6:  80%|███████▉  | 749/938 [01:19<00:19,  9.65it/s, loss=0.0434]  

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

# ---------------------------
#  Hyperparameters
# ---------------------------

batch_size = 64
lr = 1e-4
num_epochs = 100

# Multi-scale diffusion: e.g. 300 steps at 32x32, then 200 at 16x16, then 200 at 8x8.
Ts = [100,100,50,50,50,5]

device = "cuda" if torch.cuda.is_available() else "cpu"

# ---------------------------
#  Dataset: CIFAR10
# ---------------------------

transform = transforms.Compose([
    transforms.ToTensor(),          # [0, 1]
    transforms.Normalize([0.5]*3, [0.5]*3),    # scale to [-1,1]
])

train_ds = datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)


# ---------------------------
#  Instantiate UNet + diffusion
# ---------------------------

diffusion = BasicAvgPoolMultiScaleDiffusion(Ts=Ts).to(device)

# Instantiate your UNet (you must adapt this to your constructor)
eps_model = UNet().to(device)

optimizer = optim.AdamW(eps_model.parameters(), lr=lr)

# Optional: EMA
class EMA:
    def __init__(self, model, decay=0.999):
        self.decay = decay
        self.shadow = {
            name: p.clone().detach()
            for name, p in model.named_parameters()
        }

    @torch.no_grad()
    def update(self, model):
        for name, p in model.named_parameters():
            self.shadow[name].mul_(self.decay).add_(p.data, alpha=1 - self.decay)

    @torch.no_grad()
    def copy_to(self, model):
        for name, p in model.named_parameters():
            p.data.copy_(self.shadow[name])

ema = EMA(eps_model, decay=0.999)

# ---------------------------
#  Training Loop
# ---------------------------

step = 0
for epoch in range(num_epochs):

    pbar = tqdm(train_dl, desc=f"Epoch {epoch}")

    for x0, _ in pbar:

        x0 = x0.to(device)                # (N, 3, 32, 32)
        optimizer.zero_grad()

        # The diffusion object handles picking a random level,
        # generating xt, eps, t, and computing loss.
        loss = diffusion.loss(eps_model, x0)

        loss.backward()
        optimizer.step()
        ema.update(eps_model)

        pbar.set_postfix({"loss": loss.item()})
        step += 1

        # -------------------------------------------
        #  Periodic sampling
        # -------------------------------------------
        if step % 2000 == 0:
            with torch.no_grad():
                # use EMA weights for sampling quality
                ema.copy_to(eps_model)

                N = 16
                # start from noise matching the lowest resolution
                # lowest resolution = 32 // (2 ** (len(Ts)-1))
                lowest_res = 32 // (2 ** (len(Ts)-1))   # e.g. 8
                noise = torch.randn(N, 3, lowest_res, lowest_res).to(device)

                samples = diffusion.sample(eps_model, noise)
                samples = samples.clamp(-1, 1)

                # Save samples
                grid = torchvision.utils.make_grid((samples+1)/2, nrow=4)
                torchvision.utils.save_image(grid, f"samples_step_{step}.png")

    # -------------------------------------------
    #  Save checkpoint at each epoch
    # -------------------------------------------
    torch.save({
        "eps_model": eps_model.state_dict(),
        "ema": ema.shadow,
        "diffusion": diffusion.state_dict(),
        "optimizer": optimizer.state_dict(),
    }, f"checkpoint_epoch_{epoch}.pt")


Epoch 0:   0%|          | 0/782 [00:00<?, ?it/s]ERROR:tornado.general:SEND Error: Host unreachable
Epoch 0:   0%|          | 0/782 [01:12<?, ?it/s]
Epoch 0:   0%|          | 0/782 [01:12<?, ?it/s]
Epoch 0:   0%|          | 0/782 [01:12<?, ?it/s]
Epoch 0:   0%|          | 0/782 [01:12<?, ?it/s]
Epoch 0:   0%|          | 0/782 [01:12<?, ?it/s]


KeyboardInterrupt: 

KeyboardInterrupt: 

KeyboardInterrupt: 

KeyboardInterrupt: 

KeyboardInterrupt: 