In [1]:
import os
from dataclasses import dataclass
from pathlib import Path

import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from accelerate import Accelerator, notebook_launcher
from cond_ddpm_pipeline import DDPMPipeline
from cond_sky_dataset import SkyDataset
from diffusers import DDPMScheduler, UNet2DModel
from diffusers.optimization import get_cosine_schedule_with_warmup
from diffusers.utils import make_image_grid
from huggingface_hub import create_repo, upload_folder
from torch.utils.data import DataLoader, Dataset
from torchvision import utils
from tqdm.auto import tqdm


In [2]:
@dataclass
class TrainingConfig:
    seq_len = 5
    
    image_size = 128
    train_batch_size = 16
    eval_batch_size = 16
    num_epochs = 10000
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_image_epochs = 3
    save_model_epochs = 3
    mixed_precision = "fp16"
    output_dir = "cond_sky_diffusion_128_lr1e-4_bs16_e1000_4"

    push_to_hub = True
    hub_model_id = "cond_diffusion_128_lr1e-4_bs16_e1000_4"
    hub_private_repo = False
    overwrite_output_dir = True
    seed = 42


config = TrainingConfig()

In [3]:
if not os.path.exists(config.output_dir):
    os.makedirs(config.output_dir)
    print(f"Folder '{config.output_dir}' created.")
else:
    print(f"Folder '{config.output_dir}' already exists.")

Folder 'cond_sky_diffusion_128_lr1e-4_bs16_e1000_4' already exists.


In [4]:
preprocess = transforms.Compose(
    [
        transforms.Resize((config.image_size, config.image_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ]
)

In [5]:
dataset = SkyDataset(transform=preprocess, seq_len=config.seq_len, offset=1)
train_dataloader = DataLoader(dataset=dataset, batch_size=config.train_batch_size, shuffle=True)

In [6]:
batch = next(iter(train_dataloader))
real_imgs = batch[1].to("cuda")
conditioning = batch[0].to("cuda")

In [7]:
def save_images(generated_images, out_dir, filename):
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    utils.save_image(generated_images,
                     f"{out_dir}/{filename}_grid.jpeg",
                     nrow=4)

In [8]:
save_images((real_imgs + 1) / 2, config.output_dir, "actual")

In [9]:
reshaped_images =  ((conditioning + 1) / 2).view(16, 4, 3, 128, 128)
split_images = torch.unbind(reshaped_images, dim=1)
split_images_list = list(split_images)

In [10]:
for i in range(1, 5):
    save_images(split_images_list[i - 1], config.output_dir, f"cond_{i}")

In [11]:
model = UNet2DModel(
    sample_size=config.image_size,
    in_channels=3 * config.seq_len,
    out_channels=3,
    layers_per_block=2,
    block_out_channels=(128, 128, 256, 256, 512, 512),
    down_block_types=(
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)

In [12]:
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(len(train_dataloader) * config.num_epochs),
)

In [13]:
def evaluate(config, epoch, pipeline):
    images = pipeline(
        conditioning=conditioning,
        batch_size=config.eval_batch_size,
    ).images

    image_grid = make_image_grid(images, rows=4, cols=4)

    test_dir = os.path.join(config.output_dir, "samples")
    os.makedirs(test_dir, exist_ok=True)
    image_grid.save(f"{test_dir}/{epoch:04d}.png")

In [16]:
def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
    # Initialize accelerator and tensorboard logging
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="tensorboard",
        project_dir=os.path.join(config.output_dir, "logs"),
    )
    if accelerator.is_main_process:
        if config.output_dir is not None:
            os.makedirs(config.output_dir, exist_ok=True)
        if config.push_to_hub:
            print()
            repo_id = create_repo(
                repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True
            ).repo_id
        accelerator.init_trackers("train_example")

    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )

    global_step = 0

    for epoch in range(config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in enumerate(train_dataloader):
            conditioning_frames = batch[0]
            clean_images = batch[1]

            noise = torch.randn(clean_images.shape, device=clean_images.device)
            bs = clean_images.shape[0]

            # Sample a random timestep for each image
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device,
                dtype=torch.int64
            )

            # Add noise to the clean images according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
            cond_and_noisy = torch.cat([conditioning_frames, noisy_images], axis=1)

            with accelerator.accumulate(model):
                # Predict the noise residual
                noise_pred = model(cond_and_noisy, timesteps, return_dict=False)[0]
                loss = F.mse_loss(noise_pred, noise)
                accelerator.backward(loss)

                accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1

        # After each epoch you optionally sample some demo images with evaluate() and save the model
        if accelerator.is_main_process:
            pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)

            if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
                evaluate(config, epoch, pipeline)

            if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
                torch.save(model.state_dict(), f"{config.output_dir}/models/model{epoch + 1}.pt")
                if config.push_to_hub:
                    upload_folder(
                        repo_id=repo_id,
                        folder_path=config.output_dir,
                        commit_message=f"Epoch {epoch + 1}",
                        ignore_patterns=["step_*", "epoch_*"],
                    )
                else:
                    pipeline.save_pretrained(config.output_dir)

In [None]:
notebook_launcher(train_loop, (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler), num_processes=1)

Launching training on one GPU.


  0%|          | 0/1158 [00:00<?, ?it/s]