<a href="https://colab.research.google.com/github/olaviinha/NeuralTextToImage/blob/main/LatentVision_rynmurdock.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#<font face="Trebuchet MS" size="6">Big Sleep: Latent vision <font color="#999" size="4">&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;</font><font color="#999" size="4">CLIP+VQGAN: Neural text-to-image</font><font color="#999" size="4">&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;</font><a href="https://github.com/olaviinha/NeuralImageGeneration" target="_blank"><font color="#999" size="4">Github</font></a>

Big Sleep generates images from text input. It's originally a combination of [CLIP](https://github.com/openai/CLIP) by OpenAI and [BigGAN](https://arxiv.org/abs/1809.11096) by Andrew Brock et al., a concept introduced by [Ryan Murdock](https://github.com/rynmurdock) in his [original notebook](https://colab.research.google.com/drive/1NCceX2mbiKOSlAd_o7IU7nA9UskKN5WR?usp=sharing). This notebook is based on another similar implementation (CLIP+[VQGAN](https://arxiv.org/abs/2012.09841)) by the same author. It generates 2 images on each run.

<hr size="1" color="#666">

### Tips
- Enter a simple string of text to `generate_image_of` field. You may also use a semicolon `;` as a separator to batch process multiple strings of texts to images in one go, and/or pipe `|` to train the image on multiple strings of text. If field is left empty, a random blog headline will be used.
- Enter `output_dir` path relative to your Google Drive root, or leave blank to not save output anywhere outside this notebook. Each run of the _Sleep_ cell will **create a new subdirectory** under `output_dir`, under which all material will be saved.
- In vast majority of cases, over 400 iterations seems to be a waste of time.
- Setup cell will say you need to restart runtime. You can ignore it and not restart runtime.

In [None]:
#@title #Setup
#@markdown This cell needs to be run only once. It will mount your Google Drive and setup prerequisites.
 
force_setup = False
pip_packages = 'kornia ftfy regex tqdm einops omegaconf==2.0.0 pytorch-lightning==1.0.8'
main_repository = ''
 
import os
from google.colab import output
import warnings
warnings.filterwarnings('ignore')
%cd /content/
 
# inhagcutils
if not os.path.isfile('/content/inhagcutils.ipynb') and force_setup == False:
  !pip -q install import-ipynb {pip_packages}
  !pip uninstall torchtext --yes
  !curl -s -O https://raw.githubusercontent.com/olaviinha/inhagcutils/master/inhagcutils.ipynb
import import_ipynb
from inhagcutils import *
 
# Mount Drive
if not os.path.isdir('/content/drive') and force_setup == False:
  from google.colab import drive
  drive.mount('/content/drive')
 
# Drive symlink
if not os.path.isdir('/content/mydrive') and force_setup == False:
  os.symlink('/content/drive/My Drive', '/content/mydrive')
  drive_root_set = True
drive_root = '/content/mydrive/'
 
!git clone https://github.com/openai/CLIP.git
!git clone https://github.com/CompVis/taming-transformers.git
 
dir_tmp = '/content/tmp/'
dir_steps1 = '/content/tmp/steps_1/'
dir_steps2 = '/content/tmp/steps_2/'
dir_initial = '/content/tmp/init/'
dir_target = '/content/tmp/target/'
create_dirs([dir_tmp, dir_steps1, dir_steps2,   dir_initial, dir_target])
 
#-----
 
import torch
import numpy as np
import torchvision
import torchvision.transforms.functional as TF
import kornia
 
import PIL
from PIL import ImageFile, Image
import matplotlib.pyplot as plt
 
import os
import random
import imageio
from IPython import display
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
 
import glob

import itertools
# from skimage import img_as_ubyte
from subprocess import Popen, PIPE
from tqdm.notebook import tqdm

#--
 
from CLIP import clip
perceptor, preprocess = clip.load('ViT-B/32', jit=False)
perceptor.eval()
clip.available_models()
perceptor.visual.input_resolution
scaler = 1
 
im_shape = [512, 512, 3]
sideX, sideY, channels = im_shape
batch_size = 2
 
 
#--

def normalize8(I):
  mn = I.min()
  mx = I.max()
  mx -= mn
  I = ((I - mn)/mx) * 255
  return I.astype(np.uint8)
 
def displ(img, idx, i, dir, pre_scaled=True):
  global dir_steps, dir_progress, save_all_steps, iterations
  # dir = path_leaf(dir)+'_'+str(idx+1)+'/'
  img = np.array(img)[:,:,:]
  img = np.transpose(img, (1, 2, 0))
  if not pre_scaled:
    img = scale(img, 48*4, 32*4)
  imgarr = np.array(img)
  imgarr = normalize8(imgarr)
  imageio.imwrite(dir + str(i).zfill(4) + '.png', imgarr)
  if i == iterations:
    imageio.imwrite(dir + str(i).zfill(4) + '.png', imgarr)
  return display.Image(str(3)+'.png')
 
def gallery(array, ncols=2):
    nindex, height, width, intensity = array.shape
    nrows = nindex//ncols
    assert nindex == nrows*ncols
    # want result.shape = (height*nrows, width*ncols, intensity)
    result = (array.reshape(nrows, ncols, height, width, intensity)
              .swapaxes(1,2)
              .reshape(height*nrows, width*ncols, intensity))
    return result
 
def card_padded(im, to_pad=3):
  return np.pad(np.pad(np.pad(im, [[1,1], [1,1], [0,0]],constant_values=0), [[2,2], [2,2], [0,0]],constant_values=1),
            [[to_pad,to_pad], [to_pad,to_pad], [0,0]],constant_values=0)
 
def get_all(img):
  print('get all')
  img = np.transpose(img, (0,2,3,1))
  cards = np.zeros((img.shape[0], sideX+12, sideY+12, 3))
  for i in range(len(img)):
    cards[i] = card_padded(img[i])
  print(img.shape)
  cards = gallery(cards)
  imageio.imwrite(str(3) + '.png', np.array(cards))
  return display.Image(str(3)+'.png')
  
 
#-----
 
%cd /content/taming-transformers
 
!mkdir -p logs/vqgan_imagenet_f16_16384/checkpoints
!mkdir -p logs/vqgan_imagenet_f16_16384/configs
 
if len(os.listdir('logs/vqgan_imagenet_f16_16384/checkpoints/')) == 0:
  !wget 'https://heibox.uni-heidelberg.de/f/867b05fc8c4841768640/?dl=1' -O 'logs/vqgan_imagenet_f16_16384/checkpoints/last.ckpt' 
  !wget 'https://heibox.uni-heidelberg.de/f/274fb24ed38341bfa753/?dl=1' -O 'logs/vqgan_imagenet_f16_16384/configs/model.yaml' 
 
 
 
# !cp /content/drive/MyDrive/vqgan_imagenet_f16_16384-20210325T002625Z-001.zip /content/vq.zip
# !unzip /content/vq.zip -d /content/taming-transformers/logs/
 
import sys
sys.path.append(".")
 
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
import yaml
import torch
from omegaconf import OmegaConf
from taming.models.vqgan import VQModel
 
def load_config(config_path, display=False):
  config = OmegaConf.load(config_path)
  if display:
    print(yaml.dump(OmegaConf.to_container(config)))
  return config
 
def load_vqgan(config, ckpt_path=None):
  model = VQModel(**config.model.params)
  if ckpt_path is not None:
    sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
    missing, unexpected = model.load_state_dict(sd, strict=False)
  return model.eval()
 
def preprocess_vqgan(x):
  x = 2.*x - 1.
  return x
 
def custom_to_pil(x):
  x = x.detach().cpu()
  x = torch.clamp(x, -1., 1.)
  x = (x + 1.)/2.
  x = x.permute(1,2,0).numpy()
  x = (255*x).astype(np.uint8)
  x = Image.fromarray(x)
  if not x.mode == "RGB":
    x = x.convert("RGB")
  return x
 
def reconstruct_with_vqgan(x, model):
  # could also use model(x) for reconstruction but use explicit encoding and decoding here
  z, _, [_, _, indices] = model.encode(x)
  print(f"VQGAN: latent shape: {z.shape[2:]}")
  xrec = model.decode(z)
  return xrec
 
class Pars(torch.nn.Module):
  def __init__(self):
    super(Pars, self).__init__()
    self.normu = .5*torch.randn(batch_size, 256, sideX//16, sideY//16).cuda()
    self.normu = torch.nn.Parameter(torch.sinh(1.9*torch.arcsinh(self.normu)))
  def forward(self):
    return self.normu.clip(-6, 6)
      
def model(x):
  o_i2 = x
  o_i3 = model16384.post_quant_conv(o_i2)
  i = model16384.decoder(o_i3)
  return i
 
config16384 = load_config("logs/vqgan_imagenet_f16_16384/configs/model.yaml", display=False)
model16384 = load_vqgan(config16384, ckpt_path="logs/vqgan_imagenet_f16_16384/checkpoints/last.ckpt").to(DEVICE)
 
 
 
output.clear()
# !nvidia-smi
op(c.ok, 'Setup finished.')

In [None]:
#@markdown <br>
 
#@markdown #S̛̞̩͎͓ ̦̤͉͚̏ ̧̠͋͘ͅl͕̞͕̝͗̐͘.̠̰̳̫̈́̚ ̡͉̼̩̬̈́̇͒͘ȩ̨͎͛̔͆͊̏͜ͅ.͕̩̹̠̕͜ ̛̦̦̮e̢͐͊͂̀̊ͅ ̜̙̝̊͋ ̬̝̱̱͗p̮̎̽̌
 
#@markdown <br>
 
torch.cuda.empty_cache()
clean_dirs([dir_steps1, dir_steps2, dir_initial, dir_target])
generate_image_of = "" #@param {type:"string"}
output_dir = '' #@param {type:"string"}
 
#@markdown <hr color="#666" size="1">
#@markdown <font size="1">&nbsp;</font>
 
#@markdown ### Advanced settings
 
iterations = 400 #@param {type:"slider", min:0, max:2000, step:100}
save_every = 50 #@param {type:"slider", min:0, max:500, step:1}
display_save_every = True #@param {type:"boolean"}
create_video = False #@param {type:"boolean"}
 
save_all_steps = False
remove_interrupted = True
repetitions = 0
 
#
# --- Very advanced settings ---------------------------------

# #@markdown <hr color="#666" size="1">
# #@markdown <font size="1">&nbsp;</font>
 
# #@markdown ### Very advanced settings
# repetitions = 0 #@param {type:"slider", min:0, max:20, step:1}
# save_all_steps = False #@param {type:"boolean"}
# remove_interrupted = True #@param {type:"boolean"}

# --- //Very advanced settings -------------------------------
#
 

 
text = generate_image_of
iterations = iterations+1

if save_all_steps is True:
  dir_steps1 = dir_output+'steps_1/'
  dir_steps1 = dir_output+'steps_2/'
  create_dirs([dir_steps1, dir_steps2])
 
if output_dir is '' or not output_dir:
  drive_root = '/content/fauxdrive/'
  output_dir = 'output'
  create_dirs([drive_root, output_dir])
else:
  drive_root = '/content/mydrive/'
 
if text is not '':
  if check_input_type(drive_root+text) is 'file':
    input_txt = drive_root+text
    with open(input_txt) as f:
      texts = f.readlines()
      texts = [x.strip() for x in texts] 
  elif ";" in text:
    texts = text.split(';')
    texts = [text.strip() for text in texts]
  else:
    texts = [text]
else:
  text = requests.get('https://api.inha.asia/headline/').text
  texts = [text]
if repetitions > 0:
  texts = list(itertools.chain.from_iterable(itertools.repeat(x, repetitions) for x in texts))
 
uniq_id = gen_id()
repeat_index = 1
 
#---------------
 
def keyboardInterruptHandler():
  global dir_output, uniq_id
  op(c.warn, 'Interrupted!', 'Cleaning up...')
  remove_dirs([dir_output])
  print('Run', uniq_id, 'directory and content removed:', dir_output)
  sys.exit()

def augment(into, cutn=32):
  into = torch.nn.functional.pad(into, (sideX//2, sideX//2, sideX//2, sideX//2), mode='constant', value=0)
  into = augs(into)
  p_s = []
  for ch in range(cutn):
    size = int(torch.normal(1.2, .3, ()).clip(.43, 1.9) * sideX)
    if ch > cutn - 4:
      size = int(sideX*1.4)
    offsetx = torch.randint(0, int(sideX*2 - size), ())
    offsety = torch.randint(0, int(sideX*2 - size), ())
    apper = into[:, :, offsetx:offsetx + size, offsety:offsety + size]
    apper = torch.nn.functional.interpolate(apper, (int(224*scaler), int(224*scaler)), mode='bilinear', align_corners=True)
    p_s.append(apper)
  into = torch.cat(p_s, 0)
  into = into + up_noise*torch.rand((into.shape[0], 1, 1, 1)).cuda()*torch.randn_like(into, requires_grad=False)
  return into

def checkin(loss, i, xtype, display_image):
  global up_noise, dir_progress1, dir_progress2, dir_steps1, dir_steps2
  if xtype is 'step':
    xdir = dir_steps1.replace('_1/', '')
  if xtype is 'progress':
    xdir = dir_progress1.replace('_1/', '')
  with torch.no_grad():
    alnot = model(lats()).float()
    alnot = augment((((alnot).clip(-1, 1) + 1) / 2), cutn=1)
    # for allls in alnot.cpu():
    #   displ(allls)
    #   display.display(display.Image(str(3)+'.png'))
    alnot = (model(lats()).cpu().clip(-1, 1) + 1) / 2
    for idx, allls in enumerate(alnot.cpu()):
      dir = xdir.replace(path_leaf(xdir), path_leaf(xdir)+'_'+str(idx+1)+'/')
      displ(allls, idx, i, dir)
      if display_image is True: 
        display.display(display.Image(dir+str(i).zfill(4)+'.png'))
        op(c.ok, '^Iteration image saved as:', dir.replace(drive_root, '')+str(i).zfill(4)+'.png\n')
      elif xtype is 'progress':
        op(c.ok, 'Iteration image saved as:', dir.replace(drive_root, '')+str(i).zfill(4)+'.png\n')
 
def ascend_txt():
  global up_noise
  out = model(lats())
  into = augment((out.clip(-1, 1) + 1) / 2)
  into = nom(into)
  iii = perceptor.encode_image(into)
  t_x = torch.cosine_similarity(t_not, iii, -1)
  all_s = torch.cosine_similarity(t, iii, -1)
  return [3*t_x, -10*all_s]
  
def train(i):
  global dec, up_noise
  loss1 = ascend_txt()
  loss = loss1[0] + loss1[1]
  loss = loss.mean()
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  if i > 400:
    for g in optimizer.param_groups:
      g['lr'] *= .995
      g['lr'] = max(g['lr'], .1)
    dec *= .995
  if torch.abs(lats()).max() > 5:
    for g in optimizer.param_groups:
      g['weight_decay'] = dec
  else:
    for g in optimizer.param_groups:
      g['weight_decay'] = 0
  # if itt % 100 == 0:
  #   checkin(loss1)
  if save_all_steps is True or create_video is True or i == iterations-1:
    checkin(loss1, i, 'step', display_image=False)
  if i % save_every == 0:
    checkin(loss1, i, 'progress', display_image=display_save_every)
 
def loop():
  global itt, iterations
  try:
    for asatreat in range(iterations):
      train(itt)
      itt+=1
  # pbar = tqdm(total=iterations)
  #   for i in range(10):
  #     while True:
  #       train(itt)
  #       if itt == iterations:
  #         break
  #       itt += 1
  #       tqdm.update()
  except KeyboardInterrupt:
    if remove_interrupted: keyboardInterruptHandler()
    
#---------------
 
for text_input in texts:
 
  torch.cuda.empty_cache()
 
  display_text = text_input
  title = text_input.split("|")[0].title()
  file_title = ''.join(e for e in title if e.isalnum())
 
  id = uniq_id+'_'+file_title
  if repetitions > 0:
    id = uniq_id+'_'+str(repeat_index)+'_'+file_title
 
  dir_output = fix_path(drive_root+output_dir)+id+'/'
  dir_progress1 = dir_output+'progress_1/'
  dir_progress2 = dir_output+'progress_2/'
  create_dirs([dir_output, dir_progress1, dir_progress2])
 
  text_not = '''disconnected, confusing, incoherent, watermarks, text, writing'''
 
  dec = .1
  lats = Pars().cuda()
  mapper = [lats.normu]
  optimizer = torch.optim.AdamW([{'params': mapper, 'lr': .5}], weight_decay=dec)
  eps = 0
  tx = clip.tokenize(text_input)
  t = perceptor.encode_text(tx.cuda()).detach().clone()
  t_not = clip.tokenize(text_not)
  t_not = perceptor.encode_text(t_not.cuda()).detach().clone()
  nom = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
 
  augs = torch.nn.Sequential(
      torchvision.transforms.RandomHorizontalFlip(),
      torchvision.transforms.RandomAffine(24, (.1, .1), fill=0)
  ).cuda()
 
  up_noise = .11
  itt = 0
 
  output.clear()
  op(c.title, '\nGenerating image of', display_text)
  op(c.title, 'Run ID:', uniq_id)
  if repetitions > 0:
    op(c.title, 'Repetition:', repeat_index)
  op(c.okb, 'Sweet dreams.\n')
 
  with torch.no_grad():
    al = (model(lats()).cpu().clip(-1, 1) + 1) / 2
    for allls in al:
      # displ(allls[:3])
      print('\n')
 
  loop()

  last_step1 = dir_steps1+str(iterations-1).zfill(4)+'.png'
  last_step2 = dir_steps2+str(iterations-1).zfill(4)+'.png'
  fin_out1 = dir_output+file_title+'_1.png'
  fin_out2 = dir_output+file_title+'_2.png'
  !cp {last_step1} {fin_out1}
  !cp {last_step2} {fin_out2}
  display_fin1 = fin_out1.replace(drive_root, '')
  display_fin2 = fin_out2.replace(drive_root, '')
  op(c.ok, '\nFinal images saved as')
  print('-', display_fin1)
  print('-', display_fin2)

  if create_video is True:
    op(c.title, '\nGenerating video')

    init_frame = 1
    last_frame = iterations-1

    fps = 30
    output_video1 = dir_output+file_title+'_1.mp4'
    output_video2 = dir_output+file_title+'_2.mp4'

    frames1 = []
    for i in range(init_frame,last_frame): #
      filename1 = f"{dir_steps1}/{i:04}.png"
      frames1.append(Image.open(filename1))
    p = Popen(['ffmpeg', '-y', '-f', 'image2pipe', '-vcodec', 'png', '-r', str(fps), '-i', '-', '-vcodec', 'libx264', '-r', str(fps), '-pix_fmt', 'yuv420p', '-crf', '13', '-preset', 'veryslow', output_video1], stdin=PIPE)
    for im in tqdm(frames1):
      im.save(p.stdin, 'PNG')
    p.stdin.close()
    p.wait()
    fin_vid1 = fin_out1.replace('.png', '.mp4')

    frames2 = []
    for i in range(init_frame,last_frame): #
      filename2 = f"{dir_steps2}/{i:04}.png"
      frames2.append(Image.open(filename2))
    p = Popen(['ffmpeg', '-y', '-f', 'image2pipe', '-vcodec', 'png', '-r', str(fps), '-i', '-', '-vcodec', 'libx264', '-r', str(fps), '-pix_fmt', 'yuv420p', '-crf', '13', '-preset', 'veryslow', output_video2], stdin=PIPE)
    for im in tqdm(frames2):
      im.save(p.stdin, 'PNG')
    p.stdin.close()
    p.wait()
    fin_vid2 = fin_out2.replace('.png', '.mp4')

    op(c.ok, 'Videos saved as')
    print('-', fin_vid1)
    print('-', fin_vid2)

  if repeat_index is repetitions:
    repeat_index = 1
  else:
    repeat_index += 1

op(c.title, '\nFIN.')