In [1]:
import argparse

from dropout_diffusion import dist_util, logger
from dropout_diffusion.image_datasets import load_data
from dropout_diffusion.resample import create_named_schedule_sampler
from dropout_diffusion.script_util import (
    add_dict_to_argparser,
    args_to_dict,
    create_model_and_diffusion,
    model_and_diffusion_defaults,
)
from dropout_diffusion.train_util import TrainLoop


In [2]:
default_training = dict(
    data_dir="",
    schedule_sampler="uniform",
    lr=1e-4,
    weight_decay=0.0,
    lr_anneal_steps=0,
    batch_size=1,
    microbatch=-1,  # -1 disables microbatches
    ema_rate="0.9999",  # comma-separated list of EMA values
    log_interval=10,
    save_interval=10000,
    resume_checkpoint="",
    use_fp16=False,
    fp16_scale_growth=1e-3,
)

In [3]:
model_and_diffusion_default = dict(
    image_size=32,
    num_channels=128,
    num_res_blocks=2,
    num_heads=4,
    num_heads_upsample=-1,
    attention_resolutions="16,8",
    dropout=0.0,
    learn_sigma=False,
    sigma_small=False,
    class_cond=False,
    diffusion_steps=1000,
    noise_schedule="linear",
    timestep_respacing="",
    use_kl=False,
    predict_xstart=False,
    rescale_timesteps=True,
    rescale_learned_sigmas=True,
    use_checkpoint=False,
    use_scale_shift_norm=True,
)   

In [4]:
default_training.update(model_and_diffusion_default)

In [5]:
model, diffusion = create_model_and_diffusion(**model_and_diffusion_default)

In [6]:
schedule_sampler = create_named_schedule_sampler(default_training["schedule_sampler"], diffusion)

In [7]:
data = load_data(
    data_dir="cifar_train",
    batch_size=2,
    image_size=32,
    class_cond=False
)

In [8]:
dist_util.setup_dist()

In [9]:
trainloop =     TrainLoop(
        model=model,
        diffusion=diffusion,
        data=data,
        batch_size=2,
        microbatch=default_training["microbatch"],
        lr=1e-4,
        ema_rate="0.9999",
        log_interval=100,
        save_interval=10000,
        resume_checkpoint="",
        use_fp16=False,
        fp16_scale_growth=1e-3,
        schedule_sampler=schedule_sampler,
        weight_decay=0.0,
        lr_anneal_steps=0,
    )

In [10]:
trainloop.master_params

[Parameter containing:
 tensor([[-0.0470, -0.0765, -0.0856,  ...,  0.0051, -0.0794,  0.0239],
         [ 0.0342, -0.0088,  0.0582,  ..., -0.0422, -0.0392,  0.0094],
         [-0.0504,  0.0820, -0.0318,  ..., -0.0055, -0.0047, -0.0332],
         ...,
         [-0.0132, -0.0459,  0.0323,  ...,  0.0802, -0.0509,  0.0511],
         [ 0.0101, -0.0156, -0.0447,  ...,  0.0004, -0.0529, -0.0301],
         [ 0.0377,  0.0567,  0.0434,  ...,  0.0687,  0.0459,  0.0545]],
        requires_grad=True),
 Parameter containing:
 tensor([-0.0689, -0.0124, -0.0059, -0.0865, -0.0719, -0.0728,  0.0792, -0.0462,
         -0.0805, -0.0191, -0.0265,  0.0103,  0.0671, -0.0524,  0.0533, -0.0356,
          0.0012, -0.0542,  0.0288,  0.0646, -0.0559,  0.0440, -0.0520, -0.0881,
          0.0320,  0.0613,  0.0073, -0.0842, -0.0328,  0.0352, -0.0596,  0.0223,
         -0.0807,  0.0122,  0.0875, -0.0457, -0.0692, -0.0835,  0.0349, -0.0071,
         -0.0436,  0.0779, -0.0465,  0.0263,  0.0777, -0.0654,  0.0551, -0.0320

In [11]:
import torch

In [12]:
with torch.no_grad():
    for i, param in enumerate(trainloop.master_params):
        torch.clamp_(param, -1, 1)

In [13]:
trainloop.master_params

[Parameter containing:
 tensor([[-0.0470, -0.0765, -0.0856,  ...,  0.0051, -0.0794,  0.0239],
         [ 0.0342, -0.0088,  0.0582,  ..., -0.0422, -0.0392,  0.0094],
         [-0.0504,  0.0820, -0.0318,  ..., -0.0055, -0.0047, -0.0332],
         ...,
         [-0.0132, -0.0459,  0.0323,  ...,  0.0802, -0.0509,  0.0511],
         [ 0.0101, -0.0156, -0.0447,  ...,  0.0004, -0.0529, -0.0301],
         [ 0.0377,  0.0567,  0.0434,  ...,  0.0687,  0.0459,  0.0545]],
        requires_grad=True),
 Parameter containing:
 tensor([-0.0689, -0.0124, -0.0059, -0.0865, -0.0719, -0.0728,  0.0792, -0.0462,
         -0.0805, -0.0191, -0.0265,  0.0103,  0.0671, -0.0524,  0.0533, -0.0356,
          0.0012, -0.0542,  0.0288,  0.0646, -0.0559,  0.0440, -0.0520, -0.0881,
          0.0320,  0.0613,  0.0073, -0.0842, -0.0328,  0.0352, -0.0596,  0.0223,
         -0.0807,  0.0122,  0.0875, -0.0457, -0.0692, -0.0835,  0.0349, -0.0071,
         -0.0436,  0.0779, -0.0465,  0.0263,  0.0777, -0.0654,  0.0551, -0.0320