In [1]:
%pip install --upgrade scikit-image torch torchvision torchaudio kornia pytorch-lightning lightning wandb>=0.12.10 diffusers["torch"] transformers einops matplotlib requests click

Note: you may need to restart the kernel to use updated packages.


In [2]:
import argparse
import pytorch_lightning as pl
import kornia.augmentation as KA
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from diffusion.LatentDiffusion import LatentDiffusionConditional
from datasets.example_dataset import ExampleImageDataset
from utils.EMA import EMA


def train(args):
    transform = [
        KA.RandomCrop((2 * args.crop_size, 2 * args.crop_size)),
        KA.Resize((args.crop_size, args.crop_size), antialias=True),
        KA.RandomVerticalFlip()
    ]

    train_ds = ExampleImageDataset(args.train_data_path,
                                   transforms=transform,
                                   conditional=True)
    val_ds = ExampleImageDataset(args.val_data_path,
                                 transforms=transform,
                                 conditional=True)
    
    print(train_ds)

    model = LatentDiffusionConditional(train_dataset=train_ds,
                                       valid_dataset=val_ds,
                                       num_timesteps=args.num_timesteps,
                                       lr=args.lr,
                                       num_warmup_steps=args.warmup_steps,
                                       num_epochs=args.num_epochs,
                                       batch_size=args.batch_size)

    wandb_logger = WandbLogger(project="CS5340")

    trainer = pl.Trainer(
        max_steps=args.max_steps,
        max_epochs=args.num_epochs,
        callbacks=[
            EMA(0.9999),
            ModelCheckpoint(
                dirpath=args.save_path,
                every_n_epochs=args.save_freq,
                save_top_k=args.top_k
            )],
        accelerator="cpu",
        #devices=args.devices,
        logger=wandb_logger,
        log_every_n_steps=args.log_steps
    )

    trainer.fit(model)

    trainer.save_checkpoint(args.save_path)


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
parser = argparse.ArgumentParser()
parser.add_argument("--crop_size", type=int, default=16)
parser.add_argument("--train_data_path", type=str, default="temp/train/data/0")
parser.add_argument("--val_data_path", type=str, default="temp/validation/data/0")
parser.add_argument("--save_path", type=str, default="temp/save/model")
parser.add_argument("--save_freq", type=int, default=200)
parser.add_argument("--top_k", type=int, default=1)
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--num_epochs", type=int, default=10)
parser.add_argument("--max_steps", type=int, default=100000)
parser.add_argument("--log_steps", type=int, default=100)
parser.add_argument("--num_timesteps", type=int, default=1000)
parser.add_argument("--warmup_steps", type=int, default=100)
parser.add_argument("--devices", type=str, default=1)

_StoreAction(option_strings=['--devices'], dest='devices', nargs=None, const=None, default=1, type=<class 'str'>, choices=None, help=None, metavar=None)

In [4]:
args = parser.parse_args(args=[])

In [5]:
train(args)

<datasets.example_dataset.ExampleImageDataset object at 0x7fe28076d3d0>
Is Time embed used ?  True


  rank_zero_warn(
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/yulong/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mgyulong1[0m ([33mcs5340[0m). Use [1m`wandb login --relogin`[0m to force relogin



  | Name  | Type                                 | Params
---------------------------------------------------------------
0 | ae    | AutoEncoder                          | 83.7 M
1 | model | DenoisingDiffusionConditionalProcess | 56.6 M
---------------------------------------------------------------
140 M     Trainable params
0         Non-trainable params
140 M     Total params
561.105   Total estimated model params size (MB)


RuntimeError: The NVIDIA driver on your system is too old (found version 11060). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver.