<a href="https://colab.research.google.com/github/jags111/Neuralism-Jax-2.6/blob/main/Jags_V1_JAX_2_6_Neuralism_Edit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Generates images from text prompts with CLIP guided diffusion.

Based on my previous jax port of Katherine Crowson's CLIP guided diffusion notebook.
 - [nshepperd's JAX CLIP Guided Diffusion 512x512.ipynb](https://colab.research.google.com/drive/1ZZi1djM8lU4sorkve3bD6EBHiHs6uNAi)
 - [CLIP Guided Diffusion HQ 512x512.ipynb](https://colab.research.google.com/drive/1V66mUeJbXrTuQITvJunvnWVn96FEbSI3)

Supports both 256x256 and 512x512 OpenAI models.
v2.6?:
 - Added small secondary model for clip guidance.
 - Added anti-jpeg model for clearer samples.
 - Added secondary anti-jpeg classifier.
 - Added Katherine Crowson's v diffusion models (https://github.com/crowsonkb/v-diffusion-jax).
 - Added pixel art model.
 - Added cc12m_1 model (https://github.com/crowsonkb/v-diffusion-pytorch)
 - Reparameterized in terms of cosine t, to allow different schedules; added spliced ddpm+cosine schedule.
 - Added cc12m_1_cfg model (https://github.com/crowsonkb/v-diffusion-pytorch) and more pixel art models.
 ---

##[NeuralismAI](https://twitter.com/NeuralismAI) edit of nshepperd's JAX 2.6 notebook.

This edit consists of forked modification with small changes:
- Cleaner Interface (Less raw code)
- Prompt Queuing system
- Video progress output
- Intermediate saves
- Diffusion model selection
- [Huemin](https://twitter.com/huemin_art)'s simple symmetry

[Original notebook](https://colab.research.google.com/drive/1fW_tPEX7iD3xZK3VBDQ_Y2WnfdSzpacM?usp=sharing#scrollTo=zxGgJmRzq3Cs)
 
 Any suggestions and feedback do update in Neurlaism discord for same.

In [None]:
#@title Licensed under the MIT License { display-mode: "form" }

# Copyright (c) 2021 Katherine Crowson; nshepperd

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

In [None]:
#@title Check GPU
!nvidia-smi


In [None]:
#@title Mount Google Drive or save on colab session
#@markdown You must run this cell either way.

import os # imports here just to prioritize connecting the drive first

MOUNT_DRIVE=True #@param {type:"boolean"}

if MOUNT_DRIVE:
  from google.colab import drive
  drive.mount('/content/drive')
  save_location = '/content/drive/MyDrive/samples/v2'
  model_location = '/content/drive/MyDrive/models'
  os.makedirs(save_location, exist_ok=True)
else:
  save_location = "/content/output_images"
  model_location = 'models'
  os.makedirs(save_location, exist_ok=True)

os.makedirs(model_location, exist_ok=True)

# Initial setup:
- Install dependencies
- Import modules
- Define funtions

In [None]:
# Workaround for https://github.com/googlecolab/colabtools/issues/2452
import os
if os.system("nvidia-smi | grep A100") == 0:
  !pip install -U https://storage.googleapis.com/jax-releases/cuda111/jaxlib-0.1.72+cuda111-cp37-none-manylinux2010_x86_64.whl "jax==0.1.76"
  # https://storage.googleapis.com/jax-releases/cuda111/jaxlib-0.1.71+cuda111-cp37-none-manylinux2010_x86_64.whl
else:
  !pip install https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.1.75%2Bcuda11.cudnn805-cp37-none-manylinux2010_x86_64.whl

In [None]:
# Install dependencies
!pip install dm-haiku==0.0.5 cbor2 ftfy einops braceexpand 
!git clone https://github.com/nshepperd/CLIP_JAX
!git clone https://github.com/nshepperd/jax-guided-diffusion -b v2
!git clone https://github.com/crowsonkb/v-diffusion-jax

In [None]:
import sys
sys.path.append('./CLIP_JAX')
sys.path.append('./jax-guided-diffusion')
sys.path.append('./v-diffusion-jax')
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'

from PIL import Image
from braceexpand import braceexpand
from dataclasses import dataclass
from functools import partial
from subprocess import Popen, PIPE
from google.colab import files
import functools
import io
import math
import re
import requests
import time
import shutil

import numpy as np
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jaxtorch
from jaxtorch import PRNG, Context, Module, nn, init
from tqdm import tqdm

from lib.script_util import create_model_and_diffusion, model_and_diffusion_defaults
from lib import util, openai

from IPython import display
from torchvision import datasets, transforms, utils
from torchvision.transforms import functional as TF
from subprocess import Popen, PIPE
import torch.utils.data
import torch

import diffusion as v_diffusion

from diffusion_models.common import DiffusionOutput, Partial, make_partial, blur_fft, norm1
from diffusion_models.cache import WeakCache
from diffusion_models.schedules import cosine, ddpm, ddpm2, spliced
from diffusion_models.perceptor import vit32, vit16, clip_size, normalize, get_vitl14

from diffusion_models.secondary import secondary1_wrap, secondary2_wrap
from diffusion_models.antijpeg import jpeg_wrap, jpeg_classifier_wrap
from diffusion_models.pixelart import pixelartv4_wrap, pixelartv6_wrap
from diffusion_models.pixelartv7 import pixelartv7_ic_wrap, pixelartv7_ic_attn_wrap
from diffusion_models.cc12m_1 import cc12m_1_wrap, cc12m_1_cfg_wrap
from diffusion_models.openai import make_openai_model, make_openai_finetune_model

In [None]:
devices = jax.devices()
n_devices = len(devices)
print('Using device:', devices)

In [None]:
# Define necessary functions

def fetch(url_or_path):
    if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
        r = requests.get(url_or_path)
        r.raise_for_status()
        fd = io.BytesIO()
        fd.write(r.content)
        fd.seek(0)
        return fd
    return open(url_or_path, 'rb')

def fetch_model(url_or_path):
    basename = os.path.basename(url_or_path)
    local_path = os.path.join(model_location, basename)
    if os.path.exists(local_path):
        return local_path
    else:
        os.makedirs(f'{model_location}/tmp', exist_ok=True)
        Popen(['curl', url_or_path, '-o', f'{model_location}/tmp/{basename}']).wait()
        os.rename(f'{model_location}/tmp/{basename}', local_path)
        return local_path

# Implement lazy loading and caching of model parameters for all the different models.

gpu_cache = WeakCache(jnp.array)

def to_gpu(params):
  """Convert a pytree of params to jax, using cached arrays if they are still alive."""
  return jax.tree_util.tree_map(lambda x: gpu_cache(x) if type(x) is np.ndarray else x, params)

class LazyParams(object):
  """Lazily download parameters and load onto gpu. Parameters are kept in cpu memory and only loaded to gpu as long as needed."""
  def __init__(self, load):
    self.load = load
    self.params = None
  @staticmethod
  def pt(url, key=None):
    def load():
      params = jaxtorch.pt.load(fetch_model(url))
      if key is not None:
        return params[key]
      else:
        return params
    return LazyParams(load)
  def __call__(self):
    if self.params is None:
      self.params = jax.tree_util.tree_map(np.array, self.load())
    return to_gpu(self.params)


def grey(image):
    [*_, c, h, w] = image.shape
    return jnp.broadcast_to(image.mean(axis=-3, keepdims=True), image.shape)

def cutout_image(image, offsetx, offsety, size, output_size=224):
    """Computes (square) cutouts of an image given x and y offsets and size."""
    (c, h, w) = image.shape

    scale = jnp.stack([output_size / size, output_size / size])
    translation = jnp.stack([-offsety * output_size / size, -offsetx * output_size / size])
    return jax.image.scale_and_translate(image,
                                         shape=(c, output_size, output_size),
                                         spatial_dims=(1,2),
                                         scale=scale,
                                         translation=translation,
                                         method='lanczos3')

def cutouts_images(image, offsetx, offsety, size, output_size=224):
    f = partial(cutout_image, output_size=output_size)         # [c h w] [] [] [] -> [c h w]
    f = jax.vmap(f, in_axes=(0, None, None, None), out_axes=0) # [n c h w] [] [] [] -> [n c h w]
    f = jax.vmap(f, in_axes=(None, 0, 0, 0), out_axes=0)       # [n c h w] [k] [k] [k] -> [k n c h w]
    return f(image, offsetx, offsety, size)

@jax.tree_util.register_pytree_node_class
class MakeCutouts(object):
    def __init__(self, cut_size, cutn, cut_pow=1., p_grey=0.2, p_mixgrey=0.0):
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow
        self.p_grey = p_grey
        self.p_mixgrey = p_mixgrey

    def __call__(self, input, key):
        [b, c, h, w] = input.shape
        rng = PRNG(key)
        max_size = min(h, w)
        min_size = min(h, w, self.cut_size)
        cut_us = jax.random.uniform(rng.split(), shape=[self.cutn//2])**self.cut_pow
        sizes = (min_size + cut_us * (max_size - min_size + 1)).astype(jnp.int32).clamp(min_size, max_size)
        offsets_x = jax.random.uniform(rng.split(), [self.cutn//2], minval=0, maxval=w - sizes)
        offsets_y = jax.random.uniform(rng.split(), [self.cutn//2], minval=0, maxval=h - sizes)
        cutouts = cutouts_images(input, offsets_x, offsets_y, sizes)

        B1 = 40
        B2 = 40
        lcut_us = jax.random.uniform(rng.split(), shape=[self.cutn//2])
        border = B1 + lcut_us * B2
        lsizes = (max(h,w) + border).astype(jnp.int32)
        loffsets_x = jax.random.uniform(rng.split(), [self.cutn//2], minval=w/2-lsizes/2-border, maxval=w/2-lsizes/2+border)
        loffsets_y = jax.random.uniform(rng.split(), [self.cutn//2], minval=h/2-lsizes/2-border, maxval=h/2-lsizes/2+border)
        lcutouts = cutouts_images(input, loffsets_x, loffsets_y, lsizes)

        cutouts = jnp.concatenate([cutouts, lcutouts], axis=0)

        greyed = grey(cutouts)

        # Partial greyscale augmentation
        grey_us = jax.random.uniform(rng.split(), shape=[self.cutn, b, 1, 1, 1])
        grey_rs = jax.random.uniform(rng.split(), shape=[self.cutn, b, 1, 1, 1])
        cutouts = jnp.where(grey_us < self.p_mixgrey, grey_rs * greyed + (1 - grey_rs) * cutouts, cutouts)

        # Greyscale augmentation
        grey_us = jax.random.uniform(rng.split(), shape=[self.cutn, b, 1, 1, 1])
        cutouts = jnp.where(grey_us < self.p_grey, greyed, cutouts)

        # Flip augmentation
        flip_us = jax.random.bernoulli(rng.split(), 0.5, [self.cutn, b, 1, 1, 1])
        cutouts = jnp.where(flip_us, jnp.flip(cutouts, axis=-1), cutouts)
        return cutouts

    def tree_flatten(self):
        return ([self.p_grey, self.cut_pow, self.p_mixgrey], (self.cut_size, self.cutn))

    @staticmethod
    def tree_unflatten(static, dynamic):
        (cut_size, cutn) = static
        (p_grey, cut_pow, p_mixgrey) = dynamic
        return MakeCutouts(cut_size, cutn, cut_pow, p_grey, p_mixgrey)

@jax.tree_util.register_pytree_node_class
class MakeCutoutsPixelated(object):
    def __init__(self, make_cutouts, factor=4):
        self.make_cutouts = make_cutouts
        self.factor = factor
        self.cutn = make_cutouts.cutn

    def __call__(self, input, key):
        [n, c, h, w] = input.shape
        input = jax.image.resize(input, [n, c, h*self.factor, w * self.factor], method='nearest')
        return self.make_cutouts(input, key)

    def tree_flatten(self):
        return ([self.make_cutouts], [self.factor])
    @staticmethod
    def tree_unflatten(static, dynamic):
        return MakeCutoutsPixelated(*dynamic, *static)

def spherical_dist_loss(x, y):
    x = norm1(x)
    y = norm1(y)
    return (x - y).square().sum(axis=-1).sqrt().div(2).arcsin().square().mul(2)


In [None]:
# Define combinators.

# These (ab)use the jax pytree registration system to define parameterised
# objects for doing various things, which are compatible with jax.jit.

# For jit compatibility an object needs to act as a pytree, which means implementing two methods:
#  - tree_flatten(self): returns two lists of the object's fields:
#       1. 'dynamic' parameters: things which can be jax tensors, or other pytrees
#       2. 'static' parameters: arbitrary python objects, will trigger recompilation when changed
#  - tree_unflatten(static, dynamic): reconstitutes the object from its parts

# With these tricks, you can simply define your cond_fn as an object, as is done
# below, and pass it into the jitted sample step as a regular argument. JAX will
# handle recompiling the jitted code whenever a control-flow affecting parameter
# is changed (such as cut_batches).

@jax.tree_util.register_pytree_node_class
class LerpModels(object):
    """Linear combination of diffusion models."""
    def __init__(self, models):
        self.models = models
    def __call__(self, x, t, key):
        outputs = [m(x,t,key) for (m,w) in self.models]
        v = sum(out.v * w for (out, (m,w)) in zip(outputs, self.models))
        pred = sum(out.pred * w for (out, (m,w)) in zip(outputs, self.models))
        eps = sum(out.eps * w for (out, (m,w)) in zip(outputs, self.models))
        return DiffusionOutput(v, pred, eps)
    def tree_flatten(self):
        return [self.models], []
    def tree_unflatten(static, dynamic):
        return LerpModels(*dynamic)

@jax.tree_util.register_pytree_node_class
class KatModel(object):
    def __init__(self, model, params, **kwargs):
      if isinstance(params, LazyParams):
        params = params()
      self.model = model
      self.params = params
      self.kwargs = kwargs
    @jax.jit
    def __call__(self, x, cosine_t, key):
        n = x.shape[0]
        alpha, sigma = cosine.to_alpha_sigma(cosine_t)
        v = self.model.apply(self.params, key, x, cosine_t.broadcast_to([n]), self.kwargs)
        pred = x * alpha - v * sigma
        eps = x * sigma + v * alpha
        return DiffusionOutput(v, pred, eps)
    def tree_flatten(self):
        return [self.params, self.kwargs], [self.model]
    def tree_unflatten(static, dynamic):
        [params, kwargs] = dynamic
        [model] = static
        return KatModel(model, params, **kwargs)

# A wrapper that causes the diffusion model to generate tileable images, by
# randomly shifting the image with wrap around.

def xyroll(x, shifts):
  return jax.vmap(partial(jnp.roll, axis=[1,2]), in_axes=(0, 0))(x, shifts)

@make_partial
def TilingModel(model, x, cosine_t, key):
  rng = PRNG(key)
  [n, c, h, w] = x.shape
  shift = jax.random.randint(rng.split(), [n, 2], -50, 50)
  x = xyroll(x, shift)
  out = model(x, cosine_t, rng.split())
  def unshift(val):
    return xyroll(val, -shift)
  return jax.tree_util.tree_map(unshift, out)

@make_partial
def PanoramaModel(model, x, cosine_t, key):
  rng = PRNG(key)
  [n, c, h, w] = x.shape
  shift = jax.random.randint(rng.split(), [n, 2], 0, [1, w])
  x = xyroll(x, shift)
  out = model(x, cosine_t, rng.split())
  def unshift(val):
    return xyroll(val, -shift)
  return jax.tree_util.tree_map(unshift, out)


In [None]:
# Make a video of the progress after diffusion is complete.
!mkdir imagesteps
secondsOfVideo = 16


def make_video(batchnum):
    videoOutputFolder = f"{save_location}/videos/"
    os.makedirs(videoOutputFolder, exist_ok=True)
    timestring = time.strftime('%Y%m%d%H%M%S')
    totalFrames = steps
    totalFrames -= 1
    videoName = f'{all_title}_{batchnum}_{timestring}'
    frames = []
    fps = 15
    if not custom_fps:
        fps = steps//secondsOfVideo

    tqdm.write(f'Generating video for batch {batchnum}...')
    for i in range(totalFrames): 
        frames.append(Image.open(f"/content/imagesteps/{batchnum}/"+str(i)+'.png'))
    p = Popen(['ffmpeg', '-y', '-f', 'image2pipe', '-vcodec', 'png', '-r', str(fps), '-i', '-', '-vcodec', 'libx264', '-r', str(fps), '-pix_fmt', 'yuv420p', '-crf', '17', '-preset', 'medium', f'video_{videoName}.mp4'], stdin=PIPE)
    for im in tqdm(frames):
        im.save(p.stdin, 'PNG')
    p.stdin.close()
    p.wait()

    !rm /content/imagesteps/$batchnum/*.png
    
    shutil.move(f"/content/video_{videoName}.mp4", f"{videoOutputFolder}video_{videoName}.mp4")

# Models & Parameters

In [None]:
# Secondary Model
secondary1_params = LazyParams.pt('https://v-diffusion.s3.us-west-2.amazonaws.com/secondary_model_imagenet.pth')
secondary2_params = LazyParams.pt('https://v-diffusion.s3.us-west-2.amazonaws.com/secondary_model_imagenet_2.pth')

# Anti-JPEG model
jpeg_params = LazyParams.pt('https://set.zlkj.in/models/diffusion/jpeg-db-oi-614.pt', key='params_ema')
jpeg_classifier_params = LazyParams.pt('https://set.zlkj.in/models/diffusion/jpeg-classifier-72.pt', 'params_ema')

# Pixel art model
# There are many checkpoints supported with this model
pixelartv4_params = LazyParams.pt(
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v4_34.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v4_63.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v4_150.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v5_50.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v5_65.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v5_97.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v5_173.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-fgood_344.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-fgood_432.pt'
    'https://set.zlkj.in/models/diffusion/pixelart/pixelart-fgood_600.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-fgood_700.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-fgood_800.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-fgood_1000.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-fgood_2000.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-fgood_3000.pt'
    , key='params_ema'
)

pixelartv6_params = LazyParams.pt(
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v6-1000.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v6-2000.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v6-3000.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v6-4000.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v6-aug-900.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v6-aug-1300.pt'
    'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v6-aug-3000.pt'
    , key='params_ema'
)

pixelartv7_ic_params = LazyParams.pt(
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v6-ic-1400.pt'
    'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v7-large-ic-700.pt'
    , key='params_ema'
)

pixelartv7_ic_attn_params = LazyParams.pt(
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v6-ic-1400.pt'
    'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v7-large-ic-attn-600.pt'
    , key='params_ema'
)

# Kat models

danbooru_128_model = v_diffusion.get_model('danbooru_128')
danbooru_128_params = LazyParams(lambda: v_diffusion.load_params(fetch_model('https://v-diffusion.s3.us-west-2.amazonaws.com/danbooru_128.pkl')))

wikiart_256_model = v_diffusion.get_model('wikiart_256')
wikiart_256_params = LazyParams(lambda: v_diffusion.load_params(fetch_model('https://v-diffusion.s3.us-west-2.amazonaws.com/wikiart_256.pkl')))

wikiart_128_model = v_diffusion.get_model('wikiart_128')
wikiart_128_params = LazyParams(lambda: v_diffusion.load_params(fetch_model('https://v-diffusion.s3.us-west-2.amazonaws.com/wikiart_128.pkl')))

imagenet_128_model = v_diffusion.get_model('imagenet_128')
imagenet_128_params = LazyParams(lambda: v_diffusion.load_params(fetch_model('https://v-diffusion.s3.us-west-2.amazonaws.com/imagenet_128.pkl')))

# CC12M_1 model

cc12m_1_params = LazyParams.pt('https://v-diffusion.s3.us-west-2.amazonaws.com/cc12m_1.pth')
cc12m_1_cfg_params = LazyParams.pt('https://v-diffusion.s3.us-west-2.amazonaws.com/cc12m_1_cfg.pth')

# OpenAI models.

use_checkpoint = False # Set to True to save some memory

openai_512_model = openai.create_openai_512_model(use_checkpoint=use_checkpoint)
openai_512_params = openai_512_model.init_weights(jax.random.PRNGKey(0))
openai_512_params = LazyParams.pt('https://set.zlkj.in/models/diffusion/512x512_diffusion_uncond_finetune_008100.pt')
openai_512_wrap = make_openai_model(openai_512_model)

openai_256_model = openai.create_openai_256_model(use_checkpoint=use_checkpoint)
openai_256_params = openai_256_model.init_weights(jax.random.PRNGKey(0))
openai_256_params = LazyParams.pt('https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt')
openai_256_wrap = make_openai_model(openai_256_model)

openai_512_finetune_wrap = make_openai_finetune_model(openai_512_model)
openai_512_finetune_params = LazyParams.pt('https://set.zlkj.in/models/diffusion/512x512_diffusion_uncond_openimages_epoch28_withfilter.pt')

# Aesthetic Model

def apply_partial(*args, **kwargs):
  def sub(f):
    return Partial(f, *args, **kwargs)
  return sub

aesthetic_model = nn.Linear(512, 10)
aesthetic_model.init_weights(jax.random.PRNGKey(0))
aesthetic_model_params = jaxtorch.pt.load(fetch_model('https://v-diffusion.s3.us-west-2.amazonaws.com/ava_vit_b_16_full.pth'))

def exec_aesthetic_model(params, embed):
  return jax.nn.log_softmax(aesthetic_model(Context(params, None), embed), axis=-1)
exec_aesthetic_model = Partial(exec_aesthetic_model, aesthetic_model_params)


In [None]:
# Losses and cond fn.

@make_partial
@apply_partial(exec_aesthetic_model)
def AestheticLoss(exec_aesthetic_model, target, scale, image_embeds):
    [k, n, d] = image_embeds.shape
    log_probs = exec_aesthetic_model(image_embeds)
    return -(scale * log_probs[:, :, target-1].mean(0)).sum()

@make_partial
@apply_partial(exec_aesthetic_model)
def AestheticExpected(exec_aesthetic_model, scale, image_embeds):
    [k, n, d] = image_embeds.shape
    probs = jax.nn.softmax(exec_aesthetic_model(image_embeds))
    expected = (probs * (1 + jnp.arange(10))).sum(-1)
    return -(scale * expected.mean(0)).sum()

@jax.tree_util.register_pytree_node_class
class CondCLIP(object):
    """Backward a loss function through clip."""
    def __init__(self, perceptor, make_cutouts, cut_batches, *losses):
        self.perceptor = perceptor
        self.make_cutouts = make_cutouts
        self.cut_batches = cut_batches
        self.losses = losses
    def __call__(self, x_in, key):
        n = x_in.shape[0]
        def main_clip_loss(x_in, key):
            cutouts = normalize(self.make_cutouts(x_in.add(1).div(2), key)).rearrange('k n c h w -> (k n) c h w')
            image_embeds = self.perceptor.embed_cutouts(cutouts)
            image_embeds = image_embeds.rearrange('(k n) c -> k n c', k=self.make_cutouts.cutn, n=n)
            return sum(loss_fn(image_embeds) for loss_fn in self.losses)
        num_cuts = self.cut_batches
        keys = jnp.stack(jax.random.split(key, num_cuts))
        main_clip_grad = jax.lax.scan(lambda total, key: (total + jax.grad(main_clip_loss)(x_in, key), key),
                                        jnp.zeros_like(x_in),
                                        keys)[0] / num_cuts
        return main_clip_grad
    def tree_flatten(self):
        return [self.perceptor, self.make_cutouts, self.losses], [self.cut_batches]
    @classmethod
    def tree_unflatten(cls, static, dynamic):
        [perceptor, make_cutouts, losses] = dynamic
        [cut_batches] = static
        return cls(perceptor, make_cutouts, cut_batches, *losses)

@make_partial
def SphericalDistLoss(text_embed, clip_guidance_scale, image_embeds):
    losses = spherical_dist_loss(image_embeds, text_embed).mean(0)
    return (clip_guidance_scale * losses).sum()

@make_partial
def InfoLOOB(text_embed, clip_guidance_scale, inv_tau, lm, image_embeds):
    all_image_embeds = norm1(image_embeds.mean(0))
    all_text_embeds = norm1(text_embed)
    sim_matrix = inv_tau * jnp.einsum('nc,mc->nm', all_image_embeds, all_text_embeds)
    xn = sim_matrix.shape[0]
    def loob(sim_matrix):
      diag = jnp.eye(xn) * sim_matrix
      off_diag = (1 - jnp.eye(xn))*sim_matrix + jnp.eye(xn) * float('-inf')
      return -diag.sum() + lm * jsp.special.logsumexp(off_diag, axis=-1).sum()
    losses = loob(sim_matrix) + loob(sim_matrix.transpose())
    return losses.sum() * clip_guidance_scale.mean() / inv_tau

@make_partial
def CondTV(tv_scale, x_in, key):
    def downscale2d(image, f):
        [c, n, h, w] = image.shape
        return jax.image.resize(image, [c, n, h//f, w//f], method='cubic')

    def tv_loss(input):
        """L2 total variation loss, as in Mahendran et al."""
        x_diff = input[..., :, 1:] - input[..., :, :-1]
        y_diff = input[..., 1:, :] - input[..., :-1, :]
        return x_diff.square().mean([1,2,3]) + y_diff.square().mean([1,2,3])

    def sum_tv_loss(x_in, f=None):
        if f is not None:
            x_in = downscale2d(x_in, f)
        return tv_loss(x_in).sum() * tv_scale
    tv_grad_512 = jax.grad(sum_tv_loss)(x_in)
    tv_grad_256 = jax.grad(partial(sum_tv_loss,f=2))(x_in)
    tv_grad_128 = jax.grad(partial(sum_tv_loss,f=4))(x_in)
    return tv_grad_512 + tv_grad_256 + tv_grad_128

@make_partial
def CondRange(range_scale, x_in, key):
    def loss(x_in):
        return jnp.abs(x_in - x_in.clamp(minval=-1,maxval=1)).mean()
    return range_scale * jax.grad(loss)(x_in)

@make_partial
def CondMSE(target, mse_scale, x_in, key):
    def mse_loss(x_in):
        return (x_in - target).square().mean()
    return mse_scale * jax.grad(mse_loss)(x_in)

@jax.tree_util.register_pytree_node_class
class MaskedMSE(object):
    # MSE loss. Targets the output towards an image.
    def __init__(self, target, mse_scale, mask, grey=False):
        self.target = target
        self.mse_scale = mse_scale
        self.mask = mask
        self.grey = grey
    def __call__(self, x_in, key):
        def mse_loss(x_in):
            if self.grey:
              return (self.mask * grey(x_in - self.target).square()).mean()
            else:
              return (self.mask * (x_in - self.target).square()).mean()
        return self.mse_scale * jax.grad(mse_loss)(x_in)
    def tree_flatten(self):
        return [self.target, self.mse_scale, self.mask], [self.grey]
    def tree_unflatten(static, dynamic):
        return MaskedMSE(*dynamic, *static)


@jax.tree_util.register_pytree_node_class
class MainCondFn(object):
    # Used to construct the main cond_fn. Accepts a diffusion model which will
    # be used for denoising, plus a list of 'conditions' which will
    # generate gradient of a loss wrt the denoised, to be summed together.
    def __init__(self, diffusion, conditions, blur_amount=None, use='pred'):
        self.diffusion = diffusion
        self.conditions = [c for c in conditions if c is not None]
        self.blur_amount = blur_amount
        self.use = use

    @jax.jit
    def __call__(self, key, x, cosine_t):
        rng = PRNG(key)
        n = x.shape[0]

        alphas, sigmas = cosine.to_alpha_sigma(cosine_t)

        def denoise(key, x):
            pred = self.diffusion(x, cosine_t, key).pred
            if self.use == 'pred':
                return pred
            elif self.use == 'x_in':
                return pred * sigmas + x * alphas
        (x_in, backward) = jax.vjp(partial(denoise, rng.split()), x)

        total = jnp.zeros_like(x_in)
        for cond in self.conditions:
            total += cond(x_in, rng.split())
        if self.blur_amount is not None:
          blur_radius = (self.blur_amount * sigmas / alphas).clamp(0.05,512)
          total = blur_fft(total, blur_radius.mean())
        final_grad = -backward(total)[0]

        # clamp gradients to a max of 0.2
        magnitude = final_grad.square().mean(axis=(1,2,3), keepdims=True).sqrt()
        final_grad = final_grad * jnp.where(magnitude > 0.2, 0.2 / magnitude, 1.0)
        return final_grad
    def tree_flatten(self):
        return [self.diffusion, self.conditions, self.blur_amount], [self.use]
    def tree_unflatten(static, dynamic):
        return MainCondFn(*dynamic, *static)


@jax.tree_util.register_pytree_node_class
class CondFns(object):
    def __init__(self, *conditions):
        self.conditions = conditions
    def __call__(self, key, x, t):
        rng = PRNG(key)
        total = jnp.zeros_like(x)
        for cond in self.conditions:
          total += cond(rng.split(), x, t)
        return total
    def tree_flatten(self):
        return [self.conditions], []
    def tree_unflatten(static, dynamic):
        [conditions] = dynamic
        return CondFns(*conditions)

def clamp_score(score):
  magnitude = score.square().mean(axis=(1,2,3), keepdims=True).sqrt()
  return score * jnp.where(magnitude > 0.1, 0.1 / magnitude, 1.0)


@make_partial
def BlurRangeLoss(scale, key, x, cosine_t):
    def blurred_pred(x, cosine_t):
      alpha, sigma = cosine.to_alpha_sigma(cosine_t)
      blur_radius = (sigma / alpha * 2).clamp(0.05,512)
      return blur_fft(x, blur_radius) / alpha.clamp(0.01)
    def loss(x):
        pred = blurred_pred(x, cosine_t)
        diff = pred - pred.clamp(minval=-1,maxval=1)
        return diff.square().sum()
    return clamp_score(-scale * jax.grad(loss)(x))


In [None]:
def sample_step(key, x, t1, t2, diffusion, cond_fn, eta):
    rng = PRNG(key)

    n = x.shape[0]
    alpha1, sigma1 = cosine.to_alpha_sigma(t1)
    alpha2, sigma2 = cosine.to_alpha_sigma(t2)

    # Run the model
    out = diffusion(x, t1, rng.split())
    eps = out.eps
    pred0 = out.pred

    # # Predict the denoised image
    # pred0 = (x - eps * sigma1) / alpha1

    # Adjust eps with conditioning gradient
    cond_score = cond_fn(rng.split(), x, t1)
    eps = eps - sigma1 * cond_score

    # Predict the denoised image with conditioning
    pred = (x - eps * sigma1) / alpha1

    # Negative eta allows more extreme levels of noise.
    ddpm_sigma = (sigma2**2 / sigma1**2).sqrt() * (1 - alpha1**2 / alpha2**2).sqrt()
    ddim_sigma = jnp.where(eta >= 0.0,
                           eta * ddpm_sigma, # Normal: eta interpolates between ddim and ddpm
                           -eta * sigma2)    # Extreme: eta interpolates between ddim and q_sample(pred)
    adjusted_sigma = (sigma2**2 - ddim_sigma**2).sqrt()

    # Recombine the predicted noise and predicted denoised image in the
    # correct proportions for the next step
    x = pred * alpha2 + eps * adjusted_sigma

    # Add the correct amount of fresh noise
    x += jax.random.normal(rng.split(), x.shape) * ddim_sigma
    return x, pred0

def process_prompt(clip, prompt):
  # Brace expansion might change later, not sure this is the best way to do it.
  expands = braceexpand(prompt)
  embeds = []
  for sub in expands:
    mult = 1.0
    if '~' in sub:
      mult *= -1.0
    sub = sub.replace('~', '')
    embeds.append(mult * clip.embed_text(sub))
  return norm1(sum(embeds))

def process_prompts(clip, prompts):
  return jnp.stack([process_prompt(clip, prompt) for prompt in prompts])

def expand(xs, batch_size):
  """Extend or truncate the list of prompts to the batch size."""
  return (xs * batch_size)[:batch_size]

# Configuration for the run (tutorial)



Some paramaters are not listed here, simply because their functionality is obvious.

- `seed`: Allows reproductability of an image. `None` is for a random seed.
- `use_model`: which diffusion model to use (the usual model in JAX is `openai`)
- `batch_size`: how many images to render next to eachother
- `n_batches`: how many times to run the current prompt
- `clip_guidance_scale`: how much the image should look like the prompt
- `cfg_guidance_scale`: same as `clip_guidance_scale` but for cc12m_cfg
- `tv_scale`: controls the smoothness of the image
- `sat_scale`: controls the saturation of the image
- `cutn`: amount of cuts clip does
- `cut_batches` -> multiplier for `cutn` (`cutn` * `cut_batches`) 
- `intermediate_step_saves:` Saves different steps of the current rendering image. 
example: `[50, 225]` will save an image of step 50 and 225. It will always save the final interation of the image with `_final` at the end of the image name. Leave this paramater as `[]` to not save any intermediate steps.
- `secondsOfVideo`: set a fixed length of the progress video, fps will be automatically calculated.
- With `custom_fps` enabled, `fps` is a fixed frame speed for the progress video. The video length will be `steps / fps`



The batch settings will apply individually to each prompt in the queue.


# Configuration for the run

In [None]:
image_size = (640, 512) #@param {type:"raw"}
batch_size = 1 #@param {type:"integer"}
n_batches =  1#@param {type:"integer"}
use_model = 'openai' #@param ["openai", "openai_finetune",  "cc12m_1", "cc12m_1_cfg", "wikiart_256", "wikiart_128", "danbooru_128", "imagenet_128", "pixelartv4", "pixelartv6", "pixelartv7_ic_attn"]
steps = 250 #@param {type:"integer"}    # Number of steps for sampling. Generally, more = better.

#@markdown ---
#@markdown Separate prompts to queue by a ``|``
prompts = "sharp cyberpunk robots :1 | Pop art bold pastel colors : 1 |  strong sharp  twisted  | Max Ernst  | trending in artstation:1 |metallic:-0.1" #@param {type:"string"}
prompts = [prompt.strip() for prompt in prompts.split("|") if prompt != ""]
switch_seed_per_prompt = True #@param {type:"boolean"}



ic_cond = 'https://irc.zlkj.in/uploads/eebeaf1803e898ac/88552154_p0%20-%20Coral.png'
# 'https://irc.zlkj.in/uploads/eebeaf1803e898ac/88552154_p0%20-%20Coral.png'
# 'https://cdn.discordapp.com/emojis/916943952597360690.png?size=240&quality=lossless' # pizagal

clip_guidance_scale =  40000#@param {type:"integer"}
clip_guidance_scale = jnp.array([clip_guidance_scale]*batch_size) # Note: with two perceptors, effective guidance scale is ~2x because they are added together.
cfg_guidance_scale = 2.0  #@param {type:"number"}
tv_scale =   150#@param {type:"integer"}# Smooths out the image
range_scale =  600#@param {type:"integer"}# Tries to prevent pixel values from going out of range
cutn =         8#@param {type:"integer"}# Effective cutn is cut_batches * this
cut_pow = 1.0   # Affects the size of cutouts. Larger cut_pow -> smaller cutouts (down to the min of 224x244)
cut_batches = 4 #@param {type:"integer"}
make_cutouts = MakeCutouts(clip_size, cutn, cut_pow=cut_pow, p_mixgrey=0.0)


eta = 1.0       # 0.0: DDIM | 1.0: DDPM | -1.0: Extreme noise (q_sample)
#@markdown ---
init_image = "/content/drive/MyDrive/AI/nshepv2/init/robotLarge01A.png"    #@param {type:"string"}
starting_noise = 1.0  #@param {type:"number"} # Between 0 and 1. When using init image, generally 0.5-0.8 is good. Lower starting noise makes the result look more like the init.
init_weight_mse = 0    # MSE loss between the output and the init makes the result look more like the init (should be between 0 and width*height*3).


#@markdown ---
saveVideo = False #@param {type:"boolean"}
#@markdown For automatic FPS:
secondsOfVideo =  7#@param {type:"number"}
#@markdown For manual FPS:
custom_fps = False #@param {type:"boolean"}
fps =  14#@param {type:"number"}



for k in range(batch_size): # creates folder that will hold steps for video
        imagestepsFolder = f'/content/imagesteps/{k}'
        os.makedirs(imagestepsFolder, exist_ok=True)





schedule = jnp.linspace(starting_noise, 0, steps+1)
schedule = spliced.to_cosine(schedule)

if init_image == "None": init_image = None




def load_image(url):
    init_array = Image.open(fetch(url)).convert('RGB')
    init_array = init_array.resize(image_size, Image.LANCZOS)
    init_array = jnp.array(TF.to_tensor(init_array)).unsqueeze(0).mul(2).sub(1)
    return init_array
if type(init_image) is str:
    init_array = jnp.concatenate([load_image(it) for it in braceexpand(init_image)], axis=0)
else:
    init_array = None

def config():
    # Configure models and load parameters onto gpu.
    # We do this in a function to avoid leaking gpu memory.
    if use_model == 'openai':
        # -- Openai with anti-jpeg --
        openai = openai_512_wrap(openai_512_params())
        secondary2 = secondary2_wrap(secondary2_params())
        jpeg_0 = jpeg_wrap(jpeg_params(), cond=jnp.array([0]*batch_size)) # Clean class
        jpeg_1 = jpeg_wrap(jpeg_params(), cond=jnp.array([2]*batch_size)) # Unconditional class

        
        jpeg_classifier_fn = jpeg_classifier_wrap(jpeg_classifier_params(),
                                                  guidance_scale=10000.0, # will generally depend on image size
                                                  flood_level=0.7, # Prevent over-optimization
                                                  blur_size=3.0)
        

        diffusion = LerpModels([(openai, 1.0),
                                (jpeg_0, 1.0),
                                (jpeg_1, -1.0)])
        cond_model = secondary2

        cond_fn = CondFns(MainCondFn(cond_model, [
            CondCLIP(vit32, make_cutouts, cut_batches, SphericalDistLoss(process_prompts(vit32, title), clip_guidance_scale)),
            CondCLIP(vit16, make_cutouts, cut_batches, SphericalDistLoss(process_prompts(vit16, title), clip_guidance_scale)),
            CondTV(tv_scale) if tv_scale > 0 else None,
            CondMSE(init_array, init_weight_mse) if init_weight_mse > 0 else None,
            CondRange(range_scale) if range_scale > 0 else None,
        ]), jpeg_classifier_fn)

    elif use_model in ('wikiart_256', 'wikiart_128', 'danbooru_128', 'imagenet_128'):
        if use_model == 'wikiart_256':
            diffusion = KatModel(wikiart_256_model, wikiart_256_params())
        elif use_model == 'wikiart_128':
            diffusion = KatModel(wikiart_128_model, wikiart_128_params())
        elif use_model == 'danbooru_128':
            diffusion = KatModel(danbooru_128_model, danbooru_128_params())
        elif use_model == 'imagenet_128':
            diffusion = KatModel(imagenet_128_model, imagenet_128_params())
        cond_model = diffusion
        cond_fn = MainCondFn(cond_model, [
                    CondCLIP(vit32, make_cutouts, cut_batches,
                             SphericalDistLoss(process_prompts(vit32, title), clip_guidance_scale)),
                    CondCLIP(vit16, make_cutouts, cut_batches,
                             SphericalDistLoss(process_prompts(vit16, title), clip_guidance_scale)),
                    CondTV(tv_scale) if tv_scale > 0 else None,
                    CondMSE(init_array, init_weight_mse) if init_weight_mse > 0 else None,
                    CondRange(range_scale) if range_scale > 0 else None,
                    ])

    elif 'pixelart' in use_model:
        if use_model == 'pixelartv7_ic_attn':
            # -- pixel art model --
            cond = jnp.array(TF.to_tensor(Image.open(fetch(ic_cond)).convert('RGB').resize(image_size,Image.BICUBIC))) * 2 - 1
            cond = cond.broadcast_to([batch_size, 3, image_size[1], image_size[0]])
            diffusion = pixelartv7_ic_attn_wrap(pixelartv7_ic_attn_params(), cond=cond, cfg_guidance_scale=cfg_guidance_scale)
        elif use_model == 'pixelartv6':
            diffusion = pixelartv6_wrap(pixelartv6_params())
        elif use_model == 'pixelartv4':
            diffusion = pixelartv4_wrap(pixelartv4_params())
        cond_model = diffusion
        cond_fn = MainCondFn(cond_model, [
            CondCLIP(vit32, MakeCutoutsPixelated(make_cutouts), cut_batches,
                     SphericalDistLoss(process_prompts(vit32, title), clip_guidance_scale)),
            CondCLIP(vit16, MakeCutoutsPixelated(make_cutouts), cut_batches,
                     SphericalDistLoss(process_prompts(vit16, title), clip_guidance_scale)),
            CondMSE(init_array, init_weight_mse) if init_weight_mse > 0 else None,
        ])

    elif use_model == 'cc12m_1_cfg':
        diffusion = cc12m_1_cfg_wrap(cc12m_1_cfg_params(), clip_embed=vit16.embed_texts(title), cfg_guidance_scale=cfg_guidance_scale)
        cond_fn = CondFns()
    
    elif use_model == 'cc12m_1':
        diffusion = cc12m_1_wrap(cc12m_1_params(), clip_embed=vit16.embed_texts(title))
        cond_model = diffusion
        cond_fn = MainCondFn(cond_model, [
                    CondCLIP(vit32, make_cutouts, cut_batches, SphericalDistLoss(process_prompts(vit32, title), clip_guidance_scale)),
                    CondCLIP(vit16, make_cutouts, cut_batches, SphericalDistLoss(process_prompts(vit16, title), clip_guidance_scale)),
                    CondTV(tv_scale) if tv_scale > 0 else None,
                    CondMSE(init_array, init_weight_mse) if init_weight_mse > 0 else None,
                    CondRange(range_scale) if range_scale > 0 else None,
                    ])

    elif use_model == 'openai_finetune':
        diffusion = openai_512_finetune_wrap(openai_512_finetune_params())
        cond_model = secondary2_wrap(secondary2_params())
        cond_fn = CondFns(MainCondFn(cond_model, [
                    CondCLIP(vit32, make_cutouts, cut_batches, SphericalDistLoss(process_prompts(vit32, title), clip_guidance_scale)),
                    CondCLIP(vit16, make_cutouts, cut_batches,
                              SphericalDistLoss(process_prompts(vit16, title), clip_guidance_scale),
                              AestheticExpected(jnp.array([16.0,16.0,16.0,16.0]))
                             ),
                    CondTV(tv_scale) if tv_scale > 0 else None,
                    CondMSE(init_array, init_weight_mse) if init_weight_mse > 0 else None,
                    CondRange(range_scale) if range_scale > 0 else None,
                    ]))

    return diffusion, cond_fn



In [None]:
#@title Huemin's simple symmetry
#@markdown `symmetry_schedule`: percentage values of when to perform the symmetry. To perform it multilpe times, separate the values by a comma. 
def simple_symmetry(x_in):
  [n, c, h, w] = x_in.shape
  x_in = jnp.concatenate([x_in[:, :, :, :w//2], jnp.flip(x_in[:, :, :, :w//2],-1)], -1)
  return(x_in)

#@markdown Symmetry Settings
use_symmetry = True #@param {type:"boolean"}
symmetry_schedule = "0.01,0.5"#@param {type:"string"}

symmetry_percents = [float(vals) for vals in symmetry_schedule.split(",")]

# Do the run

In [None]:
#@markdown If `switch_seed_per_prompt` is enabled, `seed` will be a random number, even if you set it to a specific number.

seed = None #@param # if None, uses the current time in seconds.
display_frequency = 50 #@param {type:"number"}
intermediate_saves = 100,200 #@param

def sanitize(title):
  return title[:100].replace('/', '_').replace('\\', '_')

@torch.no_grad()
def run():

    rng = PRNG(jax.random.PRNGKey(local_seed))

    for i in range(n_batches):
        timestring = time.strftime('%Y%m%d%H%M%S')

        ts = schedule
        alphas, sigmas = cosine.to_alpha_sigma(ts)

        #print(ts[0], sigmas[0], alphas[0])

        x = jax.random.normal(rng.split(), [batch_size, 3, image_size[1], image_size[0]])

        if init_array is not None:
            x = sigmas[0] * x + alphas[0] * init_array

        # Creates the ./images folder in advance for intermediate image saves
        os.makedirs('samples/images', exist_ok=True)
        if save_location:
          os.makedirs(f'{save_location}/images', exist_ok=True)

        # Main loop
        local_steps = schedule.shape[0] - 1
        for j in tqdm(range(local_steps)):
            # == Panorama ==
            # shift = jax.random.randint(rng.split(), [batch_size, 2], 0, jnp.array([1, image_size[0]]))
            # x = xyroll(x, shift) 
            # == -------- ==
            if ts[j] == ts[j+1]:
              continue
            # Skip steps where the ts are the same, to make it easier to
            # make complicated schedules out of cat'ing linspaces.
            # diffusion.set(clip_embed=jax.random.normal(rng.split(), [batch_size,512]))
            
            # performs the symmetry
            if (use_symmetry) and (j in symmetry_steps) and (j != 0):
              x = simple_symmetry(x)
              print("simple symmetry")
            
            x, pred = sample_step(rng.split(), x, ts[j], ts[j+1], diffusion, cond_fn, eta)
            assert x.isfinite().all().item()
            if j % display_frequency == 0 or j == local_steps - 1:
                images = pred
                # images = jnp.concatenate([images, x], axis=0)
                images = images.add(1).div(2).clamp(0, 1)
                images = torch.tensor(np.array(images))
                display.display(TF.to_pil_image(utils.make_grid(images, 4).cpu()))
            
            if j in intermediate_saves: # saves itermediate steps
                for k in range(batch_size):
                    images = pred
                    this_title = sanitize(title[k])
                    dname = f'samples/images/{this_title}_{k}_{timestring}_{j}.png'
                    images = images.add(1).div(2).clamp(0, 1)
                    images = torch.tensor(np.array(images))
                    pil_image = TF.to_pil_image(utils.make_grid(images, 2).cpu())
                    print(f" [Intermediate Save]")
                    display.display(pil_image)
                    pil_image.save(dname)
                    if save_location:
                        pil_image.save(f'{save_location}/images/{this_title}_{k}_{timestring}_{j}.png')

            
            if saveVideo:
                for k in range(batch_size):
                    images = pred.add(1).div(2).clamp(0, 1)
                    images = torch.tensor(np.array(images))
                    stepnum = f'{j}.png' 
                    pil_image = TF.to_pil_image(images[k])
                    pil_image.save(f'/content/imagesteps/{k}/'+stepnum)
            

        # Save samples
        os.makedirs('samples/grid', exist_ok=True)
        if save_location:
          os.makedirs(f'{save_location}/grid', exist_ok=True)
        TF.to_pil_image(utils.make_grid(images, 2).cpu()).save(f'samples/grid/{timestring}_{sanitize(all_title)}.png')
        TF.to_pil_image(utils.make_grid(images, 2).cpu()).save(f'{save_location}/grid/{timestring}_{sanitize(all_title)}.png')

        for k in range(batch_size):
            this_title = sanitize(title[k])
            dname = f'samples/images/{this_title}_{k}_{timestring}_final.png'
            pil_image = TF.to_pil_image(images[k])
            pil_image.save(dname)
            if save_location:
              pil_image.save(f'{save_location}/images/{this_title}_{k}_{timestring}_final.png')
            if saveVideo:
                make_video(k)


# so that there can be the same seed for each prompt
if (seed == None) and (not switch_seed_per_prompt):
        seed = int(time.time())



for all_title in prompts:
    
    title = expand([all_title], batch_size)

    # must be moved here since the locations of all_title and title are in this loop
    diffusion, cond_fn = config() 


    # Weird seeding to account for different options
    if (switch_seed_per_prompt) or (seed == None): 
        local_seed = int(time.time())
        

    elif (not switch_seed_per_prompt) or (seed != None):
        local_seed = seed
        

    # sets when to perform the symmetry
    if use_symmetry:
        symmetry_steps = [int(steps*percent) for percent in symmetry_percents]



    print(f"Current prompt: {all_title}")
    print(f"Current seed: {local_seed}")

    try:
        run()
        success = True
    except:
        import traceback
        traceback.print_exc()
        success = False
    assert success


In [None]:
#@title zip and download all of the output images

make_zip = False #@param {type:"boolean"}

zipName = "JaxZipOutput.zip"

if make_zip:  
    if os.path.exists(zipName):
        os.remove(zipName)
    #os.system(f"zip -r -j {zipName} results/*")
    !zip -r {zipName} {save_location}/images
    files.download(zipName)
