### Import Libraries

In [1]:
from src.diffusion.ddpm import DDPM
from src.diffusion.unet import UNet
from src.data.mnist import get_mnist_loader_and_transform
from src.data.cifar10 import get_cifar10_loader_and_transform
from torchvision.utils import save_image, make_grid
from src.diffusion.train import train
import matplotlib.pyplot as plt

import torch.backends
import torch.backends.mps
import os

### Configuration of model

In [7]:
T = 1000
dataset = "cifar10" # can be "cifar10" or "mnist"

PATH_TO_READY_MODEL = "models/mnist_diffusion.pth" # input path for ready model

PATH_TO_SAVE_MODEL = "model.pth"
EPOCHS = 100 # for cifar10 it should be more than 1000, but for mnist 20-100 should be okay

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# For Mac OS
if torch.backends.mps.is_available():
    device = "mps"
    print("USE MPS")

USE MPS


### Load dataset

In [4]:
if dataset == "mnist":
    data = get_mnist_loader_and_transform()
elif dataset == "cifar10":
    data = get_cifar10_loader_and_transform()

Files already downloaded and verified
Files already downloaded and verified


### Setup Model

In [5]:
ddpm = DDPM(
    T = T,
    eps_model=UNet(
        in_channels=data.in_channels,
        out_channels=data.out_channels,
        T=T+1,
        steps=data.recommended_steps,
        attn_step_indexes=data.recommended_attn_step_indexes
    ),
    device=device
)

### Train or load ready model

In [8]:
if PATH_TO_READY_MODEL is not None:
    ddpm.load_state_dict(torch.load(PATH_TO_READY_MODEL, map_location=device))
else:
    _, val_losses = train(
        model=ddpm,
        optimizer=torch.optim.Adam(params=ddpm.parameters(), lr=2e-4),
        epochs=EPOCHS,
        device=device,
        train_dataloader=data.train_loader,
        val_dataloader=data.val_loader
    )

    path = PATH_TO_SAVE_MODEL if PATH_TO_SAVE_MODEL is not None else "model.pth"

    torch.save(ddpm.state_dict(), path)

    plt.plot(val_losses, label="Validation Loss")

    plt.legend()

RuntimeError: Error(s) in loading state_dict for DDPM:
	Missing key(s) in state_dict: "eps_model.down_blocks.0.models.1.K_W", "eps_model.down_blocks.0.models.1.K_b", "eps_model.down_blocks.0.models.1.Q_W", "eps_model.down_blocks.0.models.1.Q_b", "eps_model.down_blocks.0.models.1.V_W", "eps_model.down_blocks.0.models.1.V_b", "eps_model.down_blocks.0.models.1.O_W", "eps_model.down_blocks.0.models.1.O_b", "eps_model.down_blocks.0.models.1.norm.weight", "eps_model.down_blocks.0.models.1.norm.bias", "eps_model.down_blocks.0.models.1.mlp.0.weight", "eps_model.down_blocks.0.models.1.mlp.0.bias", "eps_model.down_blocks.0.models.1.mlp.1.weight", "eps_model.down_blocks.0.models.1.mlp.1.bias", "eps_model.down_blocks.0.models.1.mlp.3.weight", "eps_model.down_blocks.0.models.1.mlp.3.bias", "eps_model.down_blocks.0.models.2.conv_1.0.weight", "eps_model.down_blocks.0.models.2.conv_1.0.bias", "eps_model.down_blocks.0.models.2.conv_1.2.weight", "eps_model.down_blocks.0.models.2.conv_1.2.bias", "eps_model.down_blocks.0.models.2.conv_2.0.weight", "eps_model.down_blocks.0.models.2.conv_2.0.bias", "eps_model.down_blocks.0.models.2.conv_2.2.weight", "eps_model.down_blocks.0.models.2.conv_2.2.bias", "eps_model.down_blocks.0.models.2.time_emb.ln.1.weight", "eps_model.down_blocks.0.models.2.time_emb.ln.1.bias", "eps_model.down_blocks.0.models.3.K_W", "eps_model.down_blocks.0.models.3.K_b", "eps_model.down_blocks.0.models.3.Q_W", "eps_model.down_blocks.0.models.3.Q_b", "eps_model.down_blocks.0.models.3.V_W", "eps_model.down_blocks.0.models.3.V_b", "eps_model.down_blocks.0.models.3.O_W", "eps_model.down_blocks.0.models.3.O_b", "eps_model.down_blocks.0.models.3.norm.weight", "eps_model.down_blocks.0.models.3.norm.bias", "eps_model.down_blocks.0.models.3.mlp.0.weight", "eps_model.down_blocks.0.models.3.mlp.0.bias", "eps_model.down_blocks.0.models.3.mlp.1.weight", "eps_model.down_blocks.0.models.3.mlp.1.bias", "eps_model.down_blocks.0.models.3.mlp.3.weight", "eps_model.down_blocks.0.models.3.mlp.3.bias", "eps_model.down_blocks.4.models.1.K_W", "eps_model.down_blocks.4.models.1.K_b", "eps_model.down_blocks.4.models.1.Q_W", "eps_model.down_blocks.4.models.1.Q_b", "eps_model.down_blocks.4.models.1.V_W", "eps_model.down_blocks.4.models.1.V_b", "eps_model.down_blocks.4.models.1.O_W", "eps_model.down_blocks.4.models.1.O_b", "eps_model.down_blocks.4.models.1.norm.weight", "eps_model.down_blocks.4.models.1.norm.bias", "eps_model.down_blocks.4.models.1.mlp.0.weight", "eps_model.down_blocks.4.models.1.mlp.0.bias", "eps_model.down_blocks.4.models.1.mlp.1.weight", "eps_model.down_blocks.4.models.1.mlp.1.bias", "eps_model.down_blocks.4.models.1.mlp.3.weight", "eps_model.down_blocks.4.models.1.mlp.3.bias", "eps_model.down_blocks.4.models.2.conv_1.0.weight", "eps_model.down_blocks.4.models.2.conv_1.0.bias", "eps_model.down_blocks.4.models.2.conv_1.2.weight", "eps_model.down_blocks.4.models.2.conv_1.2.bias", "eps_model.down_blocks.4.models.2.conv_2.0.weight", "eps_model.down_blocks.4.models.2.conv_2.0.bias", "eps_model.down_blocks.4.models.2.conv_2.2.weight", "eps_model.down_blocks.4.models.2.conv_2.2.bias", "eps_model.down_blocks.4.models.2.time_emb.ln.1.weight", "eps_model.down_blocks.4.models.2.time_emb.ln.1.bias", "eps_model.down_blocks.4.models.3.K_W", "eps_model.down_blocks.4.models.3.K_b", "eps_model.down_blocks.4.models.3.Q_W", "eps_model.down_blocks.4.models.3.Q_b", "eps_model.down_blocks.4.models.3.V_W", "eps_model.down_blocks.4.models.3.V_b", "eps_model.down_blocks.4.models.3.O_W", "eps_model.down_blocks.4.models.3.O_b", "eps_model.down_blocks.4.models.3.norm.weight", "eps_model.down_blocks.4.models.3.norm.bias", "eps_model.down_blocks.4.models.3.mlp.0.weight", "eps_model.down_blocks.4.models.3.mlp.0.bias", "eps_model.down_blocks.4.models.3.mlp.1.weight", "eps_model.down_blocks.4.models.3.mlp.1.bias", "eps_model.down_blocks.4.models.3.mlp.3.weight", "eps_model.down_blocks.4.models.3.mlp.3.bias", "eps_model.down_blocks.6.models.0.conv_1.0.weight", "eps_model.down_blocks.6.models.0.conv_1.0.bias", "eps_model.down_blocks.6.models.0.conv_1.2.weight", "eps_model.down_blocks.6.models.0.conv_1.2.bias", "eps_model.down_blocks.6.models.0.conv_2.0.weight", "eps_model.down_blocks.6.models.0.conv_2.0.bias", "eps_model.down_blocks.6.models.0.conv_2.2.weight", "eps_model.down_blocks.6.models.0.conv_2.2.bias", "eps_model.down_blocks.6.models.0.time_emb.ln.1.weight", "eps_model.down_blocks.6.models.0.time_emb.ln.1.bias", "eps_model.down_blocks.6.models.1.K_W", "eps_model.down_blocks.6.models.1.K_b", "eps_model.down_blocks.6.models.1.Q_W", "eps_model.down_blocks.6.models.1.Q_b", "eps_model.down_blocks.6.models.1.V_W", "eps_model.down_blocks.6.models.1.V_b", "eps_model.down_blocks.6.models.1.O_W", "eps_model.down_blocks.6.models.1.O_b", "eps_model.down_blocks.6.models.1.norm.weight", "eps_model.down_blocks.6.models.1.norm.bias", "eps_model.down_blocks.6.models.1.mlp.0.weight", "eps_model.down_blocks.6.models.1.mlp.0.bias", "eps_model.down_blocks.6.models.1.mlp.1.weight", "eps_model.down_blocks.6.models.1.mlp.1.bias", "eps_model.down_blocks.6.models.1.mlp.3.weight", "eps_model.down_blocks.6.models.1.mlp.3.bias", "eps_model.down_blocks.6.models.2.conv_1.0.weight", "eps_model.down_blocks.6.models.2.conv_1.0.bias", "eps_model.down_blocks.6.models.2.conv_1.2.weight", "eps_model.down_blocks.6.models.2.conv_1.2.bias", "eps_model.down_blocks.6.models.2.conv_2.0.weight", "eps_model.down_blocks.6.models.2.conv_2.0.bias", "eps_model.down_blocks.6.models.2.conv_2.2.weight", "eps_model.down_blocks.6.models.2.conv_2.2.bias", "eps_model.down_blocks.6.models.2.time_emb.ln.1.weight", "eps_model.down_blocks.6.models.2.time_emb.ln.1.bias", "eps_model.down_blocks.6.models.3.K_W", "eps_model.down_blocks.6.models.3.K_b", "eps_model.down_blocks.6.models.3.Q_W", "eps_model.down_blocks.6.models.3.Q_b", "eps_model.down_blocks.6.models.3.V_W", "eps_model.down_blocks.6.models.3.V_b", "eps_model.down_blocks.6.models.3.O_W", "eps_model.down_blocks.6.models.3.O_b", "eps_model.down_blocks.6.models.3.norm.weight", "eps_model.down_blocks.6.models.3.norm.bias", "eps_model.down_blocks.6.models.3.mlp.0.weight", "eps_model.down_blocks.6.models.3.mlp.0.bias", "eps_model.down_blocks.6.models.3.mlp.1.weight", "eps_model.down_blocks.6.models.3.mlp.1.bias", "eps_model.down_blocks.6.models.3.mlp.3.weight", "eps_model.down_blocks.6.models.3.mlp.3.bias", "eps_model.up_blocks.4.models.1.K_W", "eps_model.up_blocks.4.models.1.K_b", "eps_model.up_blocks.4.models.1.Q_W", "eps_model.up_blocks.4.models.1.Q_b", "eps_model.up_blocks.4.models.1.V_W", "eps_model.up_blocks.4.models.1.V_b", "eps_model.up_blocks.4.models.1.O_W", "eps_model.up_blocks.4.models.1.O_b", "eps_model.up_blocks.4.models.1.norm.weight", "eps_model.up_blocks.4.models.1.norm.bias", "eps_model.up_blocks.4.models.1.mlp.0.weight", "eps_model.up_blocks.4.models.1.mlp.0.bias", "eps_model.up_blocks.4.models.1.mlp.1.weight", "eps_model.up_blocks.4.models.1.mlp.1.bias", "eps_model.up_blocks.4.models.1.mlp.3.weight", "eps_model.up_blocks.4.models.1.mlp.3.bias", "eps_model.up_blocks.4.models.2.conv_1.0.weight", "eps_model.up_blocks.4.models.2.conv_1.0.bias", "eps_model.up_blocks.4.models.2.conv_1.2.weight", "eps_model.up_blocks.4.models.2.conv_1.2.bias", "eps_model.up_blocks.4.models.2.conv_2.0.weight", "eps_model.up_blocks.4.models.2.conv_2.0.bias", "eps_model.up_blocks.4.models.2.conv_2.2.weight", "eps_model.up_blocks.4.models.2.conv_2.2.bias", "eps_model.up_blocks.4.models.2.time_emb.ln.1.weight", "eps_model.up_blocks.4.models.2.time_emb.ln.1.bias", "eps_model.up_blocks.4.models.3.K_W", "eps_model.up_blocks.4.models.3.K_b", "eps_model.up_blocks.4.models.3.Q_W", "eps_model.up_blocks.4.models.3.Q_b", "eps_model.up_blocks.4.models.3.V_W", "eps_model.up_blocks.4.models.3.V_b", "eps_model.up_blocks.4.models.3.O_W", "eps_model.up_blocks.4.models.3.O_b", "eps_model.up_blocks.4.models.3.norm.weight", "eps_model.up_blocks.4.models.3.norm.bias", "eps_model.up_blocks.4.models.3.mlp.0.weight", "eps_model.up_blocks.4.models.3.mlp.0.bias", "eps_model.up_blocks.4.models.3.mlp.1.weight", "eps_model.up_blocks.4.models.3.mlp.1.bias", "eps_model.up_blocks.4.models.3.mlp.3.weight", "eps_model.up_blocks.4.models.3.mlp.3.bias", "eps_model.up_blocks.5.upscale.weight", "eps_model.up_blocks.5.upscale.bias", "eps_model.up_blocks.6.models.0.conv_1.0.weight", "eps_model.up_blocks.6.models.0.conv_1.0.bias", "eps_model.up_blocks.6.models.0.conv_1.2.weight", "eps_model.up_blocks.6.models.0.conv_1.2.bias", "eps_model.up_blocks.6.models.0.conv_2.0.weight", "eps_model.up_blocks.6.models.0.conv_2.0.bias", "eps_model.up_blocks.6.models.0.conv_2.2.weight", "eps_model.up_blocks.6.models.0.conv_2.2.bias", "eps_model.up_blocks.6.models.0.time_emb.ln.1.weight", "eps_model.up_blocks.6.models.0.time_emb.ln.1.bias", "eps_model.up_blocks.6.models.0.conv_3.weight", "eps_model.up_blocks.6.models.0.conv_3.bias", "eps_model.up_blocks.6.models.1.conv_1.0.weight", "eps_model.up_blocks.6.models.1.conv_1.0.bias", "eps_model.up_blocks.6.models.1.conv_1.2.weight", "eps_model.up_blocks.6.models.1.conv_1.2.bias", "eps_model.up_blocks.6.models.1.conv_2.0.weight", "eps_model.up_blocks.6.models.1.conv_2.0.bias", "eps_model.up_blocks.6.models.1.conv_2.2.weight", "eps_model.up_blocks.6.models.1.conv_2.2.bias", "eps_model.up_blocks.6.models.1.time_emb.ln.1.weight", "eps_model.up_blocks.6.models.1.time_emb.ln.1.bias". 
	Unexpected key(s) in state_dict: "eps_model.down_blocks.0.models.1.conv_1.0.weight", "eps_model.down_blocks.0.models.1.conv_1.0.bias", "eps_model.down_blocks.0.models.1.conv_1.2.weight", "eps_model.down_blocks.0.models.1.conv_1.2.bias", "eps_model.down_blocks.0.models.1.conv_2.0.weight", "eps_model.down_blocks.0.models.1.conv_2.0.bias", "eps_model.down_blocks.0.models.1.conv_2.2.weight", "eps_model.down_blocks.0.models.1.conv_2.2.bias", "eps_model.down_blocks.0.models.1.time_emb.ln.1.weight", "eps_model.down_blocks.0.models.1.time_emb.ln.1.bias", "eps_model.down_blocks.4.models.0.conv_3.weight", "eps_model.down_blocks.4.models.0.conv_3.bias", "eps_model.down_blocks.4.models.1.conv_1.0.weight", "eps_model.down_blocks.4.models.1.conv_1.0.bias", "eps_model.down_blocks.4.models.1.conv_1.2.weight", "eps_model.down_blocks.4.models.1.conv_1.2.bias", "eps_model.down_blocks.4.models.1.conv_2.0.weight", "eps_model.down_blocks.4.models.1.conv_2.0.bias", "eps_model.down_blocks.4.models.1.conv_2.2.weight", "eps_model.down_blocks.4.models.1.conv_2.2.bias", "eps_model.down_blocks.4.models.1.time_emb.ln.1.weight", "eps_model.down_blocks.4.models.1.time_emb.ln.1.bias", "eps_model.up_blocks.4.models.1.conv_1.0.weight", "eps_model.up_blocks.4.models.1.conv_1.0.bias", "eps_model.up_blocks.4.models.1.conv_1.2.weight", "eps_model.up_blocks.4.models.1.conv_1.2.bias", "eps_model.up_blocks.4.models.1.conv_2.0.weight", "eps_model.up_blocks.4.models.1.conv_2.0.bias", "eps_model.up_blocks.4.models.1.conv_2.2.weight", "eps_model.up_blocks.4.models.1.conv_2.2.bias", "eps_model.up_blocks.4.models.1.time_emb.ln.1.weight", "eps_model.up_blocks.4.models.1.time_emb.ln.1.bias". 
	size mismatch for eps_model.first_conv.weight: copying a param with shape torch.Size([128, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 3, 3, 3]).
	size mismatch for eps_model.down_blocks.4.models.0.conv_1.2.weight: copying a param with shape torch.Size([512, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for eps_model.down_blocks.4.models.0.conv_1.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.down_blocks.4.models.0.conv_2.0.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.down_blocks.4.models.0.conv_2.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.down_blocks.4.models.0.conv_2.2.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for eps_model.down_blocks.4.models.0.conv_2.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.down_blocks.4.models.0.time_emb.ln.1.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for eps_model.down_blocks.4.models.0.time_emb.ln.1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.backbone.models.0.conv_1.0.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.backbone.models.0.conv_1.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.backbone.models.0.conv_1.2.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for eps_model.backbone.models.0.conv_1.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.backbone.models.0.conv_2.0.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.backbone.models.0.conv_2.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.backbone.models.0.conv_2.2.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for eps_model.backbone.models.0.conv_2.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.backbone.models.0.time_emb.ln.1.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for eps_model.backbone.models.0.time_emb.ln.1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.backbone.models.1.K_W: copying a param with shape torch.Size([4, 512, 128]) from checkpoint, the shape in current model is torch.Size([4, 256, 64]).
	size mismatch for eps_model.backbone.models.1.K_b: copying a param with shape torch.Size([4, 128]) from checkpoint, the shape in current model is torch.Size([4, 64]).
	size mismatch for eps_model.backbone.models.1.Q_W: copying a param with shape torch.Size([4, 512, 128]) from checkpoint, the shape in current model is torch.Size([4, 256, 64]).
	size mismatch for eps_model.backbone.models.1.Q_b: copying a param with shape torch.Size([4, 128]) from checkpoint, the shape in current model is torch.Size([4, 64]).
	size mismatch for eps_model.backbone.models.1.V_W: copying a param with shape torch.Size([4, 512, 128]) from checkpoint, the shape in current model is torch.Size([4, 256, 64]).
	size mismatch for eps_model.backbone.models.1.V_b: copying a param with shape torch.Size([4, 128]) from checkpoint, the shape in current model is torch.Size([4, 64]).
	size mismatch for eps_model.backbone.models.1.O_W: copying a param with shape torch.Size([4, 128, 512]) from checkpoint, the shape in current model is torch.Size([4, 64, 256]).
	size mismatch for eps_model.backbone.models.1.O_b: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.backbone.models.1.norm.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.backbone.models.1.norm.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.backbone.models.1.mlp.0.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.backbone.models.1.mlp.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.backbone.models.1.mlp.1.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([256, 256]).
	size mismatch for eps_model.backbone.models.1.mlp.1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.backbone.models.1.mlp.3.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([256, 256]).
	size mismatch for eps_model.backbone.models.1.mlp.3.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.backbone.models.2.conv_1.0.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.backbone.models.2.conv_1.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.backbone.models.2.conv_1.2.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for eps_model.backbone.models.2.conv_1.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.backbone.models.2.conv_2.0.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.backbone.models.2.conv_2.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.backbone.models.2.conv_2.2.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for eps_model.backbone.models.2.conv_2.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.backbone.models.2.time_emb.ln.1.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for eps_model.backbone.models.2.time_emb.ln.1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.0.models.0.conv_1.0.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for eps_model.up_blocks.0.models.0.conv_1.0.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for eps_model.up_blocks.0.models.0.conv_1.2.weight: copying a param with shape torch.Size([256, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 512, 3, 3]).
	size mismatch for eps_model.up_blocks.0.models.0.conv_3.weight: copying a param with shape torch.Size([256, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 512, 1, 1]).
	size mismatch for eps_model.up_blocks.2.models.0.conv_1.2.weight: copying a param with shape torch.Size([128, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 512, 3, 3]).
	size mismatch for eps_model.up_blocks.2.models.0.conv_1.2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.2.models.0.conv_2.0.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.2.models.0.conv_2.0.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.2.models.0.conv_2.2.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for eps_model.up_blocks.2.models.0.conv_2.2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.2.models.0.time_emb.ln.1.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for eps_model.up_blocks.2.models.0.time_emb.ln.1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.2.models.0.conv_3.weight: copying a param with shape torch.Size([128, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 512, 1, 1]).
	size mismatch for eps_model.up_blocks.2.models.0.conv_3.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.2.models.1.K_W: copying a param with shape torch.Size([4, 128, 32]) from checkpoint, the shape in current model is torch.Size([4, 256, 64]).
	size mismatch for eps_model.up_blocks.2.models.1.K_b: copying a param with shape torch.Size([4, 32]) from checkpoint, the shape in current model is torch.Size([4, 64]).
	size mismatch for eps_model.up_blocks.2.models.1.Q_W: copying a param with shape torch.Size([4, 128, 32]) from checkpoint, the shape in current model is torch.Size([4, 256, 64]).
	size mismatch for eps_model.up_blocks.2.models.1.Q_b: copying a param with shape torch.Size([4, 32]) from checkpoint, the shape in current model is torch.Size([4, 64]).
	size mismatch for eps_model.up_blocks.2.models.1.V_W: copying a param with shape torch.Size([4, 128, 32]) from checkpoint, the shape in current model is torch.Size([4, 256, 64]).
	size mismatch for eps_model.up_blocks.2.models.1.V_b: copying a param with shape torch.Size([4, 32]) from checkpoint, the shape in current model is torch.Size([4, 64]).
	size mismatch for eps_model.up_blocks.2.models.1.O_W: copying a param with shape torch.Size([4, 32, 128]) from checkpoint, the shape in current model is torch.Size([4, 64, 256]).
	size mismatch for eps_model.up_blocks.2.models.1.O_b: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.2.models.1.norm.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.2.models.1.norm.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.2.models.1.mlp.0.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.2.models.1.mlp.0.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.2.models.1.mlp.1.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([256, 256]).
	size mismatch for eps_model.up_blocks.2.models.1.mlp.1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.2.models.1.mlp.3.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([256, 256]).
	size mismatch for eps_model.up_blocks.2.models.1.mlp.3.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.2.models.2.conv_1.0.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.2.models.2.conv_1.0.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.2.models.2.conv_1.2.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for eps_model.up_blocks.2.models.2.conv_1.2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.2.models.2.conv_2.0.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.2.models.2.conv_2.0.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.2.models.2.conv_2.2.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for eps_model.up_blocks.2.models.2.conv_2.2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.2.models.2.time_emb.ln.1.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for eps_model.up_blocks.2.models.2.time_emb.ln.1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.2.models.3.K_W: copying a param with shape torch.Size([4, 128, 32]) from checkpoint, the shape in current model is torch.Size([4, 256, 64]).
	size mismatch for eps_model.up_blocks.2.models.3.K_b: copying a param with shape torch.Size([4, 32]) from checkpoint, the shape in current model is torch.Size([4, 64]).
	size mismatch for eps_model.up_blocks.2.models.3.Q_W: copying a param with shape torch.Size([4, 128, 32]) from checkpoint, the shape in current model is torch.Size([4, 256, 64]).
	size mismatch for eps_model.up_blocks.2.models.3.Q_b: copying a param with shape torch.Size([4, 32]) from checkpoint, the shape in current model is torch.Size([4, 64]).
	size mismatch for eps_model.up_blocks.2.models.3.V_W: copying a param with shape torch.Size([4, 128, 32]) from checkpoint, the shape in current model is torch.Size([4, 256, 64]).
	size mismatch for eps_model.up_blocks.2.models.3.V_b: copying a param with shape torch.Size([4, 32]) from checkpoint, the shape in current model is torch.Size([4, 64]).
	size mismatch for eps_model.up_blocks.2.models.3.O_W: copying a param with shape torch.Size([4, 32, 128]) from checkpoint, the shape in current model is torch.Size([4, 64, 256]).
	size mismatch for eps_model.up_blocks.2.models.3.O_b: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.2.models.3.norm.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.2.models.3.norm.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.2.models.3.mlp.0.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.2.models.3.mlp.0.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.2.models.3.mlp.1.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([256, 256]).
	size mismatch for eps_model.up_blocks.2.models.3.mlp.1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.2.models.3.mlp.3.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([256, 256]).
	size mismatch for eps_model.up_blocks.2.models.3.mlp.3.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.3.upscale.weight: copying a param with shape torch.Size([128, 128, 2, 2]) from checkpoint, the shape in current model is torch.Size([256, 256, 2, 2]).
	size mismatch for eps_model.up_blocks.3.upscale.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for eps_model.up_blocks.4.models.0.conv_1.0.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for eps_model.up_blocks.4.models.0.conv_1.0.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for eps_model.up_blocks.4.models.0.conv_1.2.weight: copying a param with shape torch.Size([128, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 512, 3, 3]).
	size mismatch for eps_model.up_blocks.4.models.0.conv_3.weight: copying a param with shape torch.Size([128, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 512, 1, 1]).
	size mismatch for eps_model.out.2.weight: copying a param with shape torch.Size([1, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([3, 128, 3, 3]).
	size mismatch for eps_model.out.2.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([3]).

### Show samples

In [None]:
n_samples = 10
x_t = ddpm.sample(n_samples=n_samples, size=data.train_dataset[0][0].shape)

result = []
for i in range(x_t.shape[0]):
    result.append(data.transform_to_pil(x_t[i]))

grid = make_grid(x_t, nrow=10)
save_image(grid, f"sample.png")

cols = 5
rows = (n_samples // cols) + (0 if n_samples % cols == 0 else 1)
fig, axs = plt.subplots(rows, cols, figsize=(3 * cols, 3 * rows))
for i in range(len(result)):
    row = i // cols
    axs[row, i % cols].imshow(result[i], cmap='gray')