In [11]:
import torch

TORCH_SEED = 69
torch.manual_seed(TORCH_SEED)

import logging
import os

os.makedirs('weights', exist_ok=True)
run_logger = logging.getLogger('run_logger')
file_handler = logging.FileHandler('weights/runs.log')
run_logger.addHandler(file_handler)


In [12]:
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torch.nn.functional import one_hot

NUM_WORKERS = 4
BATCH_SIZE = 128
INPUT_SIZE = 32
NUM_CLASSES = 10
dataset = datasets.CIFAR10

AUGMENTATIONS = (
        transforms.RandomHorizontalFlip(p=0.5),
)

NORMALIZATIONS = (
        transforms.ToTensor(),
        transforms.Normalize( (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
)

# Load datasets
training_data = dataset(root='./data', train=True, download=True, transform=transforms.Compose(AUGMENTATIONS + NORMALIZATIONS))

# Load data loaders
train_dataloader = DataLoader(training_data, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)

In [13]:
import model_factory

MODEL_NAME = 'salun-ddpm'

# Load model
model = model_factory.create_model(MODEL_NAME, NUM_CLASSES, input_size=INPUT_SIZE)

trainable_params = sum(param.numel() for param in model.parameters() if param.requires_grad)
print(f"Number of trainable parameters: {trainable_params}")

Number of trainable parameters: 38632323


In [14]:
import numpy as np

from torch import optim, nn
from models.ema import EMAHelper

LEARNING_RATE = 2e-4
WEIGHT_DECAY = 0
BETA1 = 1e-4
BETA2 = 0.02
TIMESTEPS = 1000
EMA_RATE = 0.9999

# Load optimizer and criterion
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

ema_helper = EMAHelper(mu=EMA_RATE)
ema_helper.register(model)

# Linear schedule
betas = np.linspace(
            BETA1, BETA2, TIMESTEPS, dtype=np.float64
        )
betas = torch.from_numpy(betas).float().to('cuda')

In [15]:
def criterion(
    model,
    x0: torch.Tensor,
    t: torch.LongTensor,
    c: torch.LongTensor,
    e: torch.Tensor,
    b: torch.Tensor,
    cond_drop_prob=0.1,
    keepdim=False,
):
    a = (1 - b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1)
    x = x0 * a.sqrt() + e * (1.0 - a).sqrt()
    output = model(x, t.float(), c, cond_drop_prob=cond_drop_prob, mode="train")
    if keepdim:
        return (e - output).square().sum(dim=(1, 2, 3))
    else:
        return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0)

In [None]:
from tqdm import tqdm

SAVE_FILE = 'weights/cifar10-salun.pth'
EPOCHS = 200
EVAL_EVERY = 5
GRAD_CLIP = 1.0

model.to('cuda')
model.train()

test_loss = None
test_accuracy = None

# Normal training
with tqdm(range(EPOCHS), unit='epoch') as pbar:
    for epoch in pbar:
        running_loss = 0.0

        for inputs, labels in train_dataloader:
            # Move inputs and labels to the specified device
            inputs, labels = inputs.to('cuda'), labels.to('cuda')

            batch_size = inputs.shape[0]

            # perturb data
            t = torch.randint(low=0, high=TIMESTEPS, size=(batch_size // 2 + 1,), device='cuda')
            t = torch.cat([t, TIMESTEPS - t - 1], dim=0)[:batch_size]
            noise = torch.randn_like(inputs, device='cuda')

            # Compute the loss and its gradients
            loss = criterion(model, inputs, t, labels, noise, betas)
            running_loss += loss.item()

            # Adjust learning weights and zero gradients
            optimizer.zero_grad()
            loss.backward()

            # Clip gradients
            try:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), GRAD_CLIP
                )
            except Exception:
                pass

            optimizer.step()

        train_loss = running_loss / len(train_dataloader)

        if (epoch + 1) % EVAL_EVERY == 0:
            run_logger.debug(f'Normal training checkpoint: save_file={SAVE_FILE}, epoch={epoch}, train_loss={train_loss}')
            torch.save(model, f'{SAVE_FILE}.checkpoint')

        pbar.set_postfix(train_loss=train_loss)

run_logger.info(f'Normal training complete: save_file={SAVE_FILE}, train_loss={train_loss}')
torch.save(model, SAVE_FILE)

  0%|          | 0/200 [00:00<?, ?epoch/s]