In [5]:
%pip install --upgrade scikit-image torch torchvision torchaudio kornia pytorch-lightning lightning wandb>=0.12.10 diffusers["torch"] transformers einops matplotlib requests click

Note: you may need to restart the kernel to use updated packages.


In [6]:
from datasets.cifar_custom import __CIFAR_Customized, CIFAR10_Customized, CIFAR100_Customized
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--crop_size", type=int, default=256)
parser.add_argument("--train_data_path", type=str, default="cifar10_train")
parser.add_argument("--val_data_path", type=str, default="cifar10_validation")
parser.add_argument("--save_path", type=str, default="models")
parser.add_argument("--save_freq", type=int, default=10)
parser.add_argument("--top_k", type=int, default=1)
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--num_epochs", type=int, default=20)
parser.add_argument("--max_steps", type=int, default=1000)
parser.add_argument("--log_steps", type=int, default=100)
parser.add_argument("--num_timesteps", type=int, default=100)
parser.add_argument("--warmup_steps", type=int, default=100)
parser.add_argument("--accelerator", type=str, default="cpu")
#parser.add_argument("--devices", type=str, default=1)

args = parser.parse_args(args=[])

In [7]:
import argparse
import pytorch_lightning as pl
import kornia.augmentation as KA
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from diffusion.LatentDiffusion import LatentDiffusionConditional
from diffusion.LatentDiffusion import LatentDiffusion
from datasets.example_dataset import ExampleImageDataset
from utils.EMA import EMA


def train(args):
    transform = [
        KA.RandomCrop((2 * args.crop_size, 2 * args.crop_size)),
        KA.Resize((args.crop_size, args.crop_size), antialias=True),
        KA.RandomVerticalFlip()
    ]

    train_ds = CIFAR100_Customized(args.train_data_path, train=True, conditional=False)
    val_ds = CIFAR100_Customized(args.val_data_path, train=False, conditional=False)

    model = LatentDiffusion(train_dataset=train_ds,
                                       valid_dataset=val_ds,
                                       num_timesteps=args.num_timesteps,
                                       lr=args.lr,
                                       num_warmup_steps=args.warmup_steps,
                                       num_epochs=args.num_epochs,
                                       batch_size=args.batch_size)

    wandb_logger = WandbLogger(project="CS5340")
    wandb_logger.experiment.config["key"] = "33da90c88f8e092fbf39fa7ad18e3125504b51a6"

    trainer = pl.Trainer(
        max_steps=args.max_steps,
        max_epochs=args.num_epochs,
        callbacks=[
            EMA(0.9999),
            ModelCheckpoint(
                dirpath=args.save_path,
                every_n_epochs=args.save_freq,
                save_top_k=args.top_k
            )],
        accelerator=args.accelerator,
        #devices=args.devices,
        logger=wandb_logger,
        log_every_n_steps=args.log_steps
    )

    trainer.fit(model)

    trainer.save_checkpoint(args.save_path)

In [8]:
train(args)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to cifar10_train/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:16<00:00, 10138156.00it/s]


Extracting cifar10_train/cifar-100-python.tar.gz to cifar10_train
Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to cifar10_validation/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:16<00:00, 10085861.35it/s]


Extracting cifar10_validation/cifar-100-python.tar.gz to cifar10_validation
Is Time embed used ?  True


/Users/yulong/Desktop/fyp/nqs-tf2/.conda/lib/python3.10/site-packages/pytorch_lightning/loggers/wandb.py:391: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type                      | Params
----------------------------------------------------
0 | ae    | AutoEncoder               | 83.7 M
1 | model | DenoisingDiffusionProcess | 56.6 M
----------------------------------------------------
140 M     Trainable params
0         Non-trainable params
140 M     Total params
561.084   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

RuntimeError: Calculated padded input size per channel: (3 x 3). Kernel size: (4 x 4). Kernel size can't be greater than actual input size