# Stable Diffusion KLMC2 Animation

<div>
<img src="https://images.squarespace-cdn.com/content/v1/6213c340453c3f502425776e/a432c21c-bb12-4f38-b5e2-1c12a3c403f6/Animated-Logo_1.gif" width="150"/>
</div>


Notebook by [Katherine Crowson](https://twitter.com/RiversHaveWings)

Sponsored by [StabilityAI](https://twitter.com/stabilityai)

Generate animations with [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) 1.4, using the [KLMC2 discretization of underdamped Langevin dynamics](https://arxiv.org/abs/1807.09382). The notebook is largely inspired by [Ajay Jain](https://twitter.com/ajayj_) and [Ben Poole](https://twitter.com/poolio)'s paper [Journey to the BAOAB-limit](https://www.ajayjain.net/journey)&mdash;thank you so much for it!

---

## Modifications Provenance

Original notebook URL - [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1m8ovBpO2QilE2o4O-p2PONSwqGn4_x2G)

Features and QOL Modifications by [David Marx](https://twitter.com/DigThatData) - [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dmarx/notebooks/blob/main/Stable_Diffusion_KLMC2_Animation.ipynb)

* Keyframed prompts and settings
* Multiprompt conditioning w independent prompt schedules
* Set seed for deterministic output
* Mount Google Drive
* Faster Setup
* Set output filename
* Fancy GPU info
* Video embed optional
* Cheaper default runtime

In [None]:
#@title Check GPU
#!nvidia-smi

import pandas as pd
import subprocess

def gpu_info():
    outv = subprocess.run([
        'nvidia-smi',
            # these lines concatenate into a single query string
            '--query-gpu='
            'timestamp,'
            'name,'
            'utilization.gpu,'
            'utilization.memory,'
            'memory.used,'
            'memory.free,'
            ,
        '--format=csv'
        ],
        stdout=subprocess.PIPE).stdout.decode('utf-8')

    header, rec = outv.split('\n')[:-1]
    return pd.DataFrame({' '.join(k.strip().split('.')).capitalize():v for k,v in zip(header.split(','), rec.split(','))}, index=[0]).T

gpu_info()

In [None]:
#@title Install Dependencies

# @markdown Your runtime will automatically restart after running this cell.
# @markdown You should only need to run this cell when setting up a new runtime. After future runtime restarts,
# @markdown you should be able to skip this cell.

!pip install ftfy einops braceexpand requests transformers clip open_clip_torch omegaconf pytorch-lightning kornia k-diffusion ninja
#!pip install -U torch torchvision
!pip install -U git+https://github.com/huggingface/huggingface_hub
!pip install napm keyframed

# for deforum loading code
!pip install omegaconf

#####################
# Install more Deps #
#####################

#!git clone https://github.com/Stability-AI/stablediffusion
#!git clone https://github.com/CompVis/stable-diffusion
#!git clone https://github.com/CompVis/taming-transformers
#!git clone https://github.com/CompVis/latent-diffusion

# !pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers

exit() # oh is this a way to restart the runtime? clever!

In [None]:
# @markdown # Setup Workspace { display-mode: "form" }

###################
# Setup Workspace #
###################

import os
from pathlib import Path

mount_gdrive = True # @param {type:'boolean'}

# defaults
outdir = Path('./frames')
if not os.environ.get('XDG_CACHE_HOME'):
    os.environ['XDG_CACHE_HOME'] = str(Path('~/.cache').expanduser())

if mount_gdrive:
    from google.colab import drive
    drive.mount('/content/drive')
    Path('/content/drive/MyDrive/AI/models/.cache/').mkdir(parents=True, exist_ok=True) 
    os.environ['XDG_CACHE_HOME']='/content/drive/MyDrive/AI/models/.cache'
    outdir = Path('/content/drive/MyDrive/AI/klmc2/frames/')

# make sure the paths we need exist
outdir.mkdir(parents=True, exist_ok=True)
os.environ['NAPM_PATH'] = str( Path(os.environ['XDG_CACHE_HOME']) / 'napm' )
Path(os.environ['NAPM_PATH']).mkdir(parents=True, exist_ok=True)


import napm

url = 'https://github.com/Stability-AI/stablediffusion'
napm.pseudoinstall_git_repo(url, add_install_dir_to_path=True)


In [None]:
# @markdown # Imports and Definitions { display-mode: "form" }

###########
# imports #
###########

import napm

from base64 import b64encode
from collections import defaultdict
from concurrent import futures
import math
from pathlib import Path
import sys

import functorch
from IPython import display
import k_diffusion as K
from omegaconf import OmegaConf
from PIL import Image
import torch
from torch import nn
from tqdm.auto import tqdm, trange

#sys.path.extend(['./stablediffusion'])
from ldm.util import instantiate_from_config

from requests.exceptions import HTTPError
import huggingface_hub

from urllib.parse import urlparse

from keyframed import Curve, ParameterGroup, Keyframe
import math


#########################
# Define useful globals #
#########################

cpu = torch.device("cpu")
device = torch.device("cuda")

############################

model_dir_str=str(Path(os.environ['XDG_CACHE_HOME']))

sdmodelid2hfrepo = {
    "sd-v1-4":"CompVis/stable-diffusion-v-1-4-original",
    "sd-v1-5":"runwayml/stable-diffusion-v1-5",
}

sdmodelid2hfckpt = {
    "sd-v1-4":"sd-v1-4.ckpt",
    "sd-v1-5":"v1-5-pruned-emaonly.ckpt",
}

sdmodelid2yamlurl = {
    "sd-v1-4":"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml",
    "sd-v1-5":"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml",
}

sdmodelid2ckptstyle = {
    "sd-v1-4":"compvis",
    "sd-v1-5":"compvis",
}

############################

vaemodelid2hfrepo = {
    "vae-ft-mse-840k":"stabilityai/sd-vae-ft-mse-original",
    "vae-ft-ema-560k":"stabilityai/sd-vae-ft-ema-original",
    #"vae-orig":,
}

vaemodelid2hfckpt = {
    "vae-ft-mse-840k":"vae-ft-mse-840000-ema-pruned.ckpt",
    "vae-ft-ema-560k":"vae-ft-ema-560000-ema-pruned.ckpt",
    #"vae-orig":,
}

vaemodelid2yamlurl = {
    "vae-ft-mse-840k":"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/models/first_stage_models/kl-f8/config.yaml",
    "vae-ft-ema-560k":"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/models/first_stage_models/kl-f8/config.yaml",
    #"vae-orig":"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/models/first_stage_models/kl-f8/config.yaml",
}


##############################
# Define necessary functions #
##############################

class Prompt:
    def __init__(
        self,
        text,
        weight_schedule,
        ease_in=None,
        ease_out=None,
        ):
      c = sd_model.get_learned_conditioning([text])
      self.text=text
      self.encoded=c
      self.weight=Curve(
          weight_schedule, 
          default_interpolation='linear', 
          ease_in=ease_in, 
          ease_out=ease_out)

##################

class NormalizingCFGDenoiser(nn.Module):
    def __init__(self, model, g):
        super().__init__()
        self.inner_model = model
        self.g = g
        self.eps_norms = defaultdict(lambda: (0, 0))

    def mean_sq(self, x):
        return x.pow(2).flatten(1).mean(1)

    @torch.no_grad()
    def update_eps_norm(self, eps, sigma):
        sigma = sigma[0].item()
        eps_norm = self.mean_sq(eps).mean()
        eps_norm_avg, count = self.eps_norms[sigma]
        eps_norm_avg = eps_norm_avg * count / (count + 1) + eps_norm / (count + 1)
        self.eps_norms[sigma] = (eps_norm_avg, count + 1)
        return eps_norm_avg

    def forward(self, x, sigma, uncond, cond, g):
        x_in = torch.cat([x] * 2)
        sigma_in = torch.cat([sigma] * 2)
        cond_in = torch.cat([uncond, cond])

        denoised = self.inner_model(x_in, sigma_in, cond=cond_in)
        eps = K.sampling.to_d(x_in, sigma_in, denoised)
        eps_uc, eps_c = eps.chunk(2)
        eps_norm = self.update_eps_norm(eps, sigma).sqrt()
        c = eps_c - eps_uc
        cond_scale = g * eps_norm / self.mean_sq(c).sqrt()
        eps_final = eps_uc + c * K.utils.append_dims(cond_scale, x.ndim)
        return x - eps_final * K.utils.append_dims(sigma, eps.ndim)


@torch.no_grad()
def sample_mcmc_klmc2(
    model, x, 
    sigma_min, sigma, sigma_max, 
    n, 
    hvp_method='reverse', 
    callback=None, 
    disable=None, 
    prompts=None,
    settings=None, # g, h, gamma, alpha, tau, prompt
):
    s_in = x.new_ones([x.shape[0]])
    sigma = torch.tensor(sigma, device=x.device)
    sigmas = K.sampling.get_sigmas_karras(6, sigma_min, sigma.item(), device=x.device)[:-1]

    uc = sd_model.get_learned_conditioning([''])
    extra_args = {'uncond': uc}
    v = torch.randn_like(x) * sigma # ... I guess?

    for i in trange(n, disable=disable):

        h = settings[i]['h']
        gamma = settings[i]['gamma']
        alpha = settings[i]['alpha']
        tau = settings[i]['tau']

        h = torch.tensor(h, device=x.device)
        gamma = torch.tensor(gamma, device=x.device)
        alpha = torch.tensor(alpha, device=x.device)
        tau = torch.tensor(tau, device=x.device)

        # Model helper functions

        def hvp_fn_forward_functorch(x, sigma, v, **extra_args):
            def grad_fn(x, sigma):
                denoised = model(x, sigma * s_in, **extra_args)
                return (x - denoised) + alpha * x
            jvp_fn = lambda v: functorch.jvp(grad_fn, (x, sigma), (v, torch.zeros_like(sigma)))
            grad, jvp_out = functorch.vmap(jvp_fn)(v)
            return grad[0], jvp_out

        def hvp_fn_reverse(x, sigma, v, **extra_args):
            def grad_fn(x, sigma):
                denoised = model(x, sigma * s_in, **extra_args)
                return (x - denoised) + alpha * x
            vjps = []
            with torch.enable_grad():
                x_ = x.clone().requires_grad_()
                grad = grad_fn(x_, sigma)
                for k, item in enumerate(v):
                    vjp_out = torch.autograd.grad(grad, x_, item, retain_graph=k < len(v) - 1)[0]
                    vjps.append(vjp_out)
            return grad, torch.stack(vjps)

        def hvp_fn_zero(x, sigma, v, **extra_args):
            def grad_fn(x, sigma):
                denoised = model(x, sigma * s_in, **extra_args)
                return (x - denoised) + alpha * x
            return grad_fn(x, sigma), torch.zeros_like(v)

        def hvp_fn_fake(x, sigma, v, **extra_args):
            def grad_fn(x, sigma):
                denoised = model(x, sigma * s_in, **extra_args)
                return (x - denoised) + alpha * x
            return grad_fn(x, sigma), (1 + alpha) * v

        hvp_fns = {'forward-functorch': hvp_fn_forward_functorch,
                  'reverse': hvp_fn_reverse,
                  'zero': hvp_fn_zero,
                  'fake': hvp_fn_fake}

        hvp_fn = hvp_fns[hvp_method]

        # KLMC2 helper functions
        def psi_0(gamma, t):
            return torch.exp(-gamma * t)

        def psi_1(gamma, t):
            return -torch.expm1(-gamma * t) / gamma

        def psi_2(gamma, t):
            return (torch.expm1(-gamma * t) + gamma * t) / gamma ** 2

        def phi_2(gamma, t_):
            t = t_.double()
            out = (torch.exp(-gamma * t) * (torch.expm1(gamma * t) - gamma * t)) / gamma ** 2
            return out.to(t_)

        def phi_3(gamma, t_):
            t = t_.double()
            out = (torch.exp(-gamma * t) * (2 + gamma * t + torch.exp(gamma * t) * (gamma * t - 2))) / gamma ** 3
            return out.to(t_)


        # Compute model outputs and sample noise
        x_trapz = torch.linspace(0, h, 1001, device=x.device)
        y_trapz = [fun(gamma, x_trapz) for fun in (psi_0, psi_1, phi_2, phi_3)]
        noise_cov = torch.tensor([[torch.trapz(y_trapz[i] * y_trapz[j], x=x_trapz) for j in range(4)] for i in range(4)], device=x.device)
        noise_v, noise_x, noise_v2, noise_x2 = torch.distributions.MultivariateNormal(x.new_zeros([4]), noise_cov).sample(x.shape).unbind(-1)
            
        extra_args['g']=g

        # loop over prompts and aggregate gradients for multicond
        grad = torch.zeros_like(x)
        h2_v = torch.zeros_like(x)
        h2_noise_v2 = torch.zeros_like(x)
        h2_noise_x2 = torch.zeros_like(x)
        wt_norm = 0
        for prompt in prompts:
            wt = prompt.weight[i]
            if wt == 0:
                continue
            wt_norm += wt
            wt = torch.tensor(wt, device=x.device)
            extra_args['cond'] = prompt.encoded

            # Estimate gradient and hessian
            grad_, (h2_v_, h2_noise_v2_, h2_noise_x2_) = hvp_fn(
                x, sigma, torch.stack([v, noise_v2, noise_x2]),
                **extra_args
            )

            grad = grad + grad_ * wt 
            h2_v = h2_v + h2_v_ * wt
            h2_noise_v2 = h2_noise_v2 + h2_noise_v2_ * wt
            h2_noise_x2 = h2_noise_x2 + h2_noise_x2_ * wt

        # Normalize gradient to magnitude it'd have if just single prompt w/ wt=1.
        # simplifies multicond w/o deep frying image or adding hyperparams
        grad = grad / wt_norm 
        h2_v = h2_v / wt_norm
        h2_noise_v2 = h2_noise_v2 / wt_norm
        h2_noise_x2 = h2_noise_x2 / wt_norm
        

        # DPM-Solver++(2M) refinement steps
        x_refine = x
        use_dpm = True
        old_denoised = None
        for j in range(len(sigmas) - 1):
            if j == 0:
                denoised = x_refine - grad
            else:
                denoised = model(x_refine, sigmas[j] * s_in, **extra_args)
            dt_ode = sigmas[j + 1] - sigmas[j]
            if not use_dpm or old_denoised is None or sigmas[j + 1] == 0:
                eps = K.sampling.to_d(x_refine, sigmas[j], denoised)
                x_refine = x_refine + eps * dt_ode
            else:
                h_ode = sigmas[j].log() - sigmas[j + 1].log()
                h_last = sigmas[j - 1].log() - sigmas[j].log()
                fac = h_ode / (2 * h_last)
                denoised_d = (1 + fac) * denoised - fac * old_denoised
                eps = K.sampling.to_d(x_refine, sigmas[j], denoised_d)
                x_refine = x_refine + eps * dt_ode
            old_denoised = denoised
        if callback is not None:
            callback({'i': i, 'denoised': x_refine})

        # Update the chain
        noise_std = (2 * gamma * tau * sigma ** 2).sqrt()
        v_next = 0 + psi_0(gamma, h) * v - psi_1(gamma, h) * grad - phi_2(gamma, h) * h2_v + noise_std * (noise_v - h2_noise_v2)
        x_next = x + psi_1(gamma, h) * v - psi_2(gamma, h) * grad - phi_3(gamma, h) * h2_v + noise_std * (noise_x - h2_noise_x2)
        v, x = v_next, x_next

    x = x - grad
    return x


def show_video(video_path, video_width=512):
  video_file = open(video_path, "r+b").read()
  video_url = f"data:video/mp4;base64,{b64encode(video_file).decode()}"
  return display.HTML(f"""<video width={video_width} controls><source src="{video_url}"></video>""")


In [None]:
#@markdown **Select and Load Model**

# scavenged from:
#   https://github.com/deforum/stable-diffusion/blob/main/Deforum_Stable_Diffusion.ipynb

from omegaconf import OmegaConf
import requests
import torch

import napm
from ldm.util import instantiate_from_config

models_path = "/content/models" #@param {type:"string"}
if mount_gdrive:
  models_path_gdrive = "/content/drive/MyDrive/AI/models" #@param {type:"string"}
  models_path = models_path_gdrive

model_config = "v1-inference.yaml" #@param ["custom","v1-inference.yaml"]
model_checkpoint =  "sd-v1-4.ckpt" #@param ["custom","sd-v1-4-full-ema.ckpt","sd-v1-4.ckpt","sd-v1-3-full-ema.ckpt","sd-v1-3.ckpt","sd-v1-2-full-ema.ckpt","sd-v1-2.ckpt","sd-v1-1-full-ema.ckpt","sd-v1-1.ckpt", "robo-diffusion-v1.ckpt","waifu-diffusion-v1-3.ckpt"]
if model_checkpoint == "waifu-diffusion-v1-3.ckpt":
    model_checkpoint = "model-epoch05-float16.ckpt"
custom_config_path = "" #@param {type:"string"}
custom_checkpoint_path = "" #@param {type:"string"}

load_on_run_all = True #@param {type: 'boolean'}
half_precision = True # check
check_sha256 = True #@param {type:"boolean"}

model_map = {
    "sd-v1-4-full-ema.ckpt": {
        'sha256': '14749efc0ae8ef0329391ad4436feb781b402f4fece4883c7ad8d10556d8a36a',
        'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/blob/main/sd-v1-4-full-ema.ckpt',
        'requires_login': True,
        },
    "sd-v1-4.ckpt": {
        'sha256': 'fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556',
        'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt',
        'requires_login': True,
        },
    "sd-v1-3-full-ema.ckpt": {
        'sha256': '54632c6e8a36eecae65e36cb0595fab314e1a1545a65209f24fde221a8d4b2ca',
        'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-3-original/blob/main/sd-v1-3-full-ema.ckpt',
        'requires_login': True,
        },
    "sd-v1-3.ckpt": {
        'sha256': '2cff93af4dcc07c3e03110205988ff98481e86539c51a8098d4f2236e41f7f2f',
        'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-3-original/resolve/main/sd-v1-3.ckpt',
        'requires_login': True,
        },
    "sd-v1-2-full-ema.ckpt": {
        'sha256': 'bc5086a904d7b9d13d2a7bccf38f089824755be7261c7399d92e555e1e9ac69a',
        'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/blob/main/sd-v1-2-full-ema.ckpt',
        'requires_login': True,
        },
    "sd-v1-2.ckpt": {
        'sha256': '3b87d30facd5bafca1cbed71cfb86648aad75d1c264663c0cc78c7aea8daec0d',
        'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/resolve/main/sd-v1-2.ckpt',
        'requires_login': True,
        },
    "sd-v1-1-full-ema.ckpt": {
        'sha256': 'efdeb5dc418a025d9a8cc0a8617e106c69044bc2925abecc8a254b2910d69829',
        'url':'https://huggingface.co/CompVis/stable-diffusion-v-1-1-original/resolve/main/sd-v1-1-full-ema.ckpt',
        'requires_login': True,
        },
    "sd-v1-1.ckpt": {
        'sha256': '86cd1d3ccb044d7ba8db743d717c9bac603c4043508ad2571383f954390f3cea',
        'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-1-original/resolve/main/sd-v1-1.ckpt',
        'requires_login': True,
        },
    "robo-diffusion-v1.ckpt": {
        'sha256': '244dbe0dcb55c761bde9c2ac0e9b46cc9705ebfe5f1f3a7cc46251573ea14e16',
        'url': 'https://huggingface.co/nousr/robo-diffusion/resolve/main/models/robo-diffusion-v1.ckpt',
        'requires_login': False,
        },
    "model-epoch05-float16.ckpt": {
        'sha256': '26cf2a2e30095926bb9fd9de0c83f47adc0b442dbfdc3d667d43778e8b70bece',
        'url': 'https://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/model-epoch05-float16.ckpt',
        'requires_login': False,
        },
}

# config path
ckpt_config_path = custom_config_path if model_config == "custom" else os.path.join(models_path, model_config)
if os.path.exists(ckpt_config_path):
    print(f"{ckpt_config_path} exists")
else:
    #ckpt_config_path = "./stable-diffusion/configs/stable-diffusion/v1-inference.yaml"
    ckpt_config_path = "./v1-inference.yaml"
    if not Path(ckpt_config_path).exists():
        !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
    
print(f"Using config: {ckpt_config_path}")

# checkpoint path or download
ckpt_path = custom_checkpoint_path if model_checkpoint == "custom" else os.path.join(models_path, model_checkpoint)
ckpt_valid = True
if os.path.exists(ckpt_path):
    print(f"{ckpt_path} exists")
elif 'url' in model_map[model_checkpoint]:
    url = model_map[model_checkpoint]['url']

    # CLI dialogue to authenticate download
    if model_map[model_checkpoint]['requires_login']:
        print("This model requires an authentication token")
        print("Please ensure you have accepted its terms of service before continuing.")

        username = input("What is your huggingface username?:")
        token = input("What is your huggingface token?:")

        _, path = url.split("https://")

        url = f"https://{username}:{token}@{path}"

    # contact server for model
    print(f"Attempting to download {model_checkpoint}...this may take a while")
    ckpt_request = requests.get(url)
    request_status = ckpt_request.status_code

    # inform user of errors
    if request_status == 403:
      raise ConnectionRefusedError("You have not accepted the license for this model.")
    elif request_status == 404:
      raise ConnectionError("Could not make contact with server")
    elif request_status != 200:
      raise ConnectionError(f"Some other error has ocurred - response code: {request_status}")

    # write to model path
    with open(os.path.join(models_path, model_checkpoint), 'wb') as model_file:
        model_file.write(ckpt_request.content)
else:
    print(f"Please download model checkpoint and place in {os.path.join(models_path, model_checkpoint)}")
    ckpt_valid = False

if check_sha256 and model_checkpoint != "custom" and ckpt_valid:
    import hashlib
    print("\n...checking sha256")
    with open(ckpt_path, "rb") as f:
        bytes = f.read() 
        hash = hashlib.sha256(bytes).hexdigest()
        del bytes
    if model_map[model_checkpoint]["sha256"] == hash:
        print("hash is correct\n")
    else:
        print("hash in not correct\n")
        ckpt_valid = False

if ckpt_valid:
    print(f"Using ckpt: {ckpt_path}")

def load_model_from_config(config, ckpt, verbose=False, device='cuda', half_precision=True):
    map_location = "cuda" #@param ["cpu", "cuda"]
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location=map_location)
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    if half_precision:
        model = model.half().to(device)
    else:
        model = model.to(device)
    model.eval()
    return model

if load_on_run_all and ckpt_valid:
    local_config = OmegaConf.load(f"{ckpt_config_path}")
    model = load_model_from_config(local_config, f"{ckpt_path}", half_precision=half_precision)
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = model.to(device)
    sd_model=model

In [None]:
# @title Settings

# @markdown The number of frames to sample:
n = 500 # @param {type:"integer"}

# @markdown If seed is negative, a random seed will be used
seed = 3126649841  # @param {type:"number"}

init_image = "profile_arcanegan-pinkfloyd.png" # @param {type:'string'}

# @markdown ---

# @markdown The strength of the conditioning on the prompt:
g = 0.09 # @param {type:"number"}

# @markdown The noise level to sample at:
sigma = 0.25 # @param {type:"number"}

# @markdown Step size (range 0 to 1):
h = 0.1 # @param {type:"number"}

# @markdown Friction (2 is critically damped, lower -> smoother animation):
gamma = 1.0 # @param {type:"number"}

# @markdown Quadratic penalty ("weight decay") strength:
alpha = 0.005 # @param {type:"number"}

# @markdown Temperature (adjustment to the amount of noise added per step):
tau = 1.0 # @param {type:"number"}

# @markdown The HVP method:
# @markdown <br><small>`forward-functorch` and `reverse` provide real second derivatives. Compatibility, speed, and memory usage vary by model and xformers configuration.
# @markdown `fake` is very fast and low memory but inaccurate. `zero` (fallback to first order KLMC) is not recommended.</small>
hvp_method = 'fake' # @param ["forward-functorch", "reverse", "fake", "zero"]


In [None]:
from PIL import Image
from pathlib import Path

if Path(init_image).exists():
  init_im_pil = Image.open(init_image)


In [None]:
#init_im_pil.size
import numpy as np

x_pil = init_im_pil.resize([512,512])
x_np = np.array(x_pil.convert('RGB')).astype(np.float16) / 255.0

#image = np.array(image).astype(np.float16) / 255.0
x = x_np[None].transpose(0, 3, 1, 2)
x = 2.*x - 1.
x = torch.from_numpy(x).to('cuda')
x = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(x))

In [None]:
#@title Prompts

#  [  
#    ["first prompt will be used to initialize the image", {time:weight, time:weight...}], 
#    ["more prompts if you want", {...}], 
#  ...]

# if a weight for time=0 isn't specified, the weight is assumed to be zero.


prompt_params = [
    # FIRST PROMPT INITIALIZES IMAGE
    #["portrait of queen elizabeth at 20 years old", {0:1, 50:1, 100:0}],
    #["portrait of queen elizabeth at 82 years old", {50:0, 100:1}],
    ["a man with glasses standing in front of a pink sunset, inspired by Android Jones, league of legends, computer art, drachenlord, lucy in the sky with diamonds, profile picture, recolored, thumbnail, swirling energy, in style of south park, infini - d - render, three handed god, stoner beanie hipster, fiberpunk", {0:1}]
]


In [None]:
# @title Build prompt and settings objects

plot_prompt_weight_curves = True # @param {type: 'boolean'}

#################

def sin2(t):
    return (math.sin(t * math.pi / 2)) ** 2

prompts = [
    Prompt(text, weight_schedule, ease_in=sin2, ease_out=sin2) 
    for (text, weight_schedule) in prompt_params
]


curved_settings = ParameterGroup({
    #'g':Curve(g),
    #'sigma':Curve(sigma),
    'g':Curve({0:0.08,50:1.1}), # warm up cfg
    'sigma':Curve({0:.25,50:1, 125:1, 200:2}), # warm up noise w init image
    'h':Curve(h),
    'gamma':Curve(gamma),
    'alpha':Curve(alpha),
    'tau':Curve(tau),
    'seed':Curve(seed),
})



if plot_prompt_weight_curves:

    import matplotlib.pyplot as plt
    import numpy as np 


    ytot=np.array([0 for _ in range(n)])
    for prompt in prompts:#[:3]:
      xs = np.array(range(n))
      ys = np.array([prompt.weight[x] for x in xs])
      ytot=ytot+ys
      plt.plot(xs, ys)
    plt.title("prompt weight schedules")
    plt.show()

    plt.plot(xs, ytot)
    plt.title("sum weight\n(aka: why weights get normalized)")
    plt.show()

    for prompt in prompts:#[:3]:
      xs = np.array(range(n))
      ys = np.array([prompt.weight[x] for x in xs])
      plt.plot(xs, ys/ytot)
    plt.title("normalized weights\n(aka: why prompts might seem weighted differently than I asked)")
    plt.show()

In [None]:
#@title Generate Animation Frames


###################

import random

# to do: if random seed, pick one for user and report chosen seed back
if seed >= 0:
    torch.manual_seed(seed)
else:
    seed = random.randrange(0, 4294967295)
print(f"using seed: {seed}")

wrappers = {'eps': K.external.CompVisDenoiser, 'v': K.external.CompVisVDenoiser}
model_wrap = wrappers[sd_model.parameterization](sd_model)
model_wrap_cfg = NormalizingCFGDenoiser(model_wrap, g)
sigma_min, sigma_max = model_wrap.sigmas[0].item(), model_wrap.sigmas[-1].item()

uc = sd_model.get_learned_conditioning([''])
c = prompts[0].encoded
extra_args = {'cond': c, 'uncond': uc}

def save_image_fn(image, name, i):
    pil_image = K.utils.to_pil_image(image)
    if i % 10 == 0 or i == n - 1:
        print(f'\nIteration {i}/{n}:')
        display.display(pil_image)
    if i == n - 1:
        print('\nDone!')
    name = outdir / name
    pil_image.save(name)

# to do: add archival
# Clean up old images and video - save them elsewhere before running this if you want to keep them!
#for p in Path('.').glob('out_*.png'):
for p in outdir.glob('out_*.png'):
    p.unlink()
Path('out.mp4').unlink(missing_ok=True)

torch.cuda.empty_cache()

with torch.cuda.amp.autocast(), futures.ThreadPoolExecutor() as ex:
    def callback(info):
        i = info['i']
        #rgb = vae_model.decode(info['denoised'] / sd_model.scale_factor)
        rgb = sd_model.decode_first_stage(info['denoised'] )
        ex.submit(save_image_fn, rgb, f'out_{i:05}.png', i)

    #x = torch.randn([1, 4, 64, 64], device=device) * sigma_max
    
    # Initialize the chain
    print('Initializing the chain...')
    sigmas_pre = K.sampling.get_sigmas_karras(15, sigma, sigma_max, device=x.device)[:-1]

    extra_args['g'] = curved_settings[0]['g']
    #x = K.sampling.sample_dpmpp_sde(model_wrap_cfg, x, sigmas_pre, extra_args=extra_args)

    print('Actually doing the sampling...')
    sample_mcmc_klmc2(
        model=model_wrap_cfg,
        x=x,
        sigma_min=sigma_min,
        sigma=sigma,
        sigma_max=sigma_max,
        n=n,
        hvp_method=hvp_method,
        callback=callback,
        prompts=prompts,
        settings=curved_settings,
    )


In [None]:
#@title Make the video

outdir_str = str(outdir)

fps = 20 # @param {type:"integer"}
out_fname = "out.mp4" # @param {type: "string"}

print('\nMaking the video...\n')
!cd {outdir_str}; ffmpeg -y -r {fps} -i 'out_%*.png' -crf 15 -preset veryslow -pix_fmt yuv420p {out_fname}

# @markdown If your video is larger than a few MB, attempting to embed it will probably crash
# @markdown the session. If this happens, view the generated video after downloading it first.
embed_video = True # @param {type:'boolean'}

if embed_video:
  print('\nThe video:')
  display.display(show_video(outdir / out_fname))

In [None]:
#@title Licensed under the MIT License { display-mode: "form" }

# Copyright (c) 2022 Katherine Crowson <crowsonkb@gmail.com>
# Copyright (c) 2023 David Marx <david.marx84@gmail.com>

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.