### changelog/updates

* 8 Oct - seed sweep version
* sept 2021 - various KB tweaks

# Generate images from text phrases with VQGAN and CLIP (z+quantize method with augmentations)

* [How to use VQGAN+CLIP](https://docs.google.com/document/d/1Lu7XPRKlNhBQjcKr8k8qRzUzbBW7kzxb5Vu72GMRn2E/edit)
* Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings).


In [None]:
# @title Licensed under the MIT License

# Copyright (c) 2021 Katherine Crowson

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

In [None]:
#@title Google Drive Integration (optional)
#@markdown To connect Google Drive, set `root_path` to the relative drive folder path you want outputs to be saved to if you already made a directory, then execute this cell. Leaving the field blank or just not running this will have outputs save to the runtime temp storage.
import os
root_path = "VQ" #@param {type: "string"}
abs_root_path = "/content"
if len(root_path) > 0:
    abs_root_path = abs_root_path + "/drive/MyDrive/" + root_path
 
from google.colab import drive
drive.mount('/content/drive')
 
def ensureProperRootPath():
    if len(abs_root_path) > 0:
        os.chdir(abs_root_path) # Changes directory to absolute root path
        print("Root path check: ")
        !pwd
 
ensureProperRootPath()

In [None]:
#@title Make a new folder & set root path to that folder (optional)
#@markdown Saves a step if you don't have a folder in your Google Drive for this. Makes one, sets the root_path to that new folder. You can name it whatever you'd like:

folder_name = "VQ" #@param {type: "string"}
abs_root_path = "/content"
if len(folder_name) > 0:
    path_tmp = abs_root_path + "/drive/MyDrive/" + folder_name
    if not os.path.exists(path_tmp):
        os.mkdir(path_tmp)
    abs_root_path = path_tmp

print("Created folder & set root path to: " + abs_root_path)

#@markdown Make & assign path to a project subfolder (optional)
#@markdown _My practice:_ preload _all_ the models into this directory.

project_name = "Workspace" #@param {type: "string"}
if len(project_name) > 0:
      path_tmp = abs_root_path + "/" + project_name
      if not os.path.exists(path_tmp):
          os.mkdir(path_tmp)
      abs_root_path = path_tmp
print("Created project subfolder & set root path to: " + abs_root_path)

ensureProperRootPath()

print('Making some symlinks to work directories...')
if not os.path.exists('/content/starters'):
  ! ln -s /content/drive/MyDrive/kbImport/Pix/2021/Work2021g/starters /content/starters
if not os.path.exists('/content/Work2021g'):
  ! ln -s /content/drive/MyDrive/kbImport/Pix/2021/Work2021g /content/Work2021g

if not os.path.exists('/content/dataRoot'):
  ! ln -s {abs_root_path} /content/dataRoot
#if not os.path.exists('/content/dataProj'):
#  ! ln -s {abs_proj_path} /content/dataProj


In [None]:
# @title Setup, Installing Libraries
# @markdown This cell might take some time due to installing several libraries.

!nvidia-smi
print("Downloading CLIP...")
!git clone https://github.com/openai/CLIP                 &> /dev/null
 
print("Installing Python Libraries for AI")
!git clone https://github.com/CompVis/taming-transformers &> /dev/null
!pip install ftfy regex tqdm omegaconf pytorch-lightning  &> /dev/null
!pip install kornia                                       &> /dev/null
!pip install einops                                       &> /dev/null
print("Installing transformers library...")
!pip install transformers   
print("Installing taming.models...")   
!pip install taming.models                           &> /dev/null
 
print("Installing libraries for managing metadata...")
!pip install stegano                                      &> /dev/null
!apt install exempi                                       &> /dev/null
!pip install python-xmp-toolkit                           &> /dev/null
!pip install imgtag                                       &> /dev/null
!pip install pillow==7.1.2 
%reload_ext autoreload
%autoreload                  &> /dev/null
 
print("Installing ffmpeg for creating videos...")
!pip install imageio-ffmpeg &> /dev/null
# !mkdir steps
print("Installation finished.")

In [None]:
#@title Selection of models to download
#@markdown By default, the notebook downloads Model 16384 from ImageNet. There are others such as ImageNet 1024, COCO-Stuff, WikiArt 1024, WikiArt 16384, FacesHQ or S-FLCKR, which are not downloaded by default, since it would be in vain if you are not going to use them, so if you want to use them, simply select the models to download.

#@markdown WARNING: 
#@markdown Not all datasets are licensed for commercial use (i.e. selling your artwork as an NFT).


#@markdown Datasets you can use for non-commercial purposes:
imagenet_1024 = False #@param {type:"boolean"} 
imagenet_16384 = False #@param {type:"boolean"}
coco = False #@param {type:"boolean"}
wikiart_1024 = False #@param {type:"boolean"}
wikiart_16384 = False #@param {type:"boolean"}
#@markdown Datasets you can use for commercial purposes:
faceshq = False #@param {type:"boolean"}
sflckr = False #@param {type:"boolean"}

if imagenet_1024:
  !curl -L -o vqgan_imagenet_f16_1024.yaml -C - 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_1024.yaml' #ImageNet 1024
  !curl -L -o vqgan_imagenet_f16_1024.ckpt -C - 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_1024.ckpt'  #ImageNet 1024
if imagenet_16384:
  !curl -L -o vqgan_imagenet_f16_16384.yaml -C - 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.yaml' #ImageNet 16384
  !curl -L -o vqgan_imagenet_f16_16384.ckpt -C - 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.ckpt' #ImageNet 16384
if coco:
  !curl -L -o coco.yaml -C - 'https://dl.nmkd.de/ai/clip/coco/coco.yaml' #COCO
  !curl -L -o coco.ckpt -C - 'https://dl.nmkd.de/ai/clip/coco/coco.ckpt' #COCO
if faceshq:
  !curl -L -o faceshq.yaml -C - 'https://drive.google.com/uc?export=download&id=1fHwGx_hnBtC8nsq7hesJvs-Klv-P0gzT' #FacesHQ
  !curl -L -o faceshq.ckpt -C - 'https://app.koofr.net/content/links/a04deec9-0c59-4673-8b37-3d696fe63a5d/files/get/last.ckpt?path=%2F2020-11-13T21-41-45_faceshq_transformer%2Fcheckpoints%2Flast.ckpt' #FacesHQ
if wikiart_1024: 
  !curl -L -o wikiart_1024.yaml -C - 'http://mirror.io.community/blob/vqgan/wikiart.yaml' #WikiArt 1024
  !curl -L -o wikiart_1024.ckpt -C - 'http://mirror.io.community/blob/vqgan/wikiart.ckpt' #WikiArt 1024
if wikiart_16384: 
  !curl -L -o wikiart_16384.yaml -C - 'http://mirror.io.community/blob/vqgan/wikiart_16384.yaml' #WikiArt 16384
  !curl -L -o wikiart_16384.ckpt -C - 'http://mirror.io.community/blob/vqgan/wikiart_16384.ckpt' #WikiArt 16384
if sflckr:
  !curl -L -o sflckr.yaml -C - 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fconfigs%2F2020-11-09T13-31-51-project.yaml&dl=1' #S-FLCKR
  !curl -L -o sflckr.ckpt -C - 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fcheckpoints%2Flast.ckpt&dl=1' #S-FLCKR

In [None]:
# @title Loading libraries and definitions
 
import argparse
import math
from pathlib import Path
import sys
 
sys.path.append('./taming-transformers')
from IPython import display
from base64 import b64encode
from omegaconf import OmegaConf
from PIL import Image
from taming.models import cond_transformer, vqgan
import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm
 
from CLIP import clip
import kornia.augmentation as K
import numpy as np
import imageio
from PIL import ImageFile, ImageOps
from imgtag import ImgTag    # metadata
from libxmp import *         # metadata
import libxmp                # metadata
from stegano import lsb
import json
ImageFile.LOAD_TRUNCATED_IMAGES = True
 
def sinc(x):
    return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
 
 
def lanczos(x, a):
    cond = torch.logical_and(-a < x, x < a)
    out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
    return out / out.sum()
 
 
def ramp(ratio, width):
    n = math.ceil(width / ratio + 1)
    out = torch.empty([n])
    cur = 0
    for i in range(out.shape[0]):
        out[i] = cur
        cur += ratio
    return torch.cat([-out[1:].flip([0]), out])[1:-1]
 
 
def resample(input, size, align_corners=True):
    n, c, h, w = input.shape
    dh, dw = size
 
    input = input.view([n * c, 1, h, w])
 
    if dh < h:
        kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
        pad_h = (kernel_h.shape[0] - 1) // 2
        input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
        input = F.conv2d(input, kernel_h[None, None, :, None])
 
    if dw < w:
        kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
        pad_w = (kernel_w.shape[0] - 1) // 2
        input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
        input = F.conv2d(input, kernel_w[None, None, None, :])
 
    input = input.view([n, c, h, w])
    return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)
 
 
class ReplaceGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x_forward, x_backward):
        ctx.shape = x_backward.shape
        return x_forward
 
    @staticmethod
    def backward(ctx, grad_in):
        return None, grad_in.sum_to_size(ctx.shape)
 
 
replace_grad = ReplaceGrad.apply
 
 
class ClampWithGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, min, max):
        ctx.min = min
        ctx.max = max
        ctx.save_for_backward(input)
        return input.clamp(min, max)
 
    @staticmethod
    def backward(ctx, grad_in):
        input, = ctx.saved_tensors
        return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None
 
 
clamp_with_grad = ClampWithGrad.apply
 
 
def vector_quantize(x, codebook):
    d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T
    indices = d.argmin(-1)
    x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
    return replace_grad(x_q, x)
 
 
class Prompt(nn.Module):
    def __init__(self, embed, weight=1., stop=float('-inf')):
        super().__init__()
        self.register_buffer('embed', embed)
        self.register_buffer('weight', torch.as_tensor(weight))
        self.register_buffer('stop', torch.as_tensor(stop))
 
    def forward(self, input):
        input_normed = F.normalize(input.unsqueeze(1), dim=2)
        embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
        dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
        dists = dists * self.weight.sign()
        return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()
 
 
def parse_prompt(prompt):
    vals = prompt.rsplit(':', 2)
    vals = vals + ['', '1', '-inf'][len(vals):]
    return vals[0], float(vals[1]), float(vals[2])
 
 
class MakeCutouts(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow
        self.augs = nn.Sequential(
            K.RandomHorizontalFlip(p=0.5),
            # K.RandomSolarize(0.01, 0.01, p=0.7),
            K.RandomSharpness(0.3,p=0.4),
            K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'),
            K.RandomPerspective(0.2,p=0.4),
            K.ColorJitter(hue=0.01, saturation=0.01, p=0.7))
        self.noise_fac = 0.1
 
 
    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
        batch = self.augs(torch.cat(cutouts, dim=0))
        if self.noise_fac:
            facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
            batch = batch + facs * torch.randn_like(batch)
        return batch
 
 
def load_vqgan_model(config_path, checkpoint_path):
    config = OmegaConf.load(config_path)
    if config.model.target == 'taming.models.vqgan.VQModel':
        model = vqgan.VQModel(**config.model.params)
        model.eval().requires_grad_(False)
        model.init_from_ckpt(checkpoint_path)
    elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer':
        parent_model = cond_transformer.Net2NetTransformer(**config.model.params)
        parent_model.eval().requires_grad_(False)
        parent_model.init_from_ckpt(checkpoint_path)
        model = parent_model.first_stage_model
    else:
        raise ValueError(f'unknown model type: {config.model.target}')
    del model.loss
    return model
 
 
def resize_image(image, out_size):
    ratio = image.size[0] / image.size[1]
    area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])
    size = round((area * ratio)**0.5), round((area / ratio)**0.5)
    return image.resize(size, Image.LANCZOS)

In [None]:
#@title KB Tweaks & Defs

import time
import re

def generate_sequence_name(JobName, promptList, Seed=None, ZoneAdj=(8*3600)):
  "filename base for storage"
  # TODO: renderStartTime scope....
  startstr = time.strftime('_%m%d-%H%M', \
                           time.gmtime(renderStartTime-ZoneAdj))
  if JobName != "":
    name = JobName[:64] if Seed is None else JobName[:56] + '_s' + str(Seed)
    return name + startstr
  name = "out"
  if promptList:
      name = '-'.join(promptList)
      name = re.sub(' ', '_', name)
      name = re.sub('[:,|.?]', '', name)
  name = name[:64] if Seed is None else name[:56] + '_s' + str(Seed)
  name = name + startstr
  return name

#=====

#@markdown Performance Summaries
import matplotlib.pyplot as plt

#elapsedTimeStr = f"{(renderEndTime-renderStartTime)/60:.3} mins"
#print(f"Rendering was complete after {elapsedTimeStr}")

def perfplots(Losses, Times, SeqName, DestPath, Elapsed, Legend, Show=False):
  plt.figure(figsize=(10,4))
  fig, (lossplt, timeplt) = plt.subplots(2, 1, figsize=(10,6))
  # fig.suptitle(f'Performance for {SeqName}')
  lossplt.plot(Losses)
  lossplt.set_title(f"Loss for {SeqName}")

  timeplt.plot(Times)
  timeplt.set_title(f"Frame+write Times, Total {elapsedTimeStr}")


  plt.figtext(0.05, 000, f'notes: {Legend}', fontsize=14, va="top", ha="left")
  fig.tight_layout(pad=1.0)
  plt.savefig(f'{DestPath}/perf.png', bbox_inches = "tight")
  if Show:
    plt.show()
  plt.close(fig)

# perfplots(LL, FTL)

#markdown cleanup routines

def remember_progress(SeqName):
  if os.path.exists('progress.png'):
    progress_filename = f'{SeqName}.png'
    print("Moving progress.png to '{}'".format(progress_filename))
    p = Popen(['mv', 'progress.png', progress_filename], stdin=PIPE)
    return progress_filename
  print('No current "progress.png"')
  return ''

from subprocess import Popen, PIPE
import os


def save_render_params(Filename, args):
  nPrompts = len(args.prompts)
  param_file_name = Filename
  print(f"Saving parameter info in {param_file_name}")
  fp = open(param_file_name, 'w')
  for j in range(nPrompts):
    fp.write(f"prompt {j} of {nPrompts}: '{args.prompts[j]}'\n")
  for j in range(len(args.image_prompts)):
    fp.write(f"image_prompt {j}: {args.image_prompts[j]}\n")
  for j in range(len(args.noise_prompt_seeds)):
    fp.write(f"noise_prompt_seed {j}: {args.noise_prompt_seeds[j]}\n")
  for j in range(len(args.noise_prompt_weights)):
    fp.write(f"noise_prompt_seed {j}: {args.noise_prompt_weights[j]}\n")
  fp.write(f"images size {args.size}\n")
  if args.init_image != '':
    fp.write(f"init images {args.init_image}\n")
    fp.write(f"init weight {args.init_weight}\n")
  fp.write(f"CLIP: {args.clip_model}\n")
  fp.write(f"VQGAN Config: {args.vqgan_config}\n")
  fp.write(f"VQGAN Checkpoint: {args.vqgan_checkpoint}\n")
  fp.write(f"step size: {args.step_size}\n")
  fp.write(f"Cut N: {args.cutn}\n")
  fp.write(f"Cut Pow: {args.cut_pow}\n")
  fp.write(f"Seed: {args.seed}\n")
  fp.write(f"Seeds: {args.seeds}\n")
  '''
  try:
    if max_iterations > (total_frames+1):
      fp.write(f"#Interupt.. {total_frames} out of {max_iterations} frames\n")
      print(f"Interupted, {total_frames} out of {max_iterations} frames\n")
    else:
      fp.write(f"# {max_iterations} iterations\n")
  except:
    pass
    fp.write("# MAY have been interupted?\n")
  '''
  fp.close()




## Implementation tools
Mainly what you will have to modify will be `texts:`, there you can place the text (s) you want to generate (separated with `|`). It is a list because you can put more than one text, and so the AI ​​tries to 'mix' the images, giving the same priority to both texts.

To use an initial image to the model, you just have to upload a file to the Colab environment (in the section on the left), and then modify `initial_image:` putting the exact name of the file. Example: `sample.png`

You can also modify the model by changing the lines that say `model:`. Currently 1024, 16384, WikiArt, S-FLCKR and COCO-Stuff are available. To activate them you have to have downloaded them first, and then you can simply select it.

You can also use `target_images`, which is basically putting one or more images on it that the AI ​​will take as a" target ", fulfilling the same function as using a text input. To put more than one you have to use `|` as a separator.

### Varying-Aspect Image Sizes
Near 490K pixels (T4 GPU might need smaller?):
* 1:1 - 700x700 (645^2)
* 4:3 - 808x606 (512x682)
* 16:9 - 928x522 (836x470 or 800x450 or 818x460)
* Cinemascope (2.4:1) - 1084:452
* 2:1 - 988x494
* 1.66:1 - 903x544 (833x500)
* 3:2 - 855x570 (768x512) (or 750x500)

`aspect * sqrt(490000/(aspect_x*aspect_y))` in general

In [None]:
#@title Available starter images:
!nvidia-smi -L
startPix = !pushd /content/drive/MyDrive/kbImport/Pix/2021/Work2021g/starters && ls *.jpg && popd
print(startPix)
print("A Fave: /content/drive/MyDrive/kbImport/Pix/2021/Work2021g/starters/starter-star-wide-bw1.jpg")

In [None]:
#@title Parameters <font color="red">Are Here</font>
# @markdown After the initial run, you can alter these and hit "Run After" to skip the above lib loading
import time
beginTime = time.time()

job_name = "cityXOTSel"#@param {type: "string"}
step_name = "city"#@param {type:"string"}

#@markdown ---

prompts = "destroyed victorian city at night under dark clouds:20 | smoke and deep shadows:20 | ornithopters & montgolfier balloons flying high:20 | realistic black and white engraving by Gustav Dore:40" #@param {type:"string"}
include_seeds_in_prompts = False#@param {type:"boolean"}
width =  928#@param {type:"number"}
height =  522#@param {type:"number"}
annotation = "XOT selects"#@param {type:"string"}
model = "wikiart_16384" #@param ["vqgan_imagenet_f16_16384", "vqgan_imagenet_f16_1024", "wikiart_1024", "wikiart_16384", "coco", "faceshq", "sflckr"]
initial_image = "/content/drive/MyDrive/kbImport/Pix/2021/Work2021g/starters/mist-poolx.jpg"#@param {type:"string"}
invert_initial_image = True#@param {type:"boolean"}
target_images = ""#@param {type:"string"}
seed = -1#@param {type:"number"}
seedlist = "1799, 1984, 1835,1890,31 , 1959"#@param {type:"string"}
max_iterations = 600#@param {type:"number"}
display_frequency =  50#@param {type:"number"}

# TODO: use these?
input_images = ""

#######

model_names={"vqgan_imagenet_f16_16384": 'ImageNet 16384',"vqgan_imagenet_f16_1024":"ImageNet 1024", 
                 "wikiart_1024":"WikiArt 1024", "wikiart_16384":"WikiArt 16384", "coco":"COCO-Stuff", "faceshq":"FacesHQ", "sflckr":"S-FLCKR"}
model_name = model_names[model]     
 
if seed == -1:
    seed = None
if seedlist:
  seeds = [int(x) for x in seedlist.split(',')]
else:
  seeds = [seed]

if initial_image == "None":
    initial_image = None
if target_images == "None" or not target_images:
    target_images = []
else:
    target_images = target_images.split("|")
    target_images = [image.strip() for image in target_images]
 
if initial_image or target_images != []:
    input_images = True
 
origPrompts = prompts
if include_seeds_in_prompts:
  origPrompts += "(+seed)"
prompts = [frase.strip() for frase in prompts.split("|")]
if prompts == ['']:
    prompts = []
 
#TODO: this is... hacky and inconsistent
args = argparse.Namespace(
    job_name=job_name,
    orig_prompts=origPrompts,
    prompts=prompts,
    image_prompts=target_images,
    include_seeds=include_seeds_in_prompts,
    annotation=annotation,
    noise_prompt_seeds=[],
    noise_prompt_weights=[],
    size=[width, height],
    init_image=initial_image,
    init_weight=0.,
    clip_model='ViT-B/32',
    vqgan_config=f'{model}.yaml',
    vqgan_checkpoint=f'{model}.ckpt',
    step_size=0.1,
    cutn=64,
    cut_pow=1.,
    display_freq=display_frequency,
    seed=seed,
    seeds=seeds
)

print(f"Execution will be for {len(seeds)} seed(s), {max_iterations} iterations each")

In [None]:
#@title Execution
import time
import re
import os
renderStartTime = renderEndTime = time.time()

#@markdown TODO unify job storage?

#@markdown images can be intermittently displayed in this pane, but GDrive gives finer-grained viewing
display_as_you_go = False #@param {type:"boolean"}
save_step_frames = True #@param {type:"boolean"}

if step_name != "":
  step_name += '_'

#TODO: max_iterations not yet a param
#TODO - put seed into the prompt 

group_name = generate_sequence_name(args.job_name, args.prompts)
if not os.path.exists(group_name):
  os.mkdir(group_name)

def render_frames(args, Seed, GroupPath):
  sequence_name = generate_sequence_name(args.job_name,
                                       args.prompts,
                                       Seed)
  print("Sequence Name: "+ sequence_name)
  steps_dir = os.path.join(GroupPath, sequence_name+'_steps')
  os.mkdir(steps_dir)

  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  print('Using device:', device)
  promptList = list(args.prompts)
  if args.include_seeds:
    promptList.append(str(Seed))
  if promptList:
      print('Using text prompt(s):', promptList)
  if args.image_prompts:
      print('Using image prompts:', args.image_prompts)

  if Seed is None:
      seed = torch.seed()
  else:
      seed = Seed
  torch.manual_seed(seed)
  print('Using seed:', seed)
 
  model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device)
  perceptor = clip.load(args.clip_model, jit=False)[0].eval().requires_grad_(False).to(device)
 
  cut_size = perceptor.visual.input_resolution
  e_dim = model.quantize.e_dim
  f = 2**(model.decoder.num_resolutions - 1)
  make_cutouts = MakeCutouts(cut_size, args.cutn, cut_pow=args.cut_pow)
  n_toks = model.quantize.n_e
  toksX, toksY = args.size[0] // f, args.size[1] // f
  sideX, sideY = toksX * f, toksY * f
  z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
  z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]

  if args.init_image:
      pil_image = Image.open(args.init_image).convert('RGB')
      if invert_initial_image:
        pil_image = ImageOps.invert(pil_image)
      pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
      z, *_ = model.encode(TF.to_tensor(pil_image).to(device).unsqueeze(0) * 2 - 1)
  else:
      one_hot = F.one_hot(torch.randint(n_toks, [toksY * toksX], device=device), n_toks).float()
      z = one_hot @ model.quantize.embedding.weight
      z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
  z_orig = z.clone()
  z.requires_grad_(True)
  opt = optim.Adam([z], lr=args.step_size)
  
  normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                  std=[0.26862954, 0.26130258, 0.27577711])
  
  pMs = []
  
  for prompt in promptList:
      txt, weight, stop = parse_prompt(prompt)
      embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
      pMs.append(Prompt(embed, weight, stop).to(device))
  
  for prompt in args.image_prompts:
      path, weight, stop = parse_prompt(prompt)
      img = resize_image(Image.open(path).convert('RGB'), (sideX, sideY))
      batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
      embed = perceptor.encode_image(normalize(batch)).float()
      pMs.append(Prompt(embed, weight, stop).to(device))
  
  for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights):
      gen = torch.Generator().manual_seed(seed)
      embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen)
      pMs.append(Prompt(embed, weight).to(device))

  lossLogger = []
  frameTimeLogger = []

  # ----still edit from here.....

  def synth(z):
    z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1)
    return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1)
 
  # TODO: update and add arguments
  def add_xmp_data(nombrefichero):
    image = ImgTag(filename=nombrefichero)
    image.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'creator', \
                            'Kevin Bjorke, VQGAN+CLIP', {"prop_array_is_ordered":True, "prop_value_is_array":True})
    image.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'title', \
                            sequence_name, {"prop_array_is_ordered":True, "prop_value_is_array":True})
    image.xmp.append_array_item(libxmp.consts.XMP_NS_DC, \
                            'i', str(i), {"prop_array_is_ordered":True, "prop_value_is_array":True})
#    image.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'model', \
#                                model_name, {"prop_array_is_ordered":True, "prop_value_is_array":True})
    image.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'seed', \
                            str(seed) , {"prop_array_is_ordered":True, "prop_value_is_array":True})
    image.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'input_images', \
                            str(input_images) , {"prop_array_is_ordered":True, "prop_value_is_array":True})
    if args.annotation != "" and args.annotation is not None:
      image.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'notes', \
                            args.annotation, {"prop_array_is_ordered":True, "prop_value_is_array":True})

    #for frases in promptList:
    #    image.xmp.append_array_item(libxmp.consts.XMP_NS_DC, 'Prompt' ,frases, {"prop_array_is_ordered":True, "prop_value_is_array":True})
    image.close()

  # TODO: improve 
  def add_stegano_data(filename):
    data = {
        "title":  sequence_name,
#        "title": " | ".join(promptList) if promptList else None,
        "notebook": "VQGAN+CLIPwTweaks",
        "author": "Kevin Bjorke",
        "i": i, # rename?
        "notes": args.annotation,
        "seed": str(seed),
        "input_images": input_images
    }
    lsb.hide(filename, json.dumps(data)).save(filename)
 
  @torch.no_grad()
  def checkin(i, losses, maxIter):
    losses_str = ', '.join(f'{loss.item():g}' for loss in losses)
    renderEndTime = time.time()
    elapsedMinutes = (renderEndTime-renderStartTime) / 60
    tqdm.write(f'i: {i}, loss: {sum(losses).item():g}, losses: {losses_str} after {elapsedMinutes:.4} mins')
    out = synth(z)
    TF.to_pil_image(out[0].cpu()).save('progress.png')
    add_stegano_data('progress.png')
    add_xmp_data('progress.png')
    if display_as_you_go or (i+1) == maxIter:
      display.display(display.Image('progress.png'))
 
  def ascend_txt(i):
    out = synth(z)
    iii = perceptor.encode_image(normalize(make_cutouts(out))).float()
 
    result = []
 
    if args.init_weight:
        result.append(F.mse_loss(z, z_orig) * args.init_weight / 2)
 
    for prompt in pMs:
        result.append(prompt(iii))
    if save_step_frames:
      img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:]
      img = np.transpose(img, (1, 2, 0))
      filename = f"{steps_dir}/{step_name}{i:04}.png"
      imageio.imwrite(filename, np.array(img))
      add_stegano_data(filename)
      add_xmp_data(filename)
    return result


  def train(i, maxIter):
    opt.zero_grad()
    lossAll = ascend_txt(i)
    if (i+1) % args.display_freq == 0 or (i+1) == max_iterations:
        checkin(i, lossAll, maxIter)
    loss = sum(lossAll)
    lossLogger.append(loss.item())
    loss.backward()
    opt.step()
    with torch.no_grad():
        z.copy_(z.maximum(z_min).minimum(z_max))
 
  i = 0
  try:
    with tqdm() as pbar:
      while True:
        frameStartTime = time.time()
        train(i, max_iterations)
        if i == max_iterations:
            break
        i += 1
        pbar.update()
        renderEndTime = time.time()
        frameTimeLogger.append(renderEndTime-frameStartTime)
  except KeyboardInterrupt:
    print("keyboard interupt")
    pass
  return (sequence_name, steps_dir, i, lossLogger, frameTimeLogger)

TestLog = []
for s in args.seeds:
  started = time.time()
  try:
    (seq_name, steps_path, I, LL, FTL) = render_frames(args, s, group_name)
  except:
    print("Unknown error")
    break
  ended = time.time()
  elapsed = ended-started
  elapsedTimeStr = f"{elapsed/60:.3} mins"
  print(f"Rendering was complete after {elapsedTimeStr}")
  legend = f": {args.annotation}" if args.annotation else ""
  legend = f"S{s}{legend}"
  perfplots(LL, FTL, seq_name, steps_path, elapsedTimeStr, legend)
  last_pic = remember_progress(seq_name)
  TestLog.append( (seq_name, last_pic, I, legend) )
  if (I < max_iterations):
    print(f"Interuppted? {I} < {max_iterations}")
    break

params_name = f'{group_name}.txt'

try:
  save_render_params(os.path.join(group_name, params_name), args)
except:
  print(f'No "{params_name}" parameter file, sorry')


renderEndTime = time.time()
print(f"{len((TestLog))} Tests run after {(renderEndTime-renderStartTime)/60:.4} minutes")


print('\n'.join([str(t) for t in TestLog]))

In [None]:
#@title labeled montage
import math
import os
from PIL import ImageFile, Image, ImageDraw, ImageOps, ImageFont
import numpy as np
import matplotlib.pyplot as plt

#@markdown TODO: move panel to the top, include iteration count


libFontPath='/usr/share/fonts/truetype/liberation/LiberationSansNarrow-Regular.ttf'

fontBaseHt = 30

fnt = ImageFont.truetype(libFontPath, fontBaseHt)

panelHt = 100
panelTextY = 4

# read image 1 : get aspect for all
firstImg = Image.open(TestLog[0][1]).convert('RGB')
(iw, ih) = firstImg.size
aspect = iw / ih
# format for proof sheet based on 8.5x11 w/1" margin
proofRect = 10/7.5
# format for proof sheet based on 35mm TODO: decide preference
proofRect = 3/2
mw = max(int(0.5+proofRect*math.sqrt(len(TestLog))/aspect),1)
mh = max(int(0.5+len(TestLog)/mw),1)

print(f"Creating {mw}x{mh} montage of {len(TestLog)} results")

scale = 0.65 # TODO calculate properly
miw = int(iw*scale)
mih = int(ih*scale)

def contained(Img, DestSize):
  osize = Img.size
  containerImg = Image.new("RGBA", DestSize, (255,255,255,255))
  xs = DestSize[0]/osize[0]
  ys = DestSize[1]/osize[1]
  scaledImg = ImageOps.scale(Img, min(xs, ys))
  containerImg.paste(scaledImg, (0, 0) )
  return containerImg

montageImg = Image.new('RGBA', (miw*mw, panelHt+mih*mh), (64, 64, 64, 255))
i = 0
for k in range(mh):
  for j in range(mw):
    srcFile = TestLog[i][1]
    if os.path.exists(srcFile):
      img = Image.open(srcFile).convert('RGBA')
    # cellImg = ImageOps.contain(img, (miw,mih) ) # no contain!
      cellImg = contained(img, (miw,mih) ) # no ImageOps.contain!
    else:
      cellImg = Image.new('RGBA', (miw, mih), (128, 64, 64, 255) )
    labelImg = Image.new("RGBA", cellImg.size, (255,255,255,0))
    d = ImageDraw.Draw(labelImg)
    d.text((10,4), TestLog[i][3], font=fnt, fill=(255,255,255,255))
    out = Image.alpha_composite(cellImg, labelImg)
    montageImg.paste(out, (j*miw, panelHt+k*mih) )
    i = i+1
    if i >= len(TestLog):
      break

panelFontBaseHt = 20

panelFnt = ImageFont.truetype(libFontPath, panelFontBaseHt)

labelImg = Image.new("RGBA", (montageImg.size[0], panelHt), (40,40,40,255))
d = ImageDraw.Draw(labelImg)
d.text((10,panelTextY), args.orig_prompts,
       font=panelFnt, fill=(255,255,255,255))
panelTextY += panelFontBaseHt
d.text((10,panelTextY), args.annotation,
       font=panelFnt, fill=(255,255,255,255))
panelTextY += panelFontBaseHt
d.text((10,panelTextY),
       f"{(renderEndTime-renderStartTime)/60:.4} minutes for all",
       font=panelFnt, fill=(255,255,255,255))
#out = Image.alpha_composite(cellImg, labelImg)
# montageImg.paste(labelImg, (0, mh*mih) )
montageImg.paste(labelImg, (0, 0) )

montageImg.convert('RGB')
m_seq_name = generate_sequence_name(args.job_name,
                                       args.prompts);
montName = f'{m_seq_name}_testpanel.png'
montageImg.save(montName)
print(f"Montage image: {montName}")

# how did that go?

## Generate a video with the results

If you want to generate a video with the images as frames, just click below. You can modify the number of FPS, the initial frame, the last frame, etc.

In [None]:
%%script false 
#@title Generate video using ffmpeg SKIP
produce_video = False #@param {type:"boolean"}

from subprocess import Popen, PIPE
import time

videoStartTime = time.time()
print(f"Rendering was complete after {(renderEndTime-renderStartTime)/60:.3} minutes")

init_frame = 1 # This is the frame where the video will start
last_frame = init_frame

try:
    last_frame = i # You can change i to the number of the last frame you want to generate. It will raise an error if that number of frames does not exist.
    total_frames = last_frame-init_frame
except:
    import glob
    total_frames = len(glob.glob(f'{steps_dir}/*.png'))
    last_frame = total_frames + init_frame - 1
    print(f"guessing at frame count: {total_frames}?")


min_fps = 10
max_fps = 30


length = 10 # 15 # Desired video runtime in seconds

# Names the video after the prompt if there is one, if not, defaults to video.mp4
def listToString(s): 
    # initialize an empty string
    str1 = "" 
    # traverse in the string  
    for ele in s: 
        str1 += ele  
    # return string  
    return str1 

if not produce_video:
  print("Skipping video")
elif total_frames < 2:
  print("no frames available")
else:

  frames = []
  tqdm.write('Generating video...')
  for i in range(init_frame,last_frame): #
      filename = f"{steps_dir}/{step_name}{i:04}.png"
      frames.append(Image.open(filename))

  #fps = last_frame/10
  fps = np.clip(total_frames/length,min_fps,max_fps)

          
  video_filename = f'{sequence_name}.mp4'
  print("Video filename: "+ video_filename)

  # libFont='/usr/share/fonts/truetype/liberation/LiberationSansNarrow-Regular.ttf'
  libFont='LiberationSansNarrow-Regular.ttf'

  mpeg_args = ['ffmpeg', '-y', \
               '-f', 'image2pipe', \
               '-vcodec', 'png', \
               '-r', str(fps), \
               '-i', '-', \
               '-vcodec', 'libx264', \
               '-r', str(fps), \
               '-pix_fmt', 'yuv420p', \
               '-crf', '17', \
               # '-vf', '"drawtext=fontfile='+libFont+': text=\'%{frame_num}\': start_number=1: x=(w-tw)/2: y=h-(2*lh): fontcolor=black: fontsize=20: box=1: boxcolor=white: boxborderw=5"', \
               '-preset', 'veryslow', \
               video_filename]
  
  mpstr = ' '.join(mpeg_args)
  print(f'ffmpeg args: "{mpstr}"')
  p = Popen(mpeg_args, stdin=PIPE)
  for im in tqdm(frames):
      im.save(p.stdin, 'PNG')
  p.stdin.close()

  print("Compressing video...")
  p.wait()
  videoEndTime = time.time()
  print(f"Video ready after {(videoEndTime-videoStartTime)/60:.3} minutes")

In [None]:
#@title final cleanup

#markdown TODO: assemble test montage

print('\nReady for a new rendering!')
#endTime = time.time()
#print(f"Rendering+Video+Cleanup time: {(endTime-beginTime)/60:.3} minutes")


In [None]:
%%script false 
# @title View video in browser
# @markdown *SKIPPED* This process may take a little longer. If you don't want to wait, download it by executing the next cell instead of using this cell.
mp4 = open(video_filename,'rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
display.HTML("""
<video width=400 controls>
      <source src="%s" type="video/mp4">
</video>
""" % data_url)

In [None]:
%%script false 
# @title *SKIPPED* Download video
from google.colab import files
files.download(video_filename)

In [None]:
#@title Close Notebook (for last session!)
close_down_at_end = False #@param {type:"boolean"}

if close_down_at_end:
  print("adios")
  exit(0)

print("run ended, kernel still active")