In [1]:
import os
import sys

import torch as th

sys.path.insert(1, os.getcwd())

from diffusion_openai import dist_util, logger
from diffusion_openai.script_util import create_model_and_diffusion
from diffusion_openai.video_datasets import load_data

th.backends.cudnn.enabled = True
th.backends.cudnn.benchmark = True

from dataclasses import dataclass

import matplotlib.pyplot as plt

In [ ]:
MODEL_PATH = "conditional_ramvid.pt"

In [2]:
@dataclass
class Parameters:
    clip_denoised=True
    num_samples=100
    batch_size=1
    use_ddim=False
    model_path=""
    seq_len=20
    sampling_type="generation"
    cond_frames="0,1,2,3,"
    cond_generation=True
    resample_steps=1
    data_dir=''
    save_gt=False
    seed=42
    data_dir="../../data_samples/gif_64"
    image_size=64
    class_cond=False
    deterministic=False
    rgb=True
    seq_len=20
    n_samples = 1
    output_dir = "generated_for_evaluation"

args = Parameters()

In [3]:
model_parameters = dict(
    image_size=64,
    class_cond=False,
    learn_sigma=False,
    sigma_small=False,
    num_channels=128,
    num_res_blocks=3,
    scale_time_dim=0,
    num_heads=4,
    num_heads_upsample=1,
    attention_resolutions="16,8",
    dropout=0.0,
    diffusion_steps=1000,
    noise_schedule="linear",
    timestep_respacing="",
    use_kl=False,
    predict_xstart=False,
    rescale_timesteps=True,
    rescale_learned_sigmas=True,
    use_checkpoint=False,
    use_scale_shift_norm=True,
    rgb=True
)

In [4]:
dist_util.setup_dist()
logger.configure(dir="logs_sampling")

1
Logging to /home/s_gladkykh/thesis/sky-diffusion/ramvid_notebooks/logs_sampling


In [5]:
cond_kwargs = {}
cond_frames = []
if args.cond_generation:
    data = load_data(
        data_dir=args.data_dir,
        batch_size=1,
        image_size=64,
        class_cond=False,
        deterministic=False,
        rgb=True,
        seq_len=20
    )
    
    num = ""
    for i in args.cond_frames:
        if i == ",":
            cond_frames.append(int(num))
            num = ""
        else:
            num = num + i
    print(num)
    ref_frames = list(i for i in range(args.seq_len) if i not in cond_frames)
    logger.log(f"cond_frames: {cond_frames}")
    logger.log(f"ref_frames: {ref_frames}")
    logger.log(f"seq_len: {args.seq_len}")
    cond_kwargs["resampling_steps"] = args.resample_steps
cond_kwargs["cond_frames"] = cond_frames
cond_kwargs["saver"] = None

if args.rgb:
    channels = 3
else:
    channels = 1


cond_frames: [0, 1, 2, 3]
ref_frames: [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
seq_len: 20


In [6]:
import numpy as np

In [None]:
from PIL import Image
def create_gif(arr, gif_path, duration=100, size=64):
    image_list = [Image.fromarray(np.uint8(myarray*255)) for myarray in arr]

    image_list[0].save(
            gif_path,
            save_all=True,
            append_images=image_list[1:], # append rest of the images
            duration=100, # in milliseconds
            loop=1)
    

model, diffusion = create_model_and_diffusion(
    **model_parameters
)

model.load_state_dict(
    dist_util.load_state_dict(MODEL_PATH, map_location="cpu")
)
model.to(dist_util.dev())
model.eval()


logger.log("sampling...")
all_videos = []
all_gt = []
generated_num = 0
os.makedirs(args.output_dir, exist_ok=True)
while generated_num < args.n_samples:
    if args.cond_generation:
        video, _ = next(data)
        # for j in range(0, video.shape[0]):
        #     create_gif(((video[j] + 1) / 2).permute(1,2,3,0).cpu().numpy(), f"original_for_evaluation/{iteration*video.shape[0]+j}.gif")
        cond_kwargs["cond_img"] = video[:,:,cond_frames].to(dist_util.dev()) 
        video = video.to(dist_util.dev())


    sample_fn = (
        diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
    )

    sample = sample_fn(
        model,
        (args.batch_size, channels, args.seq_len, args.image_size, args.image_size),
        clip_denoised=args.clip_denoised,
        progress=False,
        cond_kwargs=cond_kwargs
    )

    all_videos.append(sample)
    generated_num += args.batch_size
    sample = ((sample + 1) / 2).permute(0, 2, 3, 4, 1).cpu().numpy()
    for i in range(sample.shape[0]):
        create_gif(sample[i], f"{args.output_dir}/{i+generated_num*sample.shape[0]}.gif")


sampling...


In [None]:
sample.shape

In [None]:
all_videos[0].view(-1,20,3,64,64).shape

In [None]:
for frame in range(20):
    imag = sample[0][frame]

    plt.imshow(imag)    
    plt.show()