In [1]:
import torch as th
from diffusion_openai import dist_util, logger
from diffusion_openai.resample import create_named_schedule_sampler
from diffusion_openai.script_util import create_model_and_diffusion
from diffusion_openai.train_util import TrainLoop
from diffusion_openai.video_datasets import load_data

th.backends.cudnn.enabled = True
th.backends.cudnn.benchmark = True

In [2]:
model_parameters = dict(
    image_size=64,
    class_cond=False,
    learn_sigma=False,
    sigma_small=False,
    num_channels=128,
    num_res_blocks=3,
    scale_time_dim=0,
    num_heads=4,
    num_heads_upsample=1,
    attention_resolutions="16,8",
    dropout=0.0,
    diffusion_steps=1000,
    noise_schedule="linear",
    timestep_respacing="",
    use_kl=False,
    predict_xstart=False,
    rescale_timesteps=True,
    rescale_learned_sigmas=True,
    use_checkpoint=False,
    use_scale_shift_norm=True,
    rgb=True
)

In [3]:
dist_util.setup_dist()
logger.configure(dir="/home/s_gladkykh/thesis/sky-diffusion/ramvid_notebooks/logs3")
for key, item in model_parameters.items():
    logger.logkv(key, item)
logger.dumpkvs()

1
Logging to /home/s_gladkykh/thesis/sky-diffusion/ramvid_notebooks/logs3
-------------------------------------
| attention_resolutions  | 16,8     |
| class_cond             | 0        |
| diffusion_steps        | 1e+03    |
| dropout                | 0        |
| image_size             | 64       |
| learn_sigma            | 0        |
| noise_schedule         | linear   |
| num_channels           | 128      |
| num_heads              | 4        |
| num_heads_upsample     | 1        |
| num_res_blocks         | 3        |
| predict_xstart         | 0        |
| rescale_learned_sigmas | 1        |
| rescale_timesteps      | 1        |
| rgb                    | 1        |
| scale_time_dim         | 0        |
| sigma_small            | 0        |
| timestep_respacing     |          |
| use_checkpoint         | 0        |
| use_kl                 | 0        |
| use_scale_shift_norm   | 1        |
-------------------------------------


defaultdict(float,
            {'image_size': 64,
             'class_cond': False,
             'learn_sigma': False,
             'sigma_small': False,
             'num_channels': 128,
             'num_res_blocks': 3,
             'scale_time_dim': 0,
             'num_heads': 4,
             'num_heads_upsample': 1,
             'attention_resolutions': '16,8',
             'dropout': 0.0,
             'diffusion_steps': 1000,
             'noise_schedule': 'linear',
             'timestep_respacing': '',
             'use_kl': False,
             'predict_xstart': False,
             'rescale_timesteps': True,
             'rescale_learned_sigmas': True,
             'use_checkpoint': False,
             'use_scale_shift_norm': True,
             'rgb': True})

In [4]:
model, diffusion = create_model_and_diffusion(
    **model_parameters
)

In [5]:
model.to(dist_util.dev())
schedule_sampler = create_named_schedule_sampler("uniform", diffusion)

In [6]:
data = load_data(
    data_dir="/home/s_gladkykh/thesis/gif_dataset_64",
    batch_size=12,
    image_size=64,
    class_cond=False,
    deterministic=False,
    rgb=True,
    seq_len=20
)

In [7]:
mask_range = [0, 20]

In [None]:
TrainLoop(
    model=model,
    diffusion=diffusion,
    data=data,
    batch_size=12,
    microbatch=-1,
    lr=1e-5,
    ema_rate="0.9999",
    log_interval=100,
    save_interval=1000,
    resume_checkpoint="",
    use_fp16=False,
    fp16_scale_growth=1e-3,
    schedule_sampler=schedule_sampler,
    weight_decay=0.0,
    lr_anneal_steps=0,
    clip=1,
    anneal_type=None,
    steps_drop=0.0,
    drop=0.0,
    decay=0.0,
    max_num_mask_frames=4,
    mask_range=mask_range, 
    uncondition_rate=0.5,
    exclude_conditional=True,
).run_loop()

global batch size = 12
world_size: 1
------------------------
| grad_norm | 1        |
| loss      | 1        |
| loss_q0   | 0.999    |
| loss_q1   | 1        |
| loss_q2   | 1        |
| loss_q3   | 1        |
| mse       | 1        |
| mse_q0    | 0.999    |
| mse_q1    | 1        |
| mse_q2    | 1        |
| mse_q3    | 1        |
| samples   | 12       |
| step      | 0        |
------------------------
saving model 0...
saving model 0.9999...
------------------------
| grad_norm | 1        |
| loss      | 0.824    |
| loss_q0   | 0.834    |
| loss_q1   | 0.823    |
| loss_q2   | 0.826    |
| loss_q3   | 0.814    |
| mse       | 0.824    |
| mse_q0    | 0.834    |
| mse_q1    | 0.823    |
| mse_q2    | 0.826    |
| mse_q3    | 0.814    |
| samples   | 1.21e+03 |
| step      | 100      |
------------------------
------------------------
| grad_norm | 1        |
| loss      | 0.485    |
| loss_q0   | 0.499    |
| loss_q1   | 0.481    |
| loss_q2   | 0.483    |
| loss_q3   | 0.478   