In [1]:
from unet import UNet
from ddim import DDIMScheduler
from sky_dataset import SkyDataset
from ema import EMA
import torch
import torchvision.transforms as transforms
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from diffusers.optimization import get_scheduler
from torch.cuda.amp import GradScaler, autocast
from tqdm.auto import tqdm
from utils import save_images, normalize_to_neg_one_to_one
import torch.nn.functional as F

In [2]:
from dataclasses import dataclass

@dataclass
class Config:
    resolution = 32
    n_timesteps = 1000
    learning_rate = 5e-6
    adam_beta1 = 0.9
    adam_beta2 = 0.99
    adam_weight_decay = 0.0
    train_batch_size = 16
    eval_batch_size = 16
    num_epochs = 50
    gradient_accumulation_steps = 1
    gamma = 0.996
    lr_scheduler = "cosine"
    lr_warmup_steps = 100
    fp16_precision = True
    use_clip_grad = False
    save_model_steps = 1000
    samples_dir = "samples"
    dataset_name = "SkyDiffusion3"
    n_inference_timesteps = 100
    output_dir = "models/SkyDiffusion3.pth"
    

args = Config()

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
model = UNet(3, image_size=args.resolution, hidden_dims=[64, 128, 256, 512])
model = model.to(device)

In [5]:
noise_scheduler = DDIMScheduler(num_train_timesteps=args.n_timesteps,
                                beta_schedule="cosine")

In [6]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=args.learning_rate,
    betas=(args.adam_beta1, args.adam_beta2),
    weight_decay=args.adam_weight_decay,
)

In [7]:
tfms = transforms.Compose([
    transforms.Resize((args.resolution, args.resolution)),
    transforms.ToTensor()
])

In [8]:
dataset = SkyDataset(transform=tfms)
train_dataloader = DataLoader(dataset=dataset, batch_size=args.train_batch_size, shuffle=True)
steps_per_epcoch = len(train_dataloader)
total_num_steps = (steps_per_epcoch * args.num_epochs) // args.gradient_accumulation_steps
total_num_steps += int(total_num_steps * 10/100)

In [9]:
gamma = args.gamma
ema = EMA(model, gamma, total_num_steps)

In [10]:
lr_scheduler = get_scheduler(
    args.lr_scheduler,
    optimizer=optimizer,
    num_warmup_steps=args.lr_warmup_steps,
    num_training_steps=total_num_steps,
)

In [11]:
scaler = GradScaler(enabled=args.fp16_precision)
global_step = 0
losses = []

In [None]:
for epoch in range(args.num_epochs):
    progress_bar = tqdm(total=steps_per_epcoch)
    progress_bar.set_description(f"Epoch {epoch}")
    losses_log = 0
    for step, batch in enumerate(train_dataloader):
        clean_images = batch.to(device)
        clean_images = normalize_to_neg_one_to_one(clean_images)

        batch_size = clean_images.shape[0]
        noise = torch.randn(clean_images.shape).to(device)

        timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (batch_size,), device=device).long()
        noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

        optimizer.zero_grad()
        with autocast(enabled=args.fp16_precision):
            noise_pred = model(noisy_images, timesteps)
            loss = F.l1_loss(noise_pred, noise)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        ema.update_params(gamma)
        gamma = ema.update_gamma(global_step)

        if args.use_clip_grad:
            clip_grad_norm_(model.parameters(), 1.0)

        lr_scheduler.step()

        progress_bar.update(1)
        losses_log += loss.detach().item()
        logs = {
            "loss_avg": losses_log / (step + 1),
            "loss": loss.detach().item(),
            "lr": lr_scheduler.get_last_lr()[0],
            "step": global_step,
            "gamma": gamma
        }

        progress_bar.set_postfix(**logs)
        global_step += 1

    progress_bar.close()
    losses.append(losses_log / (step + 1)) 

    ema.ema_model.eval()
    with torch.no_grad():
        # has to be instantiated every time, because of reproducibility
        generator = torch.manual_seed(0)
        generated_images = noise_scheduler.generate(
            ema.ema_model,
            num_inference_steps=args.n_inference_timesteps,
            generator=generator,
            eta=1.0,
            use_clipped_model_output=True,
            batch_size=args.eval_batch_size,
            output_type="numpy")

        save_images(generated_images, epoch, args)

        torch.save(
            {
                'model_state': model.state_dict(),
                'ema_model_state': ema.ema_model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
            }, args.output_dir
        )

    

  0%|          | 0/1178 [00:00<?, ?it/s]

100%|█████████████████████████████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 61.66it/s]


  0%|          | 0/1178 [00:00<?, ?it/s]