<a href="https://colab.research.google.com/github/keirwilliamsxyz/keirxyz/blob/main/VQGAN%2BCLIP_Batch_generate_%2B_canvas.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Generating images using VQGAN+CLIP (+ mass generating)

Originally made by Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings).
The original BigGAN+CLIP method was by https://twitter.com/advadnoun.
Added some explanations and modifications by Eleiber#8347, pooling trick by Crimeacs#8222 (https://twitter.com/EarthML1) and the GUI was made with the help of Abulafia#3734.

This notebook adds to the original in the fact, that it facilitates multiple runs of generating images based on a list of prompts. We're using it in some of our experiments with the prompts and inputs - and perhaps you may find use of it too! If you see any ways to improve it or see mistakes, please do let me know: https://twitter.com/neurowelt :)

### Licensed uner the MIT License

In [None]:
# Copyright (c) 2021 Katherine Crowson

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.


### Installations

If you are using permanent storage of any sort then you do not have to repeat this step.

(If on Colab you're asked to restart runtime to update `pydevd` package - you don't have to do that, I found it actually breaking the generation process)

In [None]:
!git clone https://github.com/openai/CLIP
!git clone https://github.com/CompVis/taming-transformers.git
!pip install ftfy regex tqdm omegaconf pytorch-lightning
!pip install kornia
!pip install imageio-ffmpeg   
!pip install einops
!pip install psutil

## Set up necessary modules

### Libraries

In [None]:
import argparse
import math
from pathlib import Path
import sys

import pandas as pd
import numpy as np
import itertools
import psutil
import random
import copy
from matplotlib import pyplot as plt

%matplotlib inline

import cv2
import matplotlib.pyplot as plt

sys.path.insert(1, 'taming-transformers')
from IPython import display
from base64 import b64encode
from omegaconf import OmegaConf
from PIL import Image
from taming.models import cond_transformer, vqgan
import taming.modules 
import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm

from CLIP import clip
import kornia.augmentation as K
import numpy as np
import imageio
from PIL import ImageFile, Image
ImageFile.LOAD_TRUNCATED_IMAGES = True

def sinc(x):
    return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))


def lanczos(x, a):
    cond = torch.logical_and(-a < x, x < a)
    out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
    return out / out.sum()


def ramp(ratio, width):
    n = math.ceil(width / ratio + 1)
    out = torch.empty([n])
    cur = 0
    for i in range(out.shape[0]):
        out[i] = cur
        cur += ratio
    return torch.cat([-out[1:].flip([0]), out])[1:-1]


def resample(input, size, align_corners=True):
    n, c, h, w = input.shape
    dh, dw = size

    input = input.view([n * c, 1, h, w])

    if dh < h:
        kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
        pad_h = (kernel_h.shape[0] - 1) // 2
        input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
        input = F.conv2d(input, kernel_h[None, None, :, None])

    if dw < w:
        kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
        pad_w = (kernel_w.shape[0] - 1) // 2
        input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
        input = F.conv2d(input, kernel_w[None, None, None, :])

    input = input.view([n, c, h, w])
    return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)


class ReplaceGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x_forward, x_backward):
        ctx.shape = x_backward.shape
        return x_forward

    @staticmethod
    def backward(ctx, grad_in):
        return None, grad_in.sum_to_size(ctx.shape)


replace_grad = ReplaceGrad.apply


class ClampWithGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, min, max):
        ctx.min = min
        ctx.max = max
        ctx.save_for_backward(input)
        return input.clamp(min, max)

    @staticmethod
    def backward(ctx, grad_in):
        input, = ctx.saved_tensors
        return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None


clamp_with_grad = ClampWithGrad.apply


def vector_quantize(x, codebook):
    d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T
    indices = d.argmin(-1)
    x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
    return replace_grad(x_q, x)


class Prompt(nn.Module):
    def __init__(self, embed, weight=1., stop=float('-inf')):
        super().__init__()
        self.register_buffer('embed', embed)
        self.register_buffer('weight', torch.as_tensor(weight))
        self.register_buffer('stop', torch.as_tensor(stop))

    def forward(self, input):
        input_normed = F.normalize(input.unsqueeze(1), dim=2)
        embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
        dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
        dists = dists * self.weight.sign()
        return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()


def parse_prompt(prompt):
    vals = prompt.rsplit(':', 2)
    vals = vals + ['', '1', '-inf'][len(vals):]
    return vals[0], float(vals[1]), float(vals[2])


class MakeCutouts(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow

        self.augs = nn.Sequential(
            # K.RandomHorizontalFlip(p=0.5),
            # K.RandomVerticalFlip(p=0.5),
            # K.RandomSolarize(0.01, 0.01, p=0.7),
            # K.RandomSharpness(0.3,p=0.4),
            # K.RandomResizedCrop(size=(self.cut_size,self.cut_size), scale=(0.1,1),  ratio=(0.75,1.333), cropping_mode='resample', p=0.5),
            # K.RandomCrop(size=(self.cut_size,self.cut_size), p=0.5),
            K.RandomAffine(degrees=15, translate=0.1, p=0.7, padding_mode='border'),
            K.RandomPerspective(0.7,p=0.7),
            K.ColorJitter(hue=0.1, saturation=0.1, p=0.7),
            K.RandomErasing((.1, .4), (.3, 1/.3), same_on_batch=True, p=0.7),
            
)
        self.noise_fac = 0.1
        self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
        self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        
        for _ in range(self.cutn):

            # size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            # offsetx = torch.randint(0, sideX - size + 1, ())
            # offsety = torch.randint(0, sideY - size + 1, ())
            # cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            # cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))

            # cutout = transforms.Resize(size=(self.cut_size, self.cut_size))(input)
            
            cutout = (self.av_pool(input) + self.max_pool(input))/2
            cutouts.append(cutout)
        batch = self.augs(torch.cat(cutouts, dim=0))
        if self.noise_fac:
            facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
            batch = batch + facs * torch.randn_like(batch)
        return batch

def load_vqgan_model(config_path, checkpoint_path):
    config = OmegaConf.load(config_path)
    if config.model.target == 'taming.models.vqgan.VQModel':
        model = vqgan.VQModel(**config.model.params)
        model.eval().requires_grad_(False)
        model.init_from_ckpt(checkpoint_path)
    elif config.model.target == 'taming.models.vqgan.GumbelVQ':
        model = vqgan.GumbelVQ(**config.model.params)
        model.eval().requires_grad_(False)
        model.init_from_ckpt(checkpoint_path)
    elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer':
        parent_model = cond_transformer.Net2NetTransformer(**config.model.params)
        parent_model.eval().requires_grad_(False)
        parent_model.init_from_ckpt(checkpoint_path)
        model = parent_model.first_stage_model
    else:
        raise ValueError(f'unknown model type: {config.model.target}')
    del model.loss
    return model

def resize_image(image, out_size):
    ratio = image.size[0] / image.size[1]
    area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])
    size = round((area * ratio)**0.5), round((area / ratio)**0.5)
    return image.resize(size, Image.LANCZOS)

class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

### Google Drive

If you want to connect your drive for accessing your models or using the option to save final iteration of each image to Google Drive, connect you drive here.

In [None]:
### COLAB
from google.colab import drive
drive.mount('/content/drive')

## Prepare data and parameters

### Input and output data

#### Input data by hand

You can easily just write down all the data you want to use for generation below.

* `texts` - write prompts for generation here, separating them with a `|`
* `input_images` - write the names in square brackets of initial image files on which you would like to generate images, separate them with `|`
* `iterations` - write down iterations in square brakcets for each image, separating them with `|`

Each prompt will be paired with input image name and iteration on the corresponding position. If you would like to generate a given prompt on more than one image and on more iterations, write as follows (an example):
* `texts`: "black sheep by Picasso | red apple on a tree"
* `input_images`: "[b.png] | [c.png, b.png]"
* `iterations`: [20, 30, 40] | [20]

So you get:
* black sheep by Picasso – [20, 30, 40] – [b.png]
* red apple on a tree – [20] – [c.png, b.png]

In [None]:
texts = "" #@param {type:"string"}
input_images = "" #@param {type:"string"}
iterations = "" #@param {type:"string"}

texts = [phrase.strip() for phrase in texts.split("|")]
iterations = [phrase.strip()[1:-1] for phrase in iterations.split("|")]
input_images = [phrase.strip()[1:-1] for phrase in input_images.split("|")]

way_chosen = 1
prompts = pd.DataFrame({"prompt":texts,"filename":input_images,"iterations":iterations})

#### Import data from .xlsx

You can import data from .xlsx if you want to generate in a more factory fashion. You can choose one of two ways: the **artist+prompt** way or the **all-in-one** way.

##### Artists and prompts way

You will need two .xlsx files in order to proceed with this method. The first one with just the artists, the seond one with all the rest of the prompts.

Artists excel should have only one column with artists (there may be many artists in one cell).

Prompts should be constructed in this way:
* `prompt` - text used for generating
* `filename` - the names of the input files separated with `, `
* `iterations` - iterations for each prompted image, there may be just one, but you can write them down in the same manner as filenames, with `,`

Then, the final prompts will be generated in a form:
"{`prompt`} by {`artist`}" and will be created for all possible combinations between artists and prompts.

This will create prompts for every artist and prompt, giving you  `len(artists) * len(prompts)`  of generations.

In [None]:
artists_excel = "" #@param {type:"string"}
prompts_excel = "" #@param {type:"string"}

try:
    arts = pd.read_excel('your_excel.xlsx')#,header=[1],index_col=0) # sometimes pandas adds Unnamed column, that helps fixing it
    other = pd.read_excel('your_excel.xlsx')
    data_a = pd.DataFrame(np.where(arts.isna(),'',arts),columns=['artist'])

    data_o = pd.DataFrame(np.where(other.isna(),'',other),columns=['prompt','filename','iterations'])

    prompts = data_o.fillna('')
    artists = data_a.fillna('').sort_index()
    way_chosen = 0
    print(artists.head())
    print(prompts.head())
except Exception as e:
    print(f'Failed to load dataset because of a following exception: {e}')

##### All-in-one way

You will need just one excel spreadsheet filled with prompts. In order for this spreadsheet to work it requires to have the following columns:
* `prompt` - text used for generating
* `filename` - name of the file to base the generated image on (if no input then leave the cell empty), there can be multiple files, if so please separate with `,`
* `iterations` - iterations for each prompted image, you can put many iterations for given image, just separate them with `,`

In [None]:
your_excel = "" #@param {type: "string"}

try:
    d = pd.read_excel(your_excel)
    prompts = pd.DataFrame(np.where(d.isna(),'',d),columns=['prompt','filename','iterations'])
    
    way_chosen = 1
except Exception as e:
    print(f'Failed to load dataset because of a following exception: {e}')

#### Output folders

Create folders for storing generated images, by default it will create a steps/full folder. **If steps folder exists, it will be deleted!**

In [None]:
import os
import shutil

path0 = './steps'
path1 = './steps/full'

try:
    shutil.rmtree(path0)
    os.mkdir(path0)
    os.mkdir(path1)
except OSError:
    try:
        os.mkdir(path0)
        os.mkdir(path1)
    except OSError as e:
        print("Failed: {}".format(e))

### Parameters

First generate a random seed. If set to $-1$ a random seed will be generated.

In [None]:
seed =  -1#@param {type:"number", min:0}

if seed == -1:
    random.seed(int(str((psutil.virtual_memory()[0] / psutil.virtual_memory()[1]))[-3:]))
    seed = random.randint(0,1000)

#### Model choice

Choose the model to use. Choose accordingly to the platform you're on. As in the original notebook, possibility of donwloading the model below.

In [None]:
#@markdown By default, the notebook downloads the 1024 and 16384 models from ImageNet. There are others like COCO-Stuff, WikiArt 1024, WikiArt 16384, FacesHQ or S-FLCKR, which are heavy, and if you are not going to use them it would be pointless to download them, so if you want to use them, simply select the models to download.

imagenet_1024 = True #@param {type:"boolean"}
imagenet_16384 = False #@param {type:"boolean"}
coco = False #@param {type:"boolean"}
faceshq = False #@param {type:"boolean"}
wikiart_1024 = False #@param {type:"boolean"}
wikiart_16384 = False #@param {type:"boolean"}
sflckr = False #@param {type:"boolean"}
openimages_8192 = False #@param {type:"boolean"}

if imagenet_1024:
  #!curl -L -o vqgan_imagenet_f16_1024.yaml -C - 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_1024.yaml' #ImageNet 1024
  #!curl -L -o vqgan_imagenet_f16_1024.ckpt -C - 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_1024.ckpt'  #ImageNet 1024
  !curl -L -o vqgan_imagenet_f16_1024.yaml -C - 'https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1'
  !curl -L -o vqgan_imagenet_f16_1024.ckpt -C - 'https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fckpts%2Flast.ckpt&dl=1'
if imagenet_16384:
  #!curl -L -o vqgan_imagenet_f16_16384.yaml -C - 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.yaml' #ImageNet 16384
  #!curl -L -o vqgan_imagenet_f16_16384.ckpt -C - 'http://mirror.io.community/blob/vqgan/vqgan_imagenet_f16_16384.ckpt' #ImageNet 16384
  !curl -L -o vqgan_imagenet_f16_16384.ckpt -C - 'https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fckpts%2Flast.ckpt' #ImageNet 16384
  !curl -L -o vqgan_imagenet_f16_16384.yaml -C - 'https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fconfigs%2Fmodel.yaml' #ImageNet 16384
if openimages_8192:
  !curl -L -o vqgan_openimages_f16_8192.yaml -C - 'https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' #ImageNet 16384
  !curl -L -o vqgan_openimages_f16_8192.ckpt -C - 'https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fckpts%2Flast.ckpt&dl=1' #ImageNet 16384

if coco:
  !curl -L -o coco.yaml -C - 'https://dl.nmkd.de/ai/clip/coco/coco.yaml' #COCO
  !curl -L -o coco.ckpt -C - 'https://dl.nmkd.de/ai/clip/coco/coco.ckpt' #COCO
if faceshq:
  !curl -L -o faceshq.yaml -C - 'https://drive.google.com/uc?export=download&id=1fHwGx_hnBtC8nsq7hesJvs-Klv-P0gzT' #FacesHQ
  !curl -L -o faceshq.ckpt -C - 'https://app.koofr.net/content/links/a04deec9-0c59-4673-8b37-3d696fe63a5d/files/get/last.ckpt?path=%2F2020-11-13T21-41-45_faceshq_transformer%2Fcheckpoints%2Flast.ckpt' #FacesHQ
if wikiart_1024: 
  !curl -L -o wikiart_1024.yaml -C - 'http://mirror.io.community/blob/vqgan/wikiart.yaml' #WikiArt 1024
  !curl -L -o wikiart_1024.ckpt -C - 'http://mirror.io.community/blob/vqgan/wikiart.ckpt' #WikiArt 1024
if wikiart_16384: 
  !curl -L -o wikiart_16384.yaml -C - 'http://mirror.io.community/blob/vqgan/wikiart_16384.yaml' #WikiArt 16384
  !curl -L -o wikiart_16384.ckpt -C - 'http://mirror.io.community/blob/vqgan/wikiart_16384.ckpt' #WikiArt 16384
if sflckr:
  !curl -L -o sflckr.yaml -C - 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fconfigs%2F2020-11-09T13-31-51-project.yaml&dl=1' #S-FLCKR
  !curl -L -o sflckr.ckpt -C - 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fcheckpoints%2Flast.ckpt&dl=1' #S-FLCKR

And option for locally stored model.

In [None]:
### COLAB
LOCAL_MODEL_YAML = 'your.yaml'
LOCAL_MODEL_CKPT = 'your.ckpt'

#### Model parameters

Set up the parameters and prepare the final settings list for the model.

Following parameters can be changed below:
* `width` and `height` - regulate the size of the image
* `model` - choose model based on the one you picked before
* `target_images` - choose the target image for the generator (leave empty if not used)
* `seed` - the generator's seed, if the same allows for comparison between outcomes of generations
* `step_size` - effectively the size of generation steps, the larger it is the further the generation process goes witihin the same time span
* `local_model` - if `True` than will use a model available in the storage the notebook is on (depends on what you picked before)

In [None]:
#@title Parameters
width =  512 #@param {type:"number"}
height = 512 #@param {type:"number"}
model = "vqgan_imagenet_f16_16384" #@param ["vqgan_imagenet_f16_16384", "vqgan_imagenet_f16_1024", "vqgan_openimages_f16_8192", "wikiart_1024", "wikiart_16384", "coco", "faceshq", "sflckr"]
target_images = "" #@param {type:"string"}
step_size = 0.1 #@param {type:"slider", min:0.001, max:0.9, step:0.001}
local_model = True #@param {type: "boolean"}

if seed == -1:
    seed = None

if target_images == "None" or not target_images:
    target_images = []
else:
    target_images = target_images.split("|")
    target_images = [image.strip() for image in target_images]

if local_model:
    model_yaml = LOCAL_MODEL_YAML
    model_ckpt = LOCAL_MODEL_CKPT
else:
    model_yaml = f'{model}.yaml'
    model_ckpt = f'{model}.ckpt'
    

if way_chosen == 1:
    try:
        row_tracker = 0
        multiple_runs = []
        
        for prompt in prompts.prompt:
            texts = []
            this_row = prompts.loc[prompts['prompt']==prompt]
            
            try:
                iters = [int(x.strip()) for x in this_row.iloc[0,-1].split(",")]
            except ValueError:
                print('No numerical values in the "iterations" column. Please repair.')
            except AttributeError:
                iters = [this_row.iloc[0,-1]]
                
            init_image = [x.strip() for x in this_row.iloc[0,-2].split(",")]
            if init_image == '':
                init_image = ['None']
                
            texts.append("{prompt}".format(prompt=prompt))
            
            settings = AttrDict({'texts':texts,
                                 'width':width,
                                 'height':height,
                                 'model':[model_yaml, model_ckpt, model],
                                 'images_interval':1,
                                 'init_images':init_image,
                                 'target_images':target_images,
                                 'seed':seed,
                                 'max_iterations':iters[-1],
                                 'iter_stops':iters,
                                 'step_size':step_size,
                                 'id':[row_tracker],
                                 })
            
            row_tracker += 1
            multiple_runs.append(settings)
    except NameError:
        print("Artists dataframe doesn't exist")
    except IndexError:
        print("Artists dataframe is empty")
    

if way_chosen == 0:
    try:
        multiple_runs = []

        for artist in artists.artist:
            texts = []
            row_ids = []
            inits = []
            row_tracker = 0

            for prompt in prompts.prompt:
                this_row = prompts.iloc[row_tracker]

                try:
                    iters = [int(x.strip()) for x in this_row[-1].split(",")]
                except ValueError:
                    print('No numerical values in the "iterations" column. Please repair.')
                except AttributeError:
                    iters = [this_row[-1]]

                init_image = [x.strip() for x in this_row[-2].split(",")]
                if init_image == '':
                    init_image = ['None']
                inits.append(init_image)

                text = this_row[0]

                if ((text != '')):
                    for i in range(len(init_image)):
                        texts.append("{text} by {art}".format(text=text,
                                                              art=artist))
                        row_ids.append(row_tracker)

                row_tracker += 1

            inits = list(itertools.chain(*inits))
            settings = AttrDict({'texts':texts,
                                 'width':width,
                                 'height':height,
                                 'model':[model_yaml, model_ckpt, model],
                                 'images_interval':1,
                                 'init_images':inits,
                                 'target_images':target_images,
                                 'seed':seed,
                                 'max_iterations':iters[-1],
                                 'iter_stops':iters,
                                 'step_size':step_size,
                                 'id':row_ids,
                                 })

            multiple_runs.append(settings)
    except NameError:
        print("Artists dataframe doesn't exist")
    except IndexError:
        print("Artists dataframe is empty")

### Prepare model and generation procedure

#### Function preparing the model

In [None]:
def prepare_model(settings):
    model_names={"vqgan_imagenet_f16_16384": 'ImageNet 16384',"vqgan_imagenet_f16_1024":"ImageNet 1024", 'vqgan_openimages_f16_8192':'OpenImages 8912',
                    "wikiart_1024":"WikiArt 1024", "wikiart_16384":"WikiArt 16384", "coco":"COCO-Stuff", "faceshq":"FacesHQ", "sflckr":"S-FLCKR"}
    name_model = model_names[settings.model[2]]     

    args = argparse.Namespace(
        prompts=[settings.texts],
        image_prompts=settings.target_images,
        noise_prompt_seeds=[],
        noise_prompt_weights=[],
        size=[settings.width, settings.height],
        init_image=settings.init_images,
        init_weight=0.,
        clip_model='ViT-B/32',
        vqgan_config=settings.model[0],
        vqgan_checkpoint=settings.model[1],
        step_size=settings.step_size,
        cutn=32,
        cut_pow=1.,
        display_freq=settings.images_interval,
        seed=settings.seed,
    )

    return args

#### Function generating images

In [None]:
def generate_image(settings, args):
    from urllib.request import urlopen
    import re

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print('Using device:', device)
    if settings.texts:
        print('Using texts:', settings.texts)
    if settings.target_images:
        print('Using image prompts:', settings.target_images)
    if args.seed is None:
        seed = torch.seed()
    else:
        seed = args.seed
    torch.manual_seed(seed)
    print('Using seed:', seed)

    model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device)
    perceptor = clip.load(args.clip_model, jit=False)[0].eval().requires_grad_(False).to(device)
    # clock=deepcopy(perceptor.visual.positional_embedding.data)
    # perceptor.visual.positional_embedding.data = clock/clock.max()
    # perceptor.visual.positional_embedding.data=clamp_with_grad(clock,0,1)

    cut_size = perceptor.visual.input_resolution

    f = 2**(model.decoder.num_resolutions - 1)
    make_cutouts = MakeCutouts(cut_size, args.cutn, cut_pow=args.cut_pow)

    toksX, toksY = args.size[0] // f, args.size[1] // f
    sideX, sideY = toksX * f, toksY * f

    if args.vqgan_checkpoint == 'vqgan_openimages_f16_8192.ckpt':
        e_dim = 256
        n_toks = model.quantize.n_embed
        z_min = model.quantize.embed.weight.min(dim=0).values[None, :, None, None]
        z_max = model.quantize.embed.weight.max(dim=0).values[None, :, None, None]
    else:
        e_dim = model.quantize.e_dim
        n_toks = model.quantize.n_e
        z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
        z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]
    # z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
    # z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]

    # normalize_imagenet = transforms.Normalize(mean=[0.485, 0.456, 0.406],
    #                                            std=[0.229, 0.224, 0.225])

    if args.init_image:
        if 'http' in args.init_image:
            img = Image.open(urlopen(args.init_image))
        else:
            img = Image.open(args.init_image)
        pil_image = img.convert('RGB')
        pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
        pil_tensor = TF.to_tensor(pil_image)
        z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1)
    else:
        one_hot = F.one_hot(torch.randint(n_toks, [toksY * toksX], device=device), n_toks).float()
        # z = one_hot @ model.quantize.embedding.weight
        if args.vqgan_checkpoint == 'vqgan_openimages_f16_8192.ckpt':
            z = one_hot @ model.quantize.embed.weight
        else:
            z = one_hot @ model.quantize.embedding.weight
        z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2) 
        z = torch.rand_like(z)*2
    z_orig = z.clone()
    z.requires_grad_(True)
    opt = optim.Adam([z], lr=args.step_size)

    normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                    std=[0.26862954, 0.26130258, 0.27577711])



    pMs = []

    for prompt in args.prompts:
        txt, weight, stop = parse_prompt(prompt)
        embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
        pMs.append(Prompt(embed, weight, stop).to(device))

    for prompt in args.image_prompts:
        path, weight, stop = parse_prompt(prompt)
        img = Image.open(path)
        pil_image = img.convert('RGB')
        img = resize_image(pil_image, (sideX, sideY))
        batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
        embed = perceptor.encode_image(normalize(batch)).float()
        pMs.append(Prompt(embed, weight, stop).to(device))

    for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights):
        gen = torch.Generator().manual_seed(seed)
        embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen)
        pMs.append(Prompt(embed, weight).to(device))

    def synth(z):
        if args.vqgan_checkpoint == 'vqgan_openimages_f16_8192.ckpt':
            z_q = vector_quantize(z.movedim(1, 3), model.quantize.embed.weight).movedim(3, 1)
        else:
            z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1)
        return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1)

    @torch.no_grad()
    def checkin(i, losses):
        losses_str = ', '.join(f'{loss.item():g}' for loss in losses)
        #tqdm.write(f'i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}')
        out = synth(z)
        #TF.to_pil_image(out[0].cpu()).save('progress.png')
        #display.display(display.Image('progress.png'))

    def ascend_txt(n):
        global i
        out = synth(z)
        iii = perceptor.encode_image(normalize(make_cutouts(out))).float()
        
        result = []

        if args.init_weight:
            # result.append(F.mse_loss(z, z_orig) * args.init_weight / 2)
            result.append(F.mse_loss(z, torch.zeros_like(z_orig)) * ((1/torch.tensor(i*2 + 1))*args.init_weight) / 2)
        for prompt in pMs:
            result.append(prompt(iii))
        img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:]
        img = np.transpose(img, (1, 2, 0))

        m = re.search(r'[\S]*.png',settings.init_images)
        short_init = m.group(0)

        if (n in settings.iter_stops):
            image_title = "{text} on {init} (seed {seed}, step {step}) [iter {n}] [row {id}].png".format(
                text=settings.texts, init=short_init[:-4], n=n, seed=settings.seed, step=settings.step_size, id=settings.id)
            imageio.imwrite('steps/full/' + image_title, np.array(img)) 

        if settings.gdrive_save and (n in settings.iter_stops):
            image_title = "{text} on {init} (seed {seed}, step {step}) [iter {n}] [row {id}].png".format(
                text=settings.texts, init=short_init[:-4], n=n, seed=settings.seed, step=settings.step_size, id=settings.id)
            imageio.imwrite(settings.gdrive_path + image_title, np.array(img))       

        return result

    def train(i):
        opt.zero_grad()
        lossAll = ascend_txt(i)
        if i % args.display_freq == 0:
            checkin(i, lossAll)
        
        loss = sum(lossAll)
        loss.backward()
        opt.step()
        with torch.no_grad():
            z.copy_(z.maximum(z_min).minimum(z_max))

    i = 0
    try:
        with tqdm() as pbar:
            while True:
                train(i)
                if i == settings.max_iterations:
                    break
                i += 1
                pbar.update() 
    except KeyboardInterrupt:
        pass

## Generate images

The code below generates images using the prepared model settings. It allows for multiple runs on many prompts and input images.

If you decide to save images to your Google Drive, please specify the path. If not, leave empty.

In [None]:
#@title Launch parameters
gdrive_save = False #@param {type:"boolean"}
GDRIVE_PATH = '/your/gdrive/path' #@param {type:"string"}


torch.cuda.empty_cache()

try:
    runs_count = 0
    
    for run in multiple_runs[]:
        for i in range(len(run.texts)):
            temp_settings = AttrDict({'texts':run.texts[i],
                                    'width':run.width,
                                    'height':run.height,
                                    'model':run.model,
                                    'images_interval':run.images_interval,
                                    'init_images':run.init_images[i],
                                    'target_images':run.target_images,
                                    'seed':run.seed,
                                    'max_iterations':run.max_iterations,
                                    'iter_stops':run.iter_stops,
                                    'step_size':run.step_size,
                                    'id':run.id[i],
                                    'gdrive_save':gdrive_save,
                                    'gdrive_path':GDRIVE_PATH,
                                    })

            arguments = prepare_model(temp_settings)
            generate_image(temp_settings, arguments)

        with open("batch_log.txt","w+") as f:
            f.write("Finished run: {r}".format(r=runs_count))
        runs_count += 1

except Exception as e:
    with open("batch_error_log.txt","w+") as f:
        f.write("{type}: {err} - {arg}".format(type=str(type(e)), err=str(e), arg=str(e.args)))

Code to .tar.gz the generated files' folder to easily download it to your computer.

In [None]:
!GZIP=-9 tar chvfz name.tar.gz steps/full

# Creating canvas with generated images

Once you have generated images you may want to present them in a more concise way. In order to do that you can use the code below - it will iterate the setps/full folder (or any folder you point towards) and create collages with images.

First of all: point the folders where the images are and where you would like to save them (make sure the paths exist)

In [None]:
PATH = "steps/full"
SAVE_PATH = "canvas_folder"

# pandas limited the length of strings in cells, so we need to do that to get the whole filename into the dataframe
pd.options.display.max_colwidth = 200

Now let's list the images we have. Sometimes a `.ipynb_checkopints` folder will be located, that's when we're remocing it. If you're using macOS sometimes you'll have `.DS_Store` file there as well.

In [None]:
file_list = os.listdir(PATH)
len(file_list)
file_list.sort(reverse=True)
file_list.pop(0) # remove .ipynb_checkpoints
file_list.pop(0) # remove .DS_Store

Then we need our dataframe with all the files for generating canvas. Will be saving the filename in the last column for loading it to the canvas later.

In [None]:
word_list = []

for el in file_list:
    if '.png' in el:
        med_oth = ' '.join(el.split('by')[:1])
        art = ' '.join(el.split(' by ')[1:])
        try:
            it = int(' '.join(el.split('iter ')[1:]).split(']')[0])
        except TypeError as e:
            print("TypeError: {arg}".format(arg=e.args))
        art = ' '.join(art.split(' on ')[:1])
        name = (' '.join(el.split('on ')[1:])).split(' ')[0]


        word_list.append([med_oth, art, it, name])

df = pd.DataFrame(word_list, columns=['medium', 'artist', 'iterations', 'filename'])
df['path'] = file_list

df_c = df[df['artist'] != '']
df_c.head()

The drawing function for now is creating canvas, where the **rows** are artists, and **columns** are input files. It creates a canvas for each prompt separately, and divides artists into groups of size equal to the number of input files.

In [None]:
def canva_artist_filename(df):
    ART = pd.unique(df["artist"])
    M_O = pd.unique(df["medium"])
    NAM = pd.unique(df["filename"])
    ITERS = pd.unique(df["iterations"])
    
    A_L = len(ART)
    M_L = len(M_O)
    F_L = len(NAM) # this defines the size of canvas
    I_L = len(ITERS)
    n = 0
    ART_GR = int(np.ceil(A_L/F_L))
    MAX_IT = max(ITERS)
    
    hfont = {"fontname":"Helvetica"}
    
    for med in range(M_L):
        plt.figure(figsize=(80,80))
        for a_gr in range(ART_GR):
            for fl in range(F_L):
                for art in range(min(F_L,A_L-a_gr*F_L)):
                    img_path = df[((df['filename'] == NAM[fl]) & (df['iterations'] == MAX_IT) & (df['medium'] == M_O[med]) & (df['artist'] == ART[art+(a_gr*F_L)]))]['path'].to_string(index=False)
                    img = cv2.imread(os.path.join(PATH,img_path))
                    try:
                        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                    except Exception as e:
                        print(img_path)
                        print(e)
                        img_rgb = None
                    print(n)
                    n += 1

                    ax = plt.subplot(F_L,F_L,art*F_L+fl+1)
                    try:
                        plt.imshow(img_rgb)
                    except TypeError:
                        pass

                    if art == 0:
                        ax.set_xlabel('\n'+NAM[fl], va='top', fontsize=60, labelpad=100, **hfont)
                        ax.xaxis.set_label_position('top')
                    if fl == 0:
                        plt.ylabel(str(ART[art+(a_gr*F_L)]), fontsize=60, rotation='horizontal', labelpad=200, **hfont)

                    plt.rc('xtick', labelsize=0)
                    plt.rc('ytick', labelsize=0)

            plt.savefig(SAVE_PATH+'/'+'art_gr'+str(a_gr)+"_"+str(M_O[med])+".png", dpi=150, format='png')
            print('Group {} finished!'.format(a_gr))

Finally, draw the canvas.

In [None]:
canva_artist_filename(df_c)