In [None]:
from dataclasses import dataclass
from pathlib import Path

import torch
import torch.nn.functional as F
from diffusers import UNet2DModel
from torch.utils.data import DataLoader
from torchinfo import summary
from torchvision import datasets, transforms

In [None]:
@dataclass
class TrainingConfig:
    image_size: int = 32
    train_batch_size: int = 256
    eval_batch_size: int = 16  # how many images to sample during evaluation
    num_epochs: int = 20  # 50
    gradient_accumulation_steps: int = 1
    learning_rate: float = 1e-4
    lr_warmup_steps: int = 500
    save_image_epochs: int = 10
    save_model_epochs: int = 30
    mixed_precision: str = (
        "fp16"  # `no` for float32, `fp16` for automatic mixed precision
    )
    device: torch.device = torch.device("cuda:1")
    output_dir: str = (
        "cifar10-unconditional"  # the model name locally and on the HF Hub
    )

    push_to_hub: bool = False  # whether to upload the saved model to the HF Hub
    # hub_model_id = "<your-username>/<my-awesome-model>"  # the name of the repository to create on the HF Hub
    # hub_private_repo = None
    overwrite_output_dir: bool = (
        True  # overwrite the old model when re-running the notebook
    )
    seed: int = 0


config = TrainingConfig()

In [None]:
root = Path("./datasets")
root.mkdir(parents=True, exist_ok=True)

transform = transforms.Compose(
    [
        transforms.Grayscale(),
        transforms.ToTensor(),
        # transforms.Normalize(mean=, std=), # ???????
    ]
)

train_ds = datasets.CIFAR10(root=root, train=True, transform=transform, download=True)
test_ds = datasets.CIFAR10(root=root, train=False, transform=transform, download=True)

train_loader = DataLoader(
    train_ds, batch_size=config.train_batch_size, shuffle=True, num_workers=16
)
test_loader = DataLoader(
    test_ds, batch_size=config.eval_batch_size, shuffle=False, num_workers=16
)

In [None]:
model = UNet2DModel(
    sample_size=config.image_size,  # the target image resolution
    in_channels=1,  # the number of input channels, 3 for RGB images
    out_channels=1,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(
        64,
        64,
        128,
        128,
        256,
        # 128,
        # 128,
        # 256,
        # 256,
        # 512,
        # 512,
    ),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        # "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        # "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)

In [None]:
class SimpleModel(torch.nn.Module):
    def __init__(self, my_model):
        super(SimpleModel, self).__init__()
        self.model = my_model

    def forward(self, x):
        return self.model(x, timestep=torch.tensor([1.0]).to(config.device))


sm = SimpleModel(model)

for batch in train_loader:
    img1, _ = batch
    break

model = model.to(config.device)
img1 = img1.to(config.device)
ts = torch.tensor([1.0]).to(config.device)

model(img1, timestep=ts)

In [None]:
sm = sm.to(config.device)
sm(img1)

In [None]:
summary(sm, (1, 1, 32, 32))

In [None]:
# import torch
# from PIL import Image
from diffusers import DDPMScheduler

noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
# noise = torch.randn(sample_image.shape)
timesteps = torch.LongTensor([50])
# noisy_image = noise_scheduler.add_noise(sample_image, noise, timesteps)

# Image.fromarray(((noisy_image.permute(0, 2, 3, 1) + 1.0) * 127.5).type(torch.uint8).numpy()[0])

In [None]:
from diffusers.optimization import get_cosine_schedule_with_warmup

optimizer = torch.optim.AdamW(
    model.parameters(), lr=config.learning_rate, foreach=False
)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(len(train_loader) * config.num_epochs),
)

In [None]:
import os

from diffusers import DDPMPipeline
from diffusers.utils import make_image_grid


def evaluate(config, epoch, pipeline):
    # Sample some images from random noise (this is the backward diffusion process).
    # The default pipeline output type is `List[PIL.Image]`
    images = pipeline(
        batch_size=config.eval_batch_size,
        generator=torch.Generator(device="cpu").manual_seed(
            config.seed
        ),  # Use a separate torch generator to avoid rewinding the random state of the main training loop
    ).images

    # Make a grid out of the images
    image_grid = make_image_grid(images, rows=4, cols=4)

    # Save the images
    test_dir = os.path.join(config.output_dir, "samples")
    os.makedirs(test_dir, exist_ok=True)
    image_grid.save(f"{test_dir}/{epoch:04d}.png")

In [None]:
from pathlib import Path

# from huggingface_hub import create_repo, upload_folder
from tqdm.auto import tqdm


def train_loop(
    config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler
):
    global_step = 0

    model.eval()
    model.to(config.device)

    pipeline = DDPMPipeline(unet=model, scheduler=noise_scheduler)

    # Now you train the model
    for epoch in range(config.num_epochs):
        # progress_bar = tqdm(
        #     total=len(train_dataloader), # disable=not accelerator.is_local_main_process
        # )
        # progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in tqdm(
            enumerate(train_dataloader), leave=False, total=len(train_dataloader)
        ):
            with torch.no_grad():
                clean_images, _labels = batch
                clean_images = clean_images.to(config.device)

                # Sample noise to add to the images
                noise = torch.randn(clean_images.shape, device=clean_images.device)
                batch_size = clean_images.shape[0]

                # Sample a random timestep for each image
                timesteps = torch.randint(
                    0,
                    noise_scheduler.config.num_train_timesteps,
                    (batch_size,),
                    # device=clean_images.device,
                    dtype=torch.int64,
                ).to(config.device)

                # Add noise to the clean images according to the noise magnitude at each timestep
                # (this is the forward diffusion process)
                noisy_images = noise_scheduler.add_noise(
                    clean_images, noise, timesteps
                ).to(config.device)

            # Predict the noise residual
            noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
            loss = F.mse_loss(noise_pred, noise)
            loss.backward()

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            # print(f"step: {step}")

            # progress_bar.update(1)
            # logs = {
            #     "loss": loss.detach().item(),
            #     "lr": lr_scheduler.get_last_lr()[0],
            #     "step": global_step,
            # }
            # progress_bar.set_postfix(**logs)
            global_step += 1

        if epoch % 3 == 0:
            evaluate(config, epoch, pipeline)
        print(f"Finished epoch {epoch}")

In [None]:
train_loop(config, model, noise_scheduler, optimizer, train_loader, lr_scheduler)

In [None]:
# Trying inference

noise_scheduler2 = DDPMScheduler(num_train_timesteps=1000)

pipeline = DDPMPipeline(unet=model, scheduler=noise_scheduler2)


evaluate(config, 3, pipeline)