In [None]:
# necessary packages and repos - may have to restart runtime/reset kernel after running
!git clone https://github.com/CompVis/latent-diffusion.git
!git clone https://github.com/CompVis/taming-transformers
!git clone https://github.com/crowsonkb/k-diffusion
!pip install -e ./taming-transformers
!pip install omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops
!pip install git+https://github.com/openai/CLIP.git
!pip install transformers kornia imageio imageio_ffmpeg pillow scikit-image jsonmerge clean-fid resize-right torchdiffeq

In [None]:
# SD model
!mkdir -p models/ldm/stable-diffusion-v1/
# when weights become available drop link here
!wget -O models/ldm/stable-diffusion-v1/model.ckpt LINK_TO_WEIGHTS_HERE

In [None]:
# fallback LAION 400M model
!mkdir -p models/ldm/text2img-large/
!wget -O models/ldm/text2img-large/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/txt2img-f8-large/model.ckpt

In [None]:
# imports
import re
import os
import time
import torch
import imageio
import torch.nn as nn
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm.auto import tqdm, trange
from itertools import islice
from einops import rearrange, repeat
from torchvision.utils import make_grid

from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler

import sys
sys.path.append(f'{os.path.abspath(os.getcwd())}/k-diffusion')
import k_diffusion as K

In [None]:
# variables
use_laion400m = False
use_plms = False
use_k_lms = False

if use_laion400m:
    print("Falling back to LAION 400M model...")
    config_location = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
    model_location = "models/ldm/text2img-large/model.ckpt"
    outdir = "outputs/txt2img-samples-laion400m"
else:
    print("Using Stable Diffusion model...")
    config_location = 'configs/stable-diffusion/v1-inference.yaml'
    model_location = 'models/ldm/stable-diffusion-v1/model.ckpt'
    outdir = 'outputs/txt2img-samples'

In [None]:
# utility functions
def load_model_from_config(config, ckpt):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt)#, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    model.cuda()
    model.eval()
    return model

def get_model():
    config = OmegaConf.load(config_location)
    model = load_model_from_config(config, model_location)
    return model

def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())

def load_img(path, size):
    image = Image.open(path).convert("RGB")
    w, h = image.size
    print(f"loaded input image of size ({w}, {h}) and resized to ({size[0]}, {size[1]}) from {path}")
    w, h = size #map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32
    image = image.resize((w, h), resample=Image.LANCZOS)
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2.*image - 1.

def slerp(val, low, high):
    low_norm = low/torch.norm(low)
    high_norm = high/torch.norm(high)
    omega = torch.acos((low_norm*high_norm).sum())
    so = torch.sin(omega)
    res = (torch.sin((1.0-val)*omega)/so)*low + (torch.sin(val*omega)/so) * high
    return res

def get_slerp_vectors(start, end, frames=20):
    out = torch.Tensor(frames, start.shape[0]).to(device)
    factor = 1.0 / (frames - 1)
    for i in range(frames):
        out[i] = slerp(factor*i, start, end)
    return out

def get_conditioning_vector(prompt):
    split_prompt = re.split(r'::([-\d.]+)', prompt)
    if len(split_prompt) == 0:
        raise(AttributeError('not a valid prompt'))
    elif len(split_prompt) == 1:
        return model.get_learned_conditioning(prompt)

    split_prompt = iter(split_prompt)
    c_tensors = []
    weights = []
    for text in split_prompt:
        if text == '':
            continue
        text = text.strip()
        try:
            weight = float(next(split_prompt))
            weights.append(weight)
            c_tensors.append(model.get_learned_conditioning(text))
        except:
            print(f'Prompt: "{text}" dropped due to invalid weight') 
            continue
    abs_weight = [abs(weight) for weight in weights] if average_weights else weights

    c = c_tensors[0] * weights[0]
    for c_tensor, weight in zip(c_tensors[1:], weights[1:]):
        c += c_tensor * weight
    c = c/sum(abs_weight)
    return c

def get_starting_code_and_conditioning_vector(seed, prompt):
    if seed is None:
        seed = np.random.randint(-np.iinfo(np.int32).max, np.iinfo(np.int32).max)
    seed_everything(seed)
    start_code = torch.randn([1, C, H // f, W // f], device=device)
    c = get_conditioning_vector(prompt)
    return (c, start_code)

def unflatten(l, n):
    res = []
    t = l[:]
    while len(t) > 0:
        res.append(t[:n])
        t = t[n:]
    return res

In [None]:
# set up model and sampler
model = get_model()

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)

if use_plms:
    print("Warning: img2img and disco style animations not compatible with PLMSSampler")
    sampler = PLMSSampler(model)
else:
    sampler = DDIMSampler(model)

os.makedirs(outdir, exist_ok=True)
sample_path = os.path.join(outdir, "samples")
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outdir)) - 1

if use_k_lms:
    class CFGDenoiser(nn.Module):
        def __init__(self, model):
            super().__init__()
            self.inner_model = model

        def forward(self, x, sigma, uncond, cond, cond_scale):
            x_in = torch.cat([x] * 2)
            sigma_in = torch.cat([sigma] * 2)
            cond_in = torch.cat([uncond, cond])
            uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
            return uncond + (cond - uncond) * cond_scale
        
    model_wrap = K.external.CompVisDenoiser(model)
    sigma_min, sigma_max = model_wrap.sigmas[0].item(), model_wrap.sigmas[-1].item()

Text -> Image

In [None]:
# parameters
prompts = [
    "joe biden::1 happy::0.5"
] # list of string prompts - negative weights don't work well, needs investigating
seed = 741 # seed for reproducible generations - use None for random seed
average_weights = True # If using prompt weights, whether or not to average them

init_image_location = "" # file location of init image (use a blank string if none desired)
init_noise_strength = 0.75 # how much to noise init image (0-1.0 where 1.0 is full destruction of init image information)

ddim_steps = 200 # ddim sampling steps
ddim_eta = 0.0 # ddim eta (eta=0.0 corresponds to deterministic sampling) (must be 0.0 if using PLMS/k_lms sampling)
unconditional_guidance_scale = 7.5 # unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))

precision = 'autocast' # precision to evaluate at (full or autocast)

n_iter = 1 # how many sample iterations
batch_size = 1 # how many samples to generate per prompt
n_rows = 0 # rows in the grid (will use batch_size if set to 0)

fixed_code = True # if enabled, uses the same starting code across samples
skip_grid = True # do not save a grid, only individual samples
skip_save = False # do not save individual samples
show_images = True # whether or not to show images after generation

H = 512 # height
W = 512 # width
C = 4 # channels
f = 8 # downsampling factor

####################################################################################

if n_rows == 0:
    n_rows = batch_size

if seed is None:
    seed = np.random.randint(-np.iinfo(np.int32).max, np.iinfo(np.int32).max)
seed_everything(seed)

start_code = None
if fixed_code:
    start_code = torch.randn([batch_size, C, H // f, W // f], device=device)

precision_scope = autocast if precision=="autocast" else nullcontext
data = chunk(prompts, batch_size)

if init_image_location != "":
    assert os.path.isfile(init_image_location)
    init_image = load_img(init_image_location, ()).to(device)
    init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
    init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image))  # move to latent space

    sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False)

    assert 0. <= init_noise_strength <= 1., 'can only work with strength in [0.0, 1.0]'
    t_enc = int(init_noise_strength * ddim_steps)
    print(f"target t_enc is {t_enc} steps")

In [None]:
# run generation(s)
with torch.no_grad():
    with precision_scope("cuda"):
        with model.ema_scope():
            tic = time.time()
            all_samples = list()
            for n in trange(n_iter, desc="Sampling"):
                for prompts in tqdm(data, desc="data"):
                    uc = None
                    if unconditional_guidance_scale != 1.0:
                        uc = model.get_learned_conditioning(batch_size * [""])
                    if isinstance(prompts, tuple):
                        prompts = list(prompts)
                    c = torch.cat([get_conditioning_vector(prompt) for prompt in prompts])

                    if len(init_image_location) == 0 or init_image_location == '':
                        shape = [C, H // f, W // f]
                        if use_k_lms:
                            sigmas = model_wrap.get_sigmas(ddim_steps)
                            torch.manual_seed(seed) # changes manual seeding procedure
                            x = torch.randn([batch_size, *shape], device=device) * sigmas[0] # for GPU draw
                            model_wrap_cfg = CFGDenoiser(model_wrap)
                            extra_args = {'cond': c, 'uncond': uc, 'cond_scale': unconditional_guidance_scale}
                            samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args)
                        else:
                            samples_ddim, _ = sampler.sample(S=ddim_steps,
                                                                conditioning=c,
                                                                batch_size=batch_size,
                                                                shape=shape,
                                                                verbose=False,
                                                                unconditional_guidance_scale=unconditional_guidance_scale,
                                                                unconditional_conditioning=uc,
                                                                eta=ddim_eta,
                                                                x_T=start_code)
                    else:
                        # encode (scaled latent)
                        z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
                        # decode it
                        samples_ddim = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=unconditional_guidance_scale,
                                                unconditional_conditioning=uc)

                    x_samples_ddim = model.decode_first_stage(samples_ddim)
                    x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

                    if not skip_save:
                        for x_sample in x_samples_ddim:
                            x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                            img = Image.fromarray(x_sample.astype(np.uint8))
                            img.save(
                                os.path.join(sample_path, f"{base_count:05}.png"))
                            base_count += 1
                            
                            if show_images:
                                display(img)
                                
                    if skip_save and show_images:
                        for x_sample in x_samples_ddim:
                            x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                            img = Image.fromarray(x_sample.astype(np.uint8))
                            display(img)

                    if not skip_grid:
                        all_samples.append(x_samples_ddim)

            if not skip_grid:
                # additionally, save as grid
                grid = torch.stack(all_samples, 0)
                grid = rearrange(grid, 'n b c h w -> (n b) c h w')
                grid = make_grid(grid, nrow=n_rows)

                # to image
                grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
                Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outdir, f'grid-{grid_count:04}.png'))
                grid_count += 1

            toc = time.time()

print(f"Your samples are ready and waiting for you here: \n{outdir}\n"
        f"\nTime elapsed: {round(toc - tic, 2)} seconds")

Text, Text -> Interpolation

In [None]:
# parameters
prompts = [
    (26208852, "Joe Biden exhaling a large smoke cloud from his bong, candid photography"),
    (685626752, "Barack Obama exhaling a large smoke cloud from his bong, candid photography"),
] # (seed, prompt)

save_mp4 = 'test.mp4'

average_weights = True # If using prompt weights, whether or not to average them

ddim_steps = 50 # ddim sampling steps
ddim_eta = 0.0 # ddim eta (eta=0.0 corresponds to deterministic sampling) (must be 0.0 if using PLMS sampling)
unconditional_guidance_scale = 7.5 # unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))

precision = 'autocast' # precision to evaluate at (full or autocast)

batch_size = 10 # how many samples to generate per prompt (currently must be set to one for animation)

loop = True # if enabled, make animation loop by interpolating between end and start vectors
fixed_code = False # if enabled, uses the same starting code across samples
fixed_seed = None # fixed seed to use if using fixed_code
skip_save = True # do not save individual frames
skip_save_video = False # do not save mp4
show_images = True # whether or not to show images after generation

H = 512 # height
W = 512 # width
C = 4 # channels
f = 8 # downsampling factor

degrees_per_second = 20 # degrees to travel per second
fps = 40 # frames per second of output mp4

####################################################################################

precision_scope = autocast if precision=="autocast" else nullcontext

frames_per_degree = fps / degrees_per_second

if fixed_code:
    for i in range(len(prompts)):
        if fixed_seed is None:
            fixed_seed = np.random.randint(-np.iinfo(np.int32).max, np.iinfo(np.int32).max)
        if isinstance(prompts[i], str):
            prompts[i] = (fixed_seed, prompts[i])
        else:
            prompts[i][0] = fixed_seed

# interpolation setup
previous_c = None
previous_start_code = None
slerp_c_vectors = []
slerp_start_codes = []
for i, data in enumerate(map(lambda x: get_starting_code_and_conditioning_vector(*x), prompts)):
    c, start_code = data
    if i == 0:
        slerp_c_vectors.append(c)
        slerp_start_codes.append(start_code)
    else:
        start_norm = previous_c.flatten()/torch.norm(previous_c.flatten())
        end_norm = c.flatten()/torch.norm(c.flatten())
        omega = torch.acos((start_norm*end_norm).sum())
        frames_c = round(omega.item() * frames_per_degree * 57.2957795)
        start_norm = previous_start_code.flatten()/torch.norm(previous_start_code.flatten())
        end_norm = start_code.flatten()/torch.norm(start_code.flatten())
        omega = torch.acos((start_norm*end_norm).sum())
        frames_start_code = round(omega.item() * frames_per_degree * 57.2957795)
        
        frames = frames_c if frames_c >= frames_start_code else frames_start_code
        
        original_c_shape = c.shape
        original_start_code_shape = start_code.shape
        c_vectors = get_slerp_vectors(previous_c.flatten(), c.flatten(), frames=frames)
        c_vectors = c_vectors.reshape(-1, *original_c_shape)
        slerp_c_vectors.extend(list(c_vectors[1:])) # drop first frame to prevent repeating frames
        start_codes = get_slerp_vectors(previous_start_code.flatten(), start_code.flatten(), frames=frames)
        start_codes = start_codes.reshape(-1, *original_start_code_shape)
        slerp_start_codes.extend(list(start_codes[1:])) # drop first frame to prevent repeating frames
        if loop and i == len(prompts) - 1:
            c_vectors = get_slerp_vectors(c.flatten(), slerp_c_vectors[0].flatten(), frames=frames)
            c_vectors = c_vectors.reshape(-1, *original_c_shape)
            slerp_c_vectors.extend(list(c_vectors[1:-1])) # drop first and last frame to prevent repeating frames
            start_codes = get_slerp_vectors(start_code.flatten(), slerp_start_codes[0].flatten(), frames=frames)
            start_codes = start_codes.reshape(-1, *original_start_code_shape)
            slerp_start_codes.extend(list(start_codes[1:-1])) # drop first and last frame to prevent repeating frames
    previous_c = c
    previous_start_code = start_code
    
slerp_c_vectors = unflatten(slerp_c_vectors, batch_size)
slerp_start_codes = unflatten(slerp_start_codes, batch_size)

In [None]:
# run generation(s)
if not skip_save_video:
    video_out = imageio.get_writer(save_mp4, mode='I', fps=fps, codec='libx264')
with torch.no_grad():
    with precision_scope("cuda"):
        with model.ema_scope():
            tic = time.time()
            #all_samples = list()
            for c, start_code in tqdm(zip(slerp_c_vectors, slerp_start_codes), desc="data", total=len(slerp_c_vectors)):
                uc = None
                if unconditional_guidance_scale != 1.0:
                    uc = model.get_learned_conditioning(len(c) * [""])
                if isinstance(c, tuple) or isinstance(c, list):
                    c = torch.stack(list(c), dim=0)
                
                c = torch.cat(tuple(c))
                start_code = torch.cat(tuple(start_code))
                
                shape = [C, H // f, W // f]
                
                if use_k_lms:
                            sigmas = model_wrap.get_sigmas(ddim_steps)
                            model_wrap_cfg = CFGDenoiser(model_wrap)
                            extra_args = {'cond': c, 'uncond': uc, 'cond_scale': unconditional_guidance_scale}
                            samples_ddim = K.sampling.sample_lms(model_wrap_cfg, start_code*sigmas[0], sigmas, extra_args=extra_args)
                else:
                    samples_ddim, _ = sampler.sample(S=ddim_steps,
                                                        conditioning=c,
                                                        batch_size=len(c),
                                                        shape=shape,
                                                        verbose=False,
                                                        unconditional_guidance_scale=unconditional_guidance_scale,
                                                        unconditional_conditioning=uc,
                                                        eta=ddim_eta,
                                                        x_T=start_code)

                x_samples_ddim = model.decode_first_stage(samples_ddim)
                x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

                if not skip_save or not skip_save_video or show_images:
                    for x_sample in x_samples_ddim:
                        x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                        if not skip_save_video:
                            video_out.append_data(x_sample)
                        if not skip_save or show_images:
                            img = Image.fromarray(x_sample.astype(np.uint8))
                        if not skip_save:
                            img.save(os.path.join(sample_path, f"{base_count:05}.png"))
                        if show_images:
                            display(img)

            toc = time.time()

print(f"Your samples are ready and waiting for you here: \n{outdir}\n"
        f"\nTime elapsed: {round(toc - tic, 2)} seconds")
video_out.close()

Disco Diffusion Style Animation

In [None]:
# necessary packages and repos
!apt update
!apt install python3-opencv -y
!pip install timm opencv-python matplotlib pandas
!git clone https://github.com/alembics/disco-diffusion.git
!mv disco-diffusion/disco_xform_utils.py disco_xform_utils.py
!git clone https://github.com/isl-org/MiDaS.git
!git clone https://github.com/MSFTserver/pytorch3d-lite.git
!mv MiDaS/utils.py MiDaS/midas_utils.py
!git clone https://github.com/shariqfarooq123/AdaBins.git

In [None]:
# MiDaS depth model
!mkdir models/depth/
!mkdir models/depth/midas/
!wget -O models/depth/midas/model.ckpt https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt

In [None]:
# AdaBins depth model
!mkdir pretrained/
!wget -O pretrained/AdaBins_nyu.pt https://cloudflare-ipfs.com/ipfs/Qmd2mMnDLWePKmgfS8m6ntAg4nhV5VkUyAydYBp8cWWeB7/AdaBins_nyu.pt

In [None]:
# system path links
import sys
sys.path.append(f'{os.path.abspath(os.getcwd())}')
sys.path.append(f'{os.path.abspath(os.getcwd())}/MiDaS')
sys.path.append(f'{os.path.abspath(os.getcwd())}/pytorch3d-lite')
sys.path.append(f'{os.path.abspath(os.getcwd())}/AdaBins/')

In [None]:
# imports
import cv2
import gc
import math
import py3d_tools as p3dT
import disco_xform_utils as dxf
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import pandas as pd
from ipywidgets import Output
from midas.dpt_depth import DPTDepthModel
from midas.midas_net import MidasNet
from midas.midas_net_custom import MidasNet_small
from midas.transforms import Resize, NormalizeImage, PrepareForNet
from types import SimpleNamespace
from PIL import Image, ImageOps

In [None]:
# utility functions
def init_midas_depth_model():
    midas_model = None
    net_w = None
    net_h = None
    resize_mode = None
    normalization = None

    print(f"Initializing MiDaS depth model...")
    # load network
    midas_model_path = 'models/depth/midas/model.ckpt'
    midas_model = DPTDepthModel(
        path=midas_model_path,
        backbone="vitl16_384",
        non_negative=True,
    )
    net_w, net_h = 384, 384
    resize_mode = "minimal"
    normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

    midas_transform = T.Compose(
        [
            Resize(
                net_w,
                net_h,
                resize_target=None,
                keep_aspect_ratio=True,
                ensure_multiple_of=32,
                resize_method=resize_mode,
                image_interpolation_method=cv2.INTER_CUBIC,
            ),
            normalization,
            PrepareForNet(),
        ]
    )

    midas_model.eval()

    midas_model = midas_model.to(memory_format=torch.channels_last)  
    midas_model = midas_model.half()

    midas_model.to(device)

    print(f"MiDaS depth model initialized.")
    return midas_model, midas_transform, net_w, net_h, resize_mode, normalization

def do_3d_step(img_filepath, frame_num, midas_model, midas_transform):
  if args.key_frames:
    translation_x = args.translation_x_series[frame_num]
    translation_y = args.translation_y_series[frame_num]
    translation_z = args.translation_z_series[frame_num]
    rotation_3d_x = args.rotation_3d_x_series[frame_num]
    rotation_3d_y = args.rotation_3d_y_series[frame_num]
    rotation_3d_z = args.rotation_3d_z_series[frame_num]
    print(
        f'translation_x: {translation_x}',
        f'translation_y: {translation_y}',
        f'translation_z: {translation_z}',
        f'rotation_3d_x: {rotation_3d_x}',
        f'rotation_3d_y: {rotation_3d_y}',
        f'rotation_3d_z: {rotation_3d_z}',
    )

  translate_xyz = [-translation_x*TRANSLATION_SCALE, translation_y*TRANSLATION_SCALE, -translation_z*TRANSLATION_SCALE]
  rotate_xyz_degrees = [rotation_3d_x, rotation_3d_y, rotation_3d_z]
  print('translation:',translate_xyz)
  print('rotation:',rotate_xyz_degrees)
  rotate_xyz = [math.radians(rotate_xyz_degrees[0]), math.radians(rotate_xyz_degrees[1]), math.radians(rotate_xyz_degrees[2])]
  rot_mat = p3dT.euler_angles_to_matrix(torch.tensor(rotate_xyz, device=device), "XYZ").unsqueeze(0)
  print("rot_mat: " + str(rot_mat))
  next_step_pil = dxf.transform_image_3d(img_filepath, midas_model, midas_transform, torch.device("cuda"),
                                          rot_mat, translate_xyz, args.near_plane, args.far_plane,
                                          args.fov, padding_mode=args.padding_mode,
                                          sampling_mode=args.sampling_mode, midas_weight=args.midas_weight)
  return next_step_pil

def interp(t):
    return 3 * t**2 - 2 * t ** 3

def perlin(width, height, scale=10):
    gx, gy = torch.randn(2, width + 1, height + 1, 1, 1, device=device)
    xs = torch.linspace(0, 1, scale + 1)[:-1, None].to(device)
    ys = torch.linspace(0, 1, scale + 1)[None, :-1].to(device)
    wx = 1 - interp(xs)
    wy = 1 - interp(ys)
    dots = 0
    dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys)
    dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys)
    dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys))
    dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] * (1 - ys))
    return dots.permute(0, 2, 1, 3).contiguous().view(width * scale, height * scale)

def perlin_ms(octaves, width, height, grayscale):
    out_array = [0.5] if grayscale else [0.5, 0.5, 0.5, 0.5]
    # out_array = [0.0] if grayscale else [0.0, 0.0, 0.0]
    for i in range(1 if grayscale else 4):
        scale = 2 ** len(octaves)
        oct_width = width
        oct_height = height
        for oct in octaves:
            p = perlin(oct_width, oct_height, scale)
            out_array[i] += p * oct
            scale //= 2
            oct_width *= 2
            oct_height *= 2
    return torch.cat(out_array)

def create_perlin_noise(octaves=[1, 1, 1, 1], width=2, height=2, grayscale=True):
    out = perlin_ms(octaves, width, height, grayscale)
    if grayscale:
        out = TF.resize(size=(args.W//args.f, args.H//args.f), img=out.unsqueeze(0))
        out = TF.to_pil_image(out.clamp(0, 1)).convert('RGBA')
    else:
        out = out.reshape(-1, 4, out.shape[0]//4, out.shape[1])
        out = TF.resize(size=(args.W//args.f, args.H//args.f), img=out)
        out = TF.to_pil_image(out.clamp(0, 1).squeeze())
    
    return out


def regen_perlin():
    if args.perlin_mode == 'color':
        init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)
        init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, False)
    elif args.perlin_mode == 'gray':
        init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, True)
        init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)
    else:
        init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)
        init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)
    display(init)
    init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device).unsqueeze(0).mul(2).sub(1)
    del init2
    return init.expand(args.batch_size, -1, -1, -1)

# def save_settings():
#     setting_list = {
#       'text_prompts': text_prompts,
#       'image_prompts': image_prompts,
#       'clip_guidance_scale': clip_guidance_scale,
#       'tv_scale': tv_scale,
#       'range_scale': range_scale,
#       'sat_scale': sat_scale,
#       # 'cutn': cutn,
#       'cutn_batches': cutn_batches,
#       'max_frames': max_frames,
#       'interp_spline': interp_spline,
#       # 'rotation_per_frame': rotation_per_frame,
#       'init_image': init_image,
#       'init_scale': init_scale,
#       'skip_steps': skip_steps,
#       # 'zoom_per_frame': zoom_per_frame,
#       'frames_scale': frames_scale,
#       'frames_skip_steps': frames_skip_steps,
#       'perlin_init': perlin_init,
#       'perlin_mode': perlin_mode,
#       'skip_augs': skip_augs,
#       'randomize_class': randomize_class,
#       'clip_denoised': clip_denoised,
#       'clamp_grad': clamp_grad,
#       'clamp_max': clamp_max,
#       'seed': seed,
#       'fuzzy_prompt': fuzzy_prompt,
#       'rand_mag': rand_mag,
#       'eta': eta,
#       'width': width_height[0],
#       'height': width_height[1],
#       'diffusion_model': diffusion_model,
#       'use_secondary_model': use_secondary_model,
#       'steps': steps,
#       'diffusion_steps': diffusion_steps,
#       'diffusion_sampling_mode': diffusion_sampling_mode,
#       'ViTB32': ViTB32,
#       'ViTB16': ViTB16,
#       'ViTL14': ViTL14,
#       'ViTL14_336px': ViTL14_336px,
#       'RN101': RN101,
#       'RN50': RN50,
#       'RN50x4': RN50x4,
#       'RN50x16': RN50x16,
#       'RN50x64': RN50x64,
#       'ViTB32_laion2b_e16': ViTB32_laion2b_e16,
#       'ViTB32_laion400m_e31': ViTB32_laion400m_e31,
#       'ViTB32_laion400m_32': ViTB32_laion400m_32,
#       'ViTB32quickgelu_laion400m_e31': ViTB32quickgelu_laion400m_e31,
#       'ViTB32quickgelu_laion400m_e32': ViTB32quickgelu_laion400m_e32,
#       'ViTB16_laion400m_e31': ViTB16_laion400m_e31,
#       'ViTB16_laion400m_e32': ViTB16_laion400m_e32,
#       'RN50_yffcc15m': RN50_yffcc15m,
#       'RN50_cc12m': RN50_cc12m,
#       'RN50_quickgelu_yfcc15m': RN50_quickgelu_yfcc15m,
#       'RN50_quickgelu_cc12m': RN50_quickgelu_cc12m,
#       'RN101_yfcc15m': RN101_yfcc15m,
#       'RN101_quickgelu_yfcc15m': RN101_quickgelu_yfcc15m,
#       'cut_overview': str(cut_overview),
#       'cut_innercut': str(cut_innercut),
#       'cut_ic_pow': str(cut_ic_pow),
#       'cut_icgray_p': str(cut_icgray_p),
#       'key_frames': key_frames,
#       'max_frames': max_frames,
#       'angle': angle,
#       'zoom': zoom,
#       'translation_x': translation_x,
#       'translation_y': translation_y,
#       'translation_z': translation_z,
#       'rotation_3d_x': rotation_3d_x,
#       'rotation_3d_y': rotation_3d_y,
#       'rotation_3d_z': rotation_3d_z,
#       'midas_depth_model': midas_depth_model,
#       'midas_weight': midas_weight,
#       'near_plane': near_plane,
#       'far_plane': far_plane,
#       'fov': fov,
#       'padding_mode': padding_mode,
#       'sampling_mode': sampling_mode,
#       'video_init_path':video_init_path,
#       'extract_nth_frame':extract_nth_frame,
#       'video_init_seed_continuity': video_init_seed_continuity,
#       'turbo_mode':turbo_mode,
#       'turbo_steps':turbo_steps,
#       'turbo_preroll':turbo_preroll,
#       'use_horizontal_symmetry':use_horizontal_symmetry,
#       'use_vertical_symmetry':use_vertical_symmetry,
#       'transformation_percent':transformation_percent,
#       #video init settings
#       'video_init_steps': video_init_steps,
#       'video_init_clip_guidance_scale': video_init_clip_guidance_scale,
#       'video_init_tv_scale': video_init_tv_scale,
#       'video_init_range_scale': video_init_range_scale,
#       'video_init_sat_scale': video_init_sat_scale,
#       'video_init_cutn_batches': video_init_cutn_batches,
#       'video_init_skip_steps': video_init_skip_steps,
#       'video_init_frames_scale': video_init_frames_scale,
#       'video_init_frames_skip_steps': video_init_frames_skip_steps,
#       #warp settings
#       'video_init_flow_warp':video_init_flow_warp,
#       'video_init_flow_blend':video_init_flow_blend,
#       'video_init_check_consistency':video_init_check_consistency,
#       'video_init_blend_mode':video_init_blend_mode
#     }
#     # print('Settings:', setting_list)
#     with open(f"{batchFolder}/{batch_name}({batchNum})_settings.txt", "w+") as f:   #save settings
#         json.dump(setting_list, f, ensure_ascii=False, indent=4)

In [None]:
prompts = [
    "A dragons lair, epic matte painting, concept art, trending on artstation",
    "A wizards magical potion room, epic matte painting, concept art, trending on artstation"
]

average_weights = True
seed = None
loop = True
interpolate = True

device='cuda'

ddim_steps = 250 # ddim sampling steps
ddim_eta = 0.05 # ddim eta (eta=0.0 corresponds to deterministic sampling) (must be 0.0 if using PLMS sampling)
unconditional_guidance_scale = 12 # unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))

init_image = None
init_noise_strength = 0.4
previous_frame_noise_strength = 0.35
animation_mode = '3D'
perlin_init = False # whether ot not to use perlin noi
perlin_mode = 'color' # color or gray

fixed_code = True

n_iter = 1
batch_size = 1

start_frame = 0

resume_run = False

batch_name = 'test'
batchNum = 1
batchFolder = 'outputs'

H = 512 # height
W = 512 # width
C = 4 # channels
f = 8 # downsampling factor

degrees_per_second = 20 # degrees to travel per second
fps = 40 # frames per second of output mp4

frames_per_degree = fps / degrees_per_second

precision = 'autocast' # precision to evaluate at (full or autocast)

video_init_path = "init.mp4"  # @param {type: 'string'}
extract_nth_frame = 2  # @param {type: 'number'}
persistent_frame_output_in_batch_folder = True  # @param {type: 'boolean'}
video_init_seed_continuity = False  # @param {type: 'boolean'}
# @markdown #####**Video Optical Flow Settings:**
video_init_flow_warp = True  # @param {type: 'boolean'}
# Call optical flow from video frames and warp prev frame with flow
# @param {type: 'number'} #0 - take next frame, 1 - take prev warped frame
video_init_flow_blend = 0.999
video_init_check_consistency = False  # Insert param here when ready
# @param ['None', 'linear', 'optical flow']
video_init_blend_mode = "optical flow"
# Call optical flow from video frames and warp prev frame with flow
if animation_mode == "Video Input":
    # suggested by Chris the Wizard#8082 at discord
    if persistent_frame_output_in_batch_folder or (not is_colab):
        videoFramesFolder = f'{batchFolder}/videoFrames'
    else:
        videoFramesFolder = f'/content/videoFrames'
    createPath(videoFramesFolder)
    print(f"Exporting Video Frames (1 every {extract_nth_frame})...")
    try:
        for f in pathlib.Path(f'{videoFramesFolder}').glob('*.jpg'):
            f.unlink()
    except:
        print('')
    vf = f'select=not(mod(n\,{extract_nth_frame}))'
    if os.path.exists(video_init_path):
        subprocess.run(['ffmpeg', '-i', f'{video_init_path}', '-vf', f'{vf}', '-vsync', 'vfr', '-q:v', '2', '-loglevel',
                       'error', '-stats', f'{videoFramesFolder}/%04d.jpg'], stdout=subprocess.PIPE).stdout.decode('utf-8')
    else:
        print(
            f'\nWARNING!\n\nVideo not found: {video_init_path}.\nPlease check your video path.\n')
    #!ffmpeg -i {video_init_path} -vf {vf} -vsync vfr -q:v 2 -loglevel error -stats {videoFramesFolder}/%04d.jpg


# @markdown ---

# @markdown ####**2D Animation Settings:**
# @markdown `zoom` is a multiplier of dimensions, 1 is no zoom.
# @markdown All rotations are provided in degrees.

key_frames = True  # @param {type:"boolean"}
max_frames = 10000  # @param {type:"number"}

if animation_mode == "Video Input":
    max_frames = len(glob(f'{videoFramesFolder}/*.jpg'))

# Do not change, currently will not look good. param ['Linear','Quadratic','Cubic']{type:"string"}
interp_spline = 'Linear'
angle = "0:(0)"  # @param {type:"string"}
zoom = "0: (1), 10: (1.0)"  # @param {type:"string"}
translation_x = "0: (0)"  # @param {type:"string"}
translation_y = "0: (0)"  # @param {type:"string"}
translation_z = "0: (5.0)"  # @param {type:"string"}
rotation_3d_x = "0: (0.2)"  # @param {type:"string"}
rotation_3d_y = "0: (0)"  # @param {type:"string"}
rotation_3d_z = "0: (0.0)"  # @param {type:"string"}
midas_depth_model = "dpt_large"  # @param {type:"string"}
midas_weight = 0.3  # @param {type:"number"}
near_plane = 200  # @param {type:"number"}
far_plane = 10000  # @param {type:"number"}
fov = 40  # @param {type:"number"}
padding_mode = 'border'  # @param {type:"string"}
sampling_mode = 'bicubic'  # @param {type:"string"}

# ======= TURBO MODE
# @markdown ---
# @markdown ####**Turbo Mode (3D anim only):**
# @markdown (Starts after frame 10,) skips diffusion steps and just uses depth map to warp images for skipped frames.
# @markdown Speeds up rendering by 2x-4x, and may improve image coherence between frames.
# @markdown For different settings tuned for Turbo Mode, refer to the original Disco-Turbo Github: https://github.com/zippy731/disco-diffusion-turbo

turbo_mode = False  # @param {type:"boolean"}
turbo_steps = "3"  # @param ["2","3","4","5","6"] {type:"string"}
turbo_preroll = 10  # frames

# insist turbo be used only w 3d anim.
if turbo_mode and animation_mode != '3D':
    print('=====')
    print('Turbo mode only available with 3D animations. Disabling Turbo.')
    print('=====')
    turbo_mode = False

# @markdown ---

# @markdown ####**Coherency Settings:**
# @markdown `frame_scale` tries to guide the new frame to looking like the old one. A good default is 1500.
frames_scale = 1500  # @param{type: 'integer'}
# @markdown `frame_skip_steps` will blur the previous frame - higher values will flicker less but struggle to add enough new detail to zoom into.
# @param ['40%', '50%', '60%', '70%', '80%'] {type: 'string'}
frames_skip_steps = '60%'

# @markdown ####**Video Init Coherency Settings:**
# @markdown `frame_scale` tries to guide the new frame to looking like the old one. A good default is 1500.
video_init_frames_scale = 15000  # @param{type: 'integer'}
# @markdown `frame_skip_steps` will blur the previous frame - higher values will flicker less but struggle to add enough new detail to zoom into.
# @param ['40%', '50%', '60%', '70%', '80%'] {type: 'string'}
video_init_frames_skip_steps = '70%'

# ======= VR MODE
# @markdown ---
# @markdown ####**VR Mode (3D anim only):**
# @markdown Enables stereo rendering of left/right eye views (supporting Turbo) which use a different (fish-eye) camera projection matrix.
# @markdown Note the images you're prompting will work better if they have some inherent wide-angle aspect
# @markdown The generated images will need to be combined into left/right videos. These can then be stitched into the VR180 format.
# @markdown Google made the VR180 Creator tool but subsequently stopped supporting it. It's available for download in a few places including https://www.patrickgrunwald.de/vr180-creator-download
# @markdown The tool is not only good for stitching (videos and photos) but also for adding the correct metadata into existing videos, which is needed for services like YouTube to identify the format correctly.
# @markdown Watching YouTube VR videos isn't necessarily the easiest depending on your headset. For instance Oculus have a dedicated media studio and store which makes the files easier to access on a Quest https://creator.oculus.com/manage/mediastudio/
# @markdown
# @markdown The command to get ffmpeg to concat your frames for each eye is in the form: `ffmpeg -framerate 15 -i frame_%4d_l.png l.mp4` (repeat for r)

vr_mode = False  # @param {type:"boolean"}
# @markdown `vr_eye_angle` is the y-axis rotation of the eyes towards the center
vr_eye_angle = 0.5  # @param{type:"number"}
# @markdown interpupillary distance (between the eyes)
vr_ipd = 5.0  # @param{type:"number"}

# insist VR be used only w 3d anim.
if vr_mode and animation_mode != '3D':
    print('=====')
    print('VR mode only available with 3D animations. Disabling VR.')
    print('=====')
    vr_mode = False


def parse_key_frames(string, prompt_parser=None):
    """Given a string representing frame numbers paired with parameter values at that frame,
    return a dictionary with the frame numbers as keys and the parameter values as the values.

    Parameters
    ----------
    string: string
        Frame numbers paired with parameter values at that frame number, in the format
        'framenumber1: (parametervalues1), framenumber2: (parametervalues2), ...'
    prompt_parser: function or None, optional
        If provided, prompt_parser will be applied to each string of parameter values.

    Returns
    -------
    dict
        Frame numbers as keys, parameter values at that frame number as values

    Raises
    ------
    RuntimeError
        If the input string does not match the expected format.

    Examples
    --------
    >>> parse_key_frames("10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)")
    {10: 'Apple: 1| Orange: 0', 20: 'Apple: 0| Orange: 1| Peach: 1'}

    >>> parse_key_frames("10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)", prompt_parser=lambda x: x.lower()))
    {10: 'apple: 1| orange: 0', 20: 'apple: 0| orange: 1| peach: 1'}
    """
    import re
    pattern = r'((?P<frame>[0-9]+):[\s]*[\(](?P<param>[\S\s]*?)[\)])'
    frames = dict()
    for match_object in re.finditer(pattern, string):
        frame = int(match_object.groupdict()['frame'])
        param = match_object.groupdict()['param']
        if prompt_parser:
            frames[frame] = prompt_parser(param)
        else:
            frames[frame] = param

    if frames == {} and len(string) != 0:
        raise RuntimeError('Key Frame string not correctly formatted')
    return frames


def get_inbetweens(key_frames, integer=False):
    """Given a dict with frame numbers as keys and a parameter value as values,
    return a pandas Series containing the value of the parameter at every frame from 0 to max_frames.
    Any values not provided in the input dict are calculated by linear interpolation between
    the values of the previous and next provided frames. If there is no previous provided frame, then
    the value is equal to the value of the next provided frame, or if there is no next provided frame,
    then the value is equal to the value of the previous provided frame. If no frames are provided,
    all frame values are NaN.

    Parameters
    ----------
    key_frames: dict
        A dict with integer frame numbers as keys and numerical values of a particular parameter as values.
    integer: Bool, optional
        If True, the values of the output series are converted to integers.
        Otherwise, the values are floats.

    Returns
    -------
    pd.Series
        A Series with length max_frames representing the parameter values for each frame.

    Examples
    --------
    >>> max_frames = 5
    >>> get_inbetweens({1: 5, 3: 6})
    0    5.0
    1    5.0
    2    5.5
    3    6.0
    4    6.0
    dtype: float64

    >>> get_inbetweens({1: 5, 3: 6}, integer=True)
    0    5
    1    5
    2    5
    3    6
    4    6
    dtype: int64
    """
    key_frame_series = pd.Series([np.nan for a in range(max_frames)])

    for i, value in key_frames.items():
        key_frame_series[i] = value
    key_frame_series = key_frame_series.astype(float)

    interp_method = interp_spline

    if interp_method == 'Cubic' and len(key_frames.items()) <= 3:
        interp_method = 'Quadratic'

    if interp_method == 'Quadratic' and len(key_frames.items()) <= 2:
        interp_method = 'Linear'

    key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()]
    key_frame_series[max_frames -
                     1] = key_frame_series[key_frame_series.last_valid_index()]
    # key_frame_series = key_frame_series.interpolate(method=intrp_method,order=1, limit_direction='both')
    key_frame_series = key_frame_series.interpolate(
        method=interp_method.lower(), limit_direction='both')
    if integer:
        return key_frame_series.astype(int)
    return key_frame_series

if key_frames:
    try:
        angle_series = get_inbetweens(parse_key_frames(angle))
    except RuntimeError as e:
        print(
            "WARNING: You have selected to use key frames, but you have not "
            "formatted `angle` correctly for key frames.\n"
            "Attempting to interpret `angle` as "
            f'"0: ({angle})"\n'
            "Please read the instructions to find out how to use key frames "
            "correctly.\n"
        )
        angle = f"0: ({angle})"
        angle_series = get_inbetweens(parse_key_frames(angle))

    try:
        zoom_series = get_inbetweens(parse_key_frames(zoom))
    except RuntimeError as e:
        print(
            "WARNING: You have selected to use key frames, but you have not "
            "formatted `zoom` correctly for key frames.\n"
            "Attempting to interpret `zoom` as "
            f'"0: ({zoom})"\n'
            "Please read the instructions to find out how to use key frames "
            "correctly.\n"
        )
        zoom = f"0: ({zoom})"
        zoom_series = get_inbetweens(parse_key_frames(zoom))

    try:
        translation_x_series = get_inbetweens(parse_key_frames(translation_x))
    except RuntimeError as e:
        print(
            "WARNING: You have selected to use key frames, but you have not "
            "formatted `translation_x` correctly for key frames.\n"
            "Attempting to interpret `translation_x` as "
            f'"0: ({translation_x})"\n'
            "Please read the instructions to find out how to use key frames "
            "correctly.\n"
        )
        translation_x = f"0: ({translation_x})"
        translation_x_series = get_inbetweens(parse_key_frames(translation_x))

    try:
        translation_y_series = get_inbetweens(parse_key_frames(translation_y))
    except RuntimeError as e:
        print(
            "WARNING: You have selected to use key frames, but you have not "
            "formatted `translation_y` correctly for key frames.\n"
            "Attempting to interpret `translation_y` as "
            f'"0: ({translation_y})"\n'
            "Please read the instructions to find out how to use key frames "
            "correctly.\n"
        )
        translation_y = f"0: ({translation_y})"
        translation_y_series = get_inbetweens(parse_key_frames(translation_y))

    try:
        translation_z_series = get_inbetweens(parse_key_frames(translation_z))
    except RuntimeError as e:
        print(
            "WARNING: You have selected to use key frames, but you have not "
            "formatted `translation_z` correctly for key frames.\n"
            "Attempting to interpret `translation_z` as "
            f'"0: ({translation_z})"\n'
            "Please read the instructions to find out how to use key frames "
            "correctly.\n"
        )
        translation_z = f"0: ({translation_z})"
        translation_z_series = get_inbetweens(parse_key_frames(translation_z))

    try:
        rotation_3d_x_series = get_inbetweens(parse_key_frames(rotation_3d_x))
    except RuntimeError as e:
        print(
            "WARNING: You have selected to use key frames, but you have not "
            "formatted `rotation_3d_x` correctly for key frames.\n"
            "Attempting to interpret `rotation_3d_x` as "
            f'"0: ({rotation_3d_x})"\n'
            "Please read the instructions to find out how to use key frames "
            "correctly.\n"
        )
        rotation_3d_x = f"0: ({rotation_3d_x})"
        rotation_3d_x_series = get_inbetweens(parse_key_frames(rotation_3d_x))

    try:
        rotation_3d_y_series = get_inbetweens(parse_key_frames(rotation_3d_y))
    except RuntimeError as e:
        print(
            "WARNING: You have selected to use key frames, but you have not "
            "formatted `rotation_3d_y` correctly for key frames.\n"
            "Attempting to interpret `rotation_3d_y` as "
            f'"0: ({rotation_3d_y})"\n'
            "Please read the instructions to find out how to use key frames "
            "correctly.\n"
        )
        rotation_3d_y = f"0: ({rotation_3d_y})"
        rotation_3d_y_series = get_inbetweens(parse_key_frames(rotation_3d_y))

    try:
        rotation_3d_z_series = get_inbetweens(parse_key_frames(rotation_3d_z))
    except RuntimeError as e:
        print(
            "WARNING: You have selected to use key frames, but you have not "
            "formatted `rotation_3d_z` correctly for key frames.\n"
            "Attempting to interpret `rotation_3d_z` as "
            f'"0: ({rotation_3d_z})"\n'
            "Please read the instructions to find out how to use key frames "
            "correctly.\n"
        )
        rotation_3d_z = f"0: ({rotation_3d_z})"
        rotation_3d_z_series = get_inbetweens(parse_key_frames(rotation_3d_z))

else:
    angle = float(angle)
    zoom = float(zoom)
    translation_x = float(translation_x)
    translation_y = float(translation_y)
    translation_z = float(translation_z)
    rotation_3d_x = float(rotation_3d_x)
    rotation_3d_y = float(rotation_3d_y)
    rotation_3d_z = float(rotation_3d_z)


if seed is None:
    seed = np.random.randint(-np.iinfo(np.int32).max, np.iinfo(np.int32).max)

args = {
    'seed': seed,
    'animation_mode': animation_mode,
    'init_image': init_image,
    'perlin_init': perlin_init,
    'perlin_mode': perlin_mode,
    'H': H,
    'W': W,
    'C': C,
    'f': f,
    'start_frame': start_frame,
    'max_frames': max_frames,
    'n_iter': n_iter,
    'batch_size': batch_size,
    'init_noise_strength': init_noise_strength,
    'previous_frame_noise_strength': previous_frame_noise_strength,
    'ddim_steps': ddim_steps,
    'ddim_eta': ddim_eta,
    'unconditional_guidance_scale': unconditional_guidance_scale,
    'fixed_code': fixed_code,
    'key_frames': key_frames,
    'angle_series': angle_series,
    'zoom_series': zoom_series,
    'translation_x_series': translation_x_series,
    'translation_y_series': translation_y_series,
    'translation_z_series': translation_z_series,
    'rotation_3d_x_series': rotation_3d_x_series,
    'rotation_3d_y_series': rotation_3d_y_series,
    'rotation_3d_z_series': rotation_3d_z_series,
    'near_plane': near_plane,
    'far_plane': far_plane,
    'fov': fov,
    'padding_mode': padding_mode,
    'sampling_mode': sampling_mode,
    'midas_weight': midas_weight
}

args = SimpleNamespace(**args)

seed_everything(seed)

start_code = None
if fixed_code:
    start_code = torch.randn([batch_size, C, H // f, W // f], device=device)

precision_scope = autocast if precision=="autocast" else nullcontext

if interpolate and len(prompts) > 1:
    previous_c = None
    slerp_c_vectors = []
    for i, c in enumerate(map(lambda x: get_conditioning_vector(x), prompts)):
        if i == 0:
            slerp_c_vectors.append(c)
        else:
            start_norm = previous_c.flatten()/torch.norm(previous_c.flatten())
            end_norm = c.flatten()/torch.norm(c.flatten())
            omega = torch.acos((start_norm*end_norm).sum())
            frames = round(omega.item() * frames_per_degree * 57.2957795)

            original_c_shape = c.shape
            c_vectors = get_slerp_vectors(previous_c.flatten(), c.flatten(), frames=frames)
            c_vectors = c_vectors.reshape(-1, *original_c_shape)
            slerp_c_vectors.extend(list(c_vectors[1:])) # drop first frame to prevent repeating frames
            if loop and i == len(prompts) - 1:
                c_vectors = get_slerp_vectors(c.flatten(), slerp_c_vectors[0].flatten(), frames=frames)
                c_vectors = c_vectors.reshape(-1, *original_c_shape)
                slerp_c_vectors.extend(list(c_vectors[1:-1])) # drop first and last frame to prevent repeating frames
        previous_c = c
    data = ['']
else:
    data = list(chunk(prompts, batch_size))
    interpolate = False

In [None]:
################################################

In [None]:
# initialize midas depth model
midas_model, midas_transform, midas_net_w, midas_net_h, midas_resize_mode, midas_normalization = init_midas_depth_model()

In [None]:
# make sure sampler is DDIM
sampler = DDIMSampler(model)

In [None]:
# Run animation loop - WIP
TRANSLATION_SCALE = 1.0/200.0
stop_on_next_loop = False
for frame_num in tqdm(range(args.start_frame, args.max_frames), desc='Frames'):
    if stop_on_next_loop:
        break

    if frame_num == 0:
        init_image = args.init_image

    if args.animation_mode == "2D":
        if args.key_frames:
            angle = args.angle_series[frame_num]
            zoom = args.zoom_series[frame_num]
            translation_x = args.translation_x_series[frame_num]
            translation_y = args.translation_y_series[frame_num]
            print(
                f'angle: {angle}',
                f'zoom: {zoom}',
                f'translation_x: {translation_x}',
                f'translation_y: {translation_y}',
            )

        if frame_num > 0:
            # seed += 1
            if resume_run and frame_num == start_frame:
                img_0 = cv2.imread(
                    batchFolder+f"/{batch_name}({batchNum})_{start_frame-1:04}.png")
            else:
                img_0 = cv2.imread('prevFrame.png')
                center = (1*img_0.shape[1]//2, 1*img_0.shape[0]//2)
                trans_mat = np.float32(
                    [[1, 0, translation_x],
                     [0, 1, translation_y]]
                )
                rot_mat = cv2.getRotationMatrix2D(center, angle, zoom)
                trans_mat = np.vstack([trans_mat, [0, 0, 1]])
                rot_mat = np.vstack([rot_mat, [0, 0, 1]])
                transformation_matrix = np.matmul(rot_mat, trans_mat)
                img_0 = cv2.warpPerspective(
                    img_0,
                    transformation_matrix,
                    (img_0.shape[1], img_0.shape[0]),
                    borderMode=cv2.BORDER_WRAP
                )

            cv2.imwrite('prevFrameScaled.png', img_0)
            init_image = 'prevFrameScaled.png'
        
    if args.animation_mode == "3D":
        if frame_num > 0:
            # seed += 1
            if resume_run and frame_num == start_frame:
                img_filepath = batchFolder + \
                    f"/{batch_name}({batchNum})_{start_frame-1:04}.png"
                if turbo_mode and frame_num > turbo_preroll:
                    shutil.copyfile(img_filepath, 'oldFrameScaled.png')
            else:
                img_filepath = 'prevFrame.png'

            next_step_pil = do_3d_step(
                img_filepath, frame_num, midas_model, midas_transform)
            next_step_pil.save('prevFrameScaled.png')

            # Turbo mode - skip some diffusions, use 3d morph for clarity and to save time
            if turbo_mode:
                if frame_num == turbo_preroll:  # start tracking oldframe
                    # stash for later blending
                    next_step_pil.save('oldFrameScaled.png')
                elif frame_num > turbo_preroll:
                    # set up 2 warped image sequences, old & new, to blend toward new diff image
                    old_frame = do_3d_step(
                        'oldFrameScaled.png', frame_num, midas_model, midas_transform)
                    old_frame.save('oldFrameScaled.png')
                    if frame_num % int(turbo_steps) != 0:
                        print(
                            'turbo skip this frame: skipping clip diffusion steps')
                        filename = f'{batch_name}({batchNum})_{frame_num:04}.png'
                        blend_factor = (
                            (frame_num % int(turbo_steps))+1)/int(turbo_steps)
                        print(
                            'turbo skip this frame: skipping clip diffusion steps and saving blended frame')
                        # this is already updated..
                        newWarpedImg = cv2.imread('prevFrameScaled.png')
                        oldWarpedImg = cv2.imread('oldFrameScaled.png')
                        blendedImage = cv2.addWeighted(
                            newWarpedImg, blend_factor, oldWarpedImg, 1-blend_factor, 0.0)
                        cv2.imwrite(
                            f'{batchFolder}/{filename}', blendedImage)
                        # save it also as prev_frame to feed next iteration
                        next_step_pil.save(f'{img_filepath}')
                        if vr_mode:
                            generate_eye_views(
                                TRANSLATION_SCALE, batchFolder, filename, frame_num, midas_model, midas_transform)
                        continue
                    else:
                        # if not a skip frame, will run diffusion and need to blend.
                        oldWarpedImg = cv2.imread('prevFrameScaled.png')
                        # swap in for blending later
                        cv2.imwrite(f'oldFrameScaled.png', oldWarpedImg)
                        print('clip/diff this frame - generate clip diff image')

            init_image = 'prevFrameScaled.png'

        if args.animation_mode == "Video Input":
            init_scale = args.video_init_frames_scale
            skip_steps = args.calc_frames_skip_steps
            if not video_init_seed_continuity:
                seed += 1
            if video_init_flow_warp:
                if frame_num == 0:
                    skip_steps = args.video_init_skip_steps
                    init_image = f'{videoFramesFolder}/{frame_num+1:04}.jpg'
                if frame_num > 0:
                    prev = PIL.Image.open(
                        batchFolder+f"/{batch_name}({batchNum})_{frame_num-1:04}.png")

                    frame1_path = f'{videoFramesFolder}/{frame_num:04}.jpg'
                    frame2 = PIL.Image.open(
                        f'{videoFramesFolder}/{frame_num+1:04}.jpg')
                    flo_path = f"/{flo_folder}/{frame1_path.split('/')[-1]}.npy"

                    init_image = 'warped.png'
                    print(video_init_flow_blend)
                    weights_path = None
                    if video_init_check_consistency:
                        # TBD
                        pass

                    warp(prev, frame2, flo_path, blend=video_init_flow_blend,
                         weights_path=weights_path).save(init_image)

            else:
                init_image = f'{videoFramesFolder}/{frame_num+1:04}.jpg'

    seed_everything(args.seed+frame_num)

    init = None
    if init_image is not None:
        init = load_img(init_image, (args.W, args.H)).to(device)
        init = repeat(init, '1 ... -> b ...', b=args.batch_size)
        init_latent = model.get_first_stage_encoding(
            model.encode_first_stage(init))  # move to latent space

        sampler.make_schedule(ddim_num_steps=args.ddim_steps,
                              ddim_eta=args.ddim_eta, verbose=False)

        noise_strength = args.init_noise_strength if frame_num == 0 else args.previous_frame_noise_strength
        assert 0. <= noise_strength <= 1., 'can only work with strength in [0.0, 1.0]'
        t_enc = int(noise_strength * args.ddim_steps)
        print(f"target t_enc is {t_enc} steps")

    if args.perlin_init:
        if args.perlin_mode == 'color':
            init = create_perlin_noise(
                [1.5**-i*0.5 for i in range(12)], 1, 1, False)
            init2 = create_perlin_noise(
                [1.5**-i*0.5 for i in range(8)], 4, 4, False)
        elif args.perlin_mode == 'gray':
            init = create_perlin_noise(
                [1.5**-i*0.5 for i in range(12)], 1, 1, True)
            init2 = create_perlin_noise(
                [1.5**-i*0.5 for i in range(8)], 4, 4, True)
        else:
            init = create_perlin_noise(
                [1.5**-i*0.5 for i in range(12)], 1, 1, False)
            init2 = create_perlin_noise(
                [1.5**-i*0.5 for i in range(8)], 4, 4, True)
        init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(
            2).to(device).unsqueeze(0).mul(2).sub(1)
        del init2

    print(f'Frame {frame_num}')

    image_display = Output()
    with torch.no_grad():
        with precision_scope("cuda"):
            with model.ema_scope():
                tic = time.time()
                all_samples = list()
                print('')
                display(image_display)
                gc.collect()
                torch.cuda.empty_cache()

                if perlin_init:
                    init = regen_perlin()
                    start_code = init

                for prompts in tqdm(data, desc="data"):
                    uc = None
                    if unconditional_guidance_scale != 1.0:
                        uc = model.get_learned_conditioning(batch_size * [""])
                    if isinstance(prompts, tuple):
                        prompts = list(prompts)
                    
                    if interpolate:
                        c = slerp_c_vectors[frame_num%len(slerp_c_vectors)]
                        c = torch.cat([c])
                    else:
                        c = torch.cat([get_conditioning_vector(prompt) for prompt in prompts])

                    if init_image == None or init_image == '' or init_image == [] or init_image == ['']:
                        shape = [args.C, args.H // args.f, args.W // args.f]
                        samples_ddim, _ = sampler.sample(S=args.ddim_steps,
                                                         conditioning=c,
                                                         batch_size=args.batch_size,
                                                         shape=shape,
                                                         verbose=False,
                                                         unconditional_guidance_scale=args.unconditional_guidance_scale,
                                                         unconditional_conditioning=uc,
                                                         eta=args.ddim_eta,
                                                         x_T=start_code)

                    else:
                        # encode (scaled latent)
                        z_enc = sampler.stochastic_encode(
                            init_latent, torch.tensor([t_enc]*args.batch_size).to(device))
                        # decode it
                        samples_ddim = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=args.unconditional_guidance_scale,
                                                      unconditional_conditioning=uc)

                    x_samples_ddim = model.decode_first_stage(samples_ddim)
                    x_samples_ddim = torch.clamp(
                        (x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

                    for x_sample in x_samples_ddim:
                        x_sample = 255. * \
                            rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                        img = Image.fromarray(x_sample.astype(np.uint8))
                        display(img)

                        if args.animation_mode != "None":
                            filename = f'{batch_name}({batchNum})_{frame_num}.png'
                            img.save('prevFrame.png')
                            img.save(f'{batchFolder}/{filename}')
                            # if frame_num == 0:
                            #     save_settings()
                            if args.animation_mode == "3D":
                                # If turbo, save a blended image
                                if turbo_mode and frame_num > 0:
                                    # Mix new image with prevFrameScaled
                                    blend_factor = (1)/int(turbo_steps)
                                    # This is already updated..
                                    newFrame = cv2.imread('prevFrame.png')
                                    prev_frame_warped = cv2.imread(
                                        'prevFrameScaled.png')
                                    blendedImage = cv2.addWeighted(
                                        newFrame, blend_factor, prev_frame_warped, (1-blend_factor), 0.0)
                                    cv2.imwrite(
                                        f'{batchFolder}/{filename}', blendedImage)
                                else:
                                    img.save(f'{batchFolder}/{filename}')

                                if vr_mode:
                                    generate_eye_views(
                                        TRANSLATION_SCALE, batchFolder, filename, frame_num, midas_model, midas_transform)
                toc = time.time()


Image, Mask -> Image (inpainting)

In [None]:
# Coming Soon