In [1]:
import torch
from bit_diffusion import Unet, Trainer, BitDiffusion

model = Unet(
    dim = 32,
    channels = 3,
    dim_mults = (1, 2, 4, 8),
)

bit_diffusion = BitDiffusion(
    model,
    image_size = 128,
    timesteps = 100,
    time_difference = 0.1,       # they found in the paper that at lower number of timesteps, a time difference during sampling of greater than 0 helps FID. as timesteps increases, this time difference can be set to 0 as it does not help
    use_ddim = True              # use ddim
)

trainer = Trainer(
    bit_diffusion,
    'data/train/',             # path to your folder of images
    results_folder = './results',     # where to save results
    num_samples = 16,                 # number of samples
    train_batch_size = 4,             # training batch size
    gradient_accumulate_every = 4,    # gradient accumulation
    train_lr = 1e-4,                  # learning rate
    save_and_sample_every = 1000,     # how often to save and sample
    train_num_steps = 30000,         # total training steps
    ema_decay = 0.995,                # exponential moving average decay
)

trainer.load("13")

trainer.train()

  from .autonotebook import tqdm as notebook_tqdm
sampling loop time step: 100%|██████████| 100/100 [00:08<00:00, 11.72it/s]
sampling loop time step: 100%|██████████| 100/100 [00:08<00:00, 11.70it/s]
sampling loop time step: 100%|██████████| 100/100 [00:08<00:00, 11.63it/s]
sampling loop time step: 100%|██████████| 100/100 [00:08<00:00, 11.71it/s]
  ma_params.data.lerp_(current_params.data, 1. - current_decay)
sampling loop time step: 100%|██████████| 100/100 [00:12<00:00,  8.10it/s] 
sampling loop time step: 100%|██████████| 100/100 [00:12<00:00,  7.76it/s]
sampling loop time step: 100%|██████████| 100/100 [00:12<00:00,  7.82it/s]
sampling loop time step: 100%|██████████| 100/100 [00:12<00:00,  7.87it/s]
sampling loop time step: 100%|██████████| 100/100 [00:08<00:00, 12.24it/s] 
sampling loop time step: 100%|██████████| 100/100 [00:08<00:00, 12.25it/s]
sampling loop time step: 100%|██████████| 100/100 [16:30<00:00,  9.91s/it]
sampling loop time step: 100%|██████████| 100/100 [00:08<00