In [1]:
import argparse

import numpy as np
import torch as th
import torch.distributed as dist

import os, sys
sys.path.insert(1, os.getcwd()) 
import random

from diffusion_openai.video_datasets import load_data
from diffusion_openai import dist_util, logger
from diffusion_openai.script_util import (
    NUM_CLASSES,
    model_and_diffusion_defaults,
    create_model_and_diffusion,
    add_dict_to_argparser,
    args_to_dict,
)
th.backends.cudnn.enabled = True  # Enable cuDNN
th.backends.cudnn.benchmark = True  # Use cuDNN's auto-tuner for the best performance

from dataclasses import dataclass

In [2]:
@dataclass
class Parameters:
    clip_denoised=True
    num_samples=8
    batch_size=8
    use_ddim=False
    model_path=""
    seq_len=20
    sampling_type="generation"
    cond_frames="0,"
    cond_generation=True
    resample_steps=1
    data_dir=''
    save_gt=False
    seed=42
    data_dir="/home/s_gladkykh/thesis/gif_dataset_64"
    batch_size=8
    image_size=64
    class_cond=False
    deterministic=False
    rgb=True
    seq_len=20

args = Parameters()

In [3]:
model_parameters = dict(
    image_size=64,
    class_cond=False,
    learn_sigma=False,
    sigma_small=False,
    num_channels=128,
    num_res_blocks=3,
    scale_time_dim=0,
    num_heads=4,
    num_heads_upsample=1,
    attention_resolutions="16,8",
    dropout=0.0,
    diffusion_steps=1000,
    noise_schedule="linear",
    timestep_respacing="",
    use_kl=False,
    predict_xstart=False,
    rescale_timesteps=True,
    rescale_learned_sigmas=True,
    use_checkpoint=False,
    use_scale_shift_norm=True,
    rgb=True
)

In [4]:
dist_util.setup_dist()
logger.configure(dir="/home/s_gladkykh/thesis/sky-diffusion/ramvid_notebooks/logs_sampling")

1
Logging to /home/s_gladkykh/thesis/sky-diffusion/ramvid_notebooks/logs_sampling


In [5]:
logger.log("creating model and diffusion...")
model, diffusion = create_model_and_diffusion(
    **model_parameters
)

creating model and diffusion...


In [6]:
model.load_state_dict(
    dist_util.load_state_dict("/home/s_gladkykh/thesis/sky-diffusion/ramvid_notebooks/logs/model006000.pt", map_location="cpu")
)
model.to(dist_util.dev())
model.eval()
print()




In [7]:
cond_kwargs = {}
cond_frames = []
if args.cond_generation:
    data = load_data(
        data_dir="/home/s_gladkykh/thesis/gif_dataset_64",
        batch_size=8,
        image_size=64,
        class_cond=False,
        deterministic=False,
        rgb=True,
        seq_len=20
    )
    
    num = ""
    for i in args.cond_frames:
        if i == ",":
            cond_frames.append(int(num))
            num = ""
        else:
            num = num + i
    print(num)
    ref_frames = list(i for i in range(args.seq_len) if i not in cond_frames)
    logger.log(f"cond_frames: {cond_frames}")
    logger.log(f"ref_frames: {ref_frames}")
    logger.log(f"seq_len: {args.seq_len}")
    cond_kwargs["resampling_steps"] = args.resample_steps
cond_kwargs["cond_frames"] = cond_frames
cond_kwargs["saver"] = None

if args.rgb:
    channels = 3
else:
    channels = 1


cond_frames: [0]
ref_frames: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
seq_len: 20


In [8]:
logger.log("sampling...")
all_videos = []
all_gt = []
while len(all_videos) * args.batch_size < args.num_samples:
    
    if args.cond_generation:
        video, _ = next(data)
        cond_kwargs["cond_img"] = video[:,:,cond_frames].to(dist_util.dev()) 
        video = video.to(dist_util.dev())


    sample_fn = (
        diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
    )

    sample = sample_fn(
        model,
        (args.batch_size, channels, args.seq_len, args.image_size, args.image_size),
        clip_denoised=args.clip_denoised,
        progress=False,
        cond_kwargs=cond_kwargs
    )

    sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
    sample = sample.permute(0, 2, 3, 4, 1)
    sample = sample.contiguous()

    gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
    dist.all_gather(gathered_samples, sample)  # gather not supported with NCCL
    all_videos.extend([sample.cpu().numpy() for sample in gathered_samples])
    logger.log(f"created {len(all_videos) * args.batch_size} samples")

    if args.cond_generation and args.save_gt:

        video = ((video + 1) * 127.5).clamp(0, 255).to(th.uint8)
        video = video.permute(0, 2, 3, 4, 1)
        video = video.contiguous()

        gathered_videos = [th.zeros_like(video) for _ in range(dist.get_world_size())]
        dist.all_gather(gathered_videos, video)  # gather not supported with NCCL
        all_gt.extend([video.cpu().numpy() for video in gathered_videos])
        logger.log(f"created {len(all_gt) * args.batch_size} videos")

sampling...
created 8 samples


In [9]:
print("hello")

hello


In [13]:
arr = np.concatenate(all_videos, axis=0)

if args.cond_generation and args.save_gt:
    arr_gt = np.concatenate(all_gt, axis=0)


if dist.get_rank() == 0:

    shape_str = "x".join([str(x) for x in arr.shape])
    logger.log(f"saving samples to {os.path.join(logger.get_dir(), shape_str)}")
    np.savez(os.path.join(logger.get_dir(), shape_str), arr)

    if args.cond_generation and args.save_gt:
        shape_str_gt = "x".join([str(x) for x in arr_gt.shape])
        logger.log(f"saving ground_truth to {os.path.join(logger.get_dir(), shape_str_gt)}")
        np.savez(os.path.join(logger.get_dir(), shape_str_gt), arr_gt)

dist.barrier()
logger.log("sampling complete")

saving samples to /home/s_gladkykh/thesis/sky-diffusion/ramvid_notebooks/logs_sampling/8x20x64x64x3
sampling complete


In [20]:
all_videos[0].shape

(8, 20, 64, 64, 3)

In [22]:
import numpy as np
import imageio

all_videos = all_videos[0]
for i in range(all_videos.shape[0]):
    # Create a writer object for GIF
    writer = imageio.get_writer(f'samples/video_{i}.gif', mode='I', duration=0.1)  # Adjust duration as needed
    
    # Iterate over each frame in the sequence
    for j in range(all_videos.shape[1]):
        # Append each frame to the writer object
        writer.append_data(all_videos[i, j])
    
    # Close the writer object
    writer.close()
