# Text to Image tool

Based on [CLIP](https://github.com/openai/CLIP) + VQGAN from [Taming Transformers](https://github.com/CompVis/taming-transformers) // made by Vadim Epstein [[eps696](https://github.com/eps696)]  
thanks to [Ryan Murdock](https://twitter.com/advadnoun), [Jonathan Fly](https://twitter.com/jonathanfly), [Hannu Toyryla](https://twitter.com/htoyryla) for ideas

## Features 
* complex requests:
  * image and/or text as main prompts  
   (composition similarity controlled with [SSIM](https://github.com/Po-Hsun-Su/pytorch-ssim) loss)
  * additional text prompt to subtract (suppress) topics
  * criteria inversion (show "the opposite")

* various CLIP models (including multi-language from [SBERT](https://sbert.net))


**Run the cell below after each session restart**



In [None]:
#@title General setup

!pip install torchtext==0.8.0 torch==1.7.1 pytorch-lightning==1.2.2 torchvision==0.8.2 ftfy==5.8 regex

try: 
  !pip3 install googletrans==3.1.0a0
  from googletrans import Translator, constants
  translator = Translator()
except: pass

import os
import io
import time
from math import exp
import random
import imageio
import numpy as np
import PIL
from base64 import b64encode
import moviepy, moviepy.editor

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.autograd import Variable

from IPython.display import HTML, Image, display, clear_output
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
import ipywidgets as ipy
from google.colab import output, files

import warnings
warnings.filterwarnings("ignore")

!pip install git+https://github.com/openai/CLIP.git
import clip
!pip install sentence_transformers
from sentence_transformers import SentenceTransformer
!pip install git+https://github.com/Po-Hsun-Su/pytorch-ssim
import pytorch_ssim as ssim

%cd /content
!rm -rf aphantasia
!git clone https://github.com/eps696/aphantasia
%cd aphantasia/
from clip_fft import slice_imgs, checkout
from utils import pad_up_to, basename, img_list, img_read, plot_text
from progress_bar import ProgressIPy as ProgressBar

!pip install omegaconf>=2.0.0 # pytorch-lightning==1.0.8
# !pip install PyYAML==5.3.1 torchtext==0.8.0 pytorch-lightning==1.2.2
# !pip install git+https://github.com/PyTorchLightning/pytorch-lightning
import pytorch_lightning as pl
!git clone https://github.com/CompVis/taming-transformers
!mv taming-transformers/* ./
import yaml
from omegaconf import OmegaConf
from taming.models.vqgan import VQModel

if not os.path.isdir('/content/models_TT'):
  !mkdir -p /content/models_TT
def getm(url, path):
  if os.path.isfile(path) and os.stat(path).st_size > 0: 
    print(' already exists', path, os.stat(path).st_size)
  else:
    !wget $url -O $path
getm('https://heibox.uni-heidelberg.de/f/140747ba53464f49b476/?dl=1', '/content/models_TT/last-1024.ckpt')
getm('https://heibox.uni-heidelberg.de/f/6ecf2af6c658432c8298/?dl=1', '/content/models_TT/model-1024.yaml')
getm('https://heibox.uni-heidelberg.de/f/867b05fc8c4841768640/?dl=1', '/content/models_TT/last-16384.ckpt')
getm('https://heibox.uni-heidelberg.de/f/274fb24ed38341bfa753/?dl=1', '/content/models_TT/model-16384.yaml')

workdir = '_out'
tempdir = os.path.join(workdir, 'ttt')

clear_output()

# resume = False #@param {type:"boolean"}
# if resume:
#   resumed = files.upload()
#   params_pt = list(resumed.values())[0]
#   params_pt = torch.load(io.BytesIO(params_pt))

def load_config(config_path, display=False):
    config = OmegaConf.load(config_path)
    if display:
        print(yaml.dump(OmegaConf.to_container(config)))
    return config

def load_vqgan(config, ckpt_path=None):
    model = VQModel(**config.model.params)
    if ckpt_path is not None:
        sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
        missing, unexpected = model.load_state_dict(sd, strict=False)
    return model.eval()

def vqgan_image(model, z):
    x = model.decode(z)
    x = (x+1.)/2.
    return x

class latents(torch.nn.Module):
    def __init__(self, shape):
        super(latents, self).__init__()
        init_rnd = torch.zeros(shape).normal_(0.,4.)
        self.lats = torch.nn.Parameter(init_rnd.cuda())
    def forward(self):
        return self.lats # [1,256, h//16, w//16]

def makevid(seq_dir, size=None):
  # out_sequence = seq_dir + '/%05d.jpg'
  out_video = seq_dir + '.mp4'
  # !ffmpeg -y -v quiet -i $out_sequence $out_video
  moviepy.editor.ImageSequenceClip(img_list(seq_dir), fps=25).write_videofile(out_video, verbose=False) # , ffmpeg_params=ffmpeg_params, logger=None
  data_url = "data:video/mp4;base64," + b64encode(open(out_video,'rb').read()).decode()
  wh = '' if size is None else 'width=%d height=%d' % (size, size)
  return """<video %s controls><source src="%s" type="video/mp4"></video>""" % (wh, data_url)

!nvidia-smi -L
print('\nDone!')

Type some `text` and/or upload some image to start.  
Put to `subtract` the topics, which you would like to avoid in the result.  
`invert` the whole criteria, if you want to see "the totally opposite".

Options for non-English languages (use only one of them!):  
`multilang` = use multi-language model, trained with ViT  
`translate` = use Google translate (works with any visual model)

In [None]:
#@title Input

text = "" #@param {type:"string"}
subtract = "" #@param {type:"string"}
multilang = False #@param {type:"boolean"}
translate = False #@param {type:"boolean"}
invert = False #@param {type:"boolean"}
upload_image = False #@param {type:"boolean"}

if translate:
  text = translator.translate(text, dest='en').text
if upload_image:
  uploaded = files.upload()

### Settings

Select CLIP visual `model` (results do vary!). I prefer ViT for consistency (and it's the only native multi-language option).  
Select `VQGAN_size` - it also changes the result (I prefer codebook `1024`, despite it's smaller than `16384`).  
`overscan` option produces more uniform composition (when off, it's more centered).  
`sync` value adds SSIM loss between the output and input image (if there's one), allowing to "redraw" it with controlled similarity. 

Decrease `samples` if you face OOM.  


In [None]:
#@title Generate

!rm -rf $tempdir
os.makedirs(tempdir, exist_ok=True)

sideX = 800 #@param {type:"integer"}
sideY = 600 #@param {type:"integer"}
#@markdown > Config
model = 'ViT-B/32' #@param ['ViT-B/32', 'RN101', 'RN50x4', 'RN50']
VQGAN_size = 1024 #@param [1024, 16384]
overscan = False #@param {type:"boolean"}
sync =  0.3 #@param {type:"number"}
#@markdown > Training
steps = 300 #@param {type:"integer"}
samples = 32 #@param {type:"integer"}
learning_rate = 0.1 #@param {type:"number"}
save_freq = 1 #@param {type:"integer"}

# #@markdown > Tricks
# no_text = 0.07 #@param {type:"number"}
# enhance = 0. #@param {type:"number"}
# diverse = -enhance
# expand = abs(enhance)

if multilang: model = 'ViT-B/32' # sbert model is trained with ViT

if len(subtract) > 0:
  samples = int(samples * 0.75)
print(' using %d samples' % samples)

model_clip, _ = clip.load(model)
modsize = 288 if model == 'RN50x4' else 224
xmem = {'RN50':0.5, 'RN50x4':0.16, 'RN101':0.33}
if 'RN' in model:
  samples = int(samples * xmem[model])

if multilang is True:
    model_lang = SentenceTransformer('clip-ViT-B-32-multilingual-v1').cuda()

def enc_text(txt):
    if multilang is True:
        emb = model_lang.encode([txt], convert_to_tensor=True, show_progress_bar=False)
    else:
        emb = model_clip.encode_text(clip.tokenize(txt).cuda())
    return emb.detach().clone()

# if diverse != 0:
#  samples = int(samples * 0.5)
        
norm_in = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
sign = 1. if invert is True else -1.

if upload_image:
  in_img = list(uploaded.values())[0]
  print(' image:', list(uploaded)[0])
  img_in = torch.from_numpy(imageio.imread(in_img).astype(np.float32)/255.).unsqueeze(0).permute(0,3,1,2).cuda()[:,:3,:,:]
  in_sliced = slice_imgs([img_in], samples, modsize, transform=norm_in)[0]
  img_enc = model_clip.encode_image(in_sliced).detach().clone()
  if sync > 0:
    overscan = True
    ssim_loss = ssim.SSIM(window_size = 11)
    ssim_size = [sideY//4, sideX//4]
    img_in = F.interpolate(img_in, ssim_size).float()
    # img_in = F.interpolate(img_in, (sideY, sideX)).float()
  else:
    del img_in
  del in_sliced; torch.cuda.empty_cache()

if len(text) > 0:
  print(' text:', text)
  if translate:
    translator = Translator()
    text = translator.translate(text, dest='en').text
    print(' translated to:', text) 
  txt_enc = enc_text(text)
#  if no_text > 0:
#      txt_plot = torch.from_numpy(plot_text(text, modsize)/255.).unsqueeze(0).permute(0,3,1,2).cuda()
#      txt_plot_enc = model_clip.encode_image(txt_plot).detach().clone()

if len(subtract) > 0:
  print(' without:', subtract)
  if translate:
      translator = Translator()
      subtract = translator.translate(subtract, dest='en').text
      print(' translated to:', subtract) 
  txt_enc0 = enc_text(subtract)

if multilang is True: del model_lang

config_vqgan = load_config("/content/models_TT/model-%d.yaml" % int(VQGAN_size), display=False)
model_vqgan = load_vqgan(config_vqgan, ckpt_path="/content/models_TT/last-%d.ckpt" % int(VQGAN_size)).cuda()

shape = [1, 256, sideY//16, sideX//16]
lats = latents(shape).cuda()
optimizer = torch.optim.Adam(lats.parameters(), learning_rate)

def save_img(img, fname=None):
  img = np.array(img)[:,:,:]
  img = np.transpose(img, (1,2,0))  
  img = np.clip(img*255, 0, 255).astype(np.uint8)
  if fname is not None:
    imageio.imsave(fname, np.array(img))
    imageio.imsave('result.jpg', np.array(img))

def checkout(num):
  with torch.no_grad():
    img = vqgan_image(model_vqgan, lats()).cpu().numpy()[0]
  save_img(img, os.path.join(tempdir, '%04d.jpg' % num))
  outpic.clear_output()
  with outpic:
    display(Image('result.jpg'))

prev_enc = 0
def train(i):
  loss = 0
  img_out = vqgan_image(model_vqgan, lats())

  imgs_sliced = slice_imgs([img_out], samples, modsize, norm_in, overscan=overscan)
  out_enc = model_clip.encode_image(imgs_sliced[-1])
#  if diverse != 0:
#    imgs_sliced = slice_imgs([vqgan_image(model_vqgan, lats())], samples, modsize, norm_in, overscan=overscan)
#    out_enc2 = model_clip.encode_image(imgs_sliced[-1])
#    loss += diverse * torch.cosine_similarity(out_enc, out_enc2, dim=-1).mean()
#    del out_enc2; torch.cuda.empty_cache()
  if upload_image:
      loss += sign * 0.5 * torch.cosine_similarity(img_enc, out_enc, dim=-1).mean()
  if len(text) > 0: # input text
      loss += sign * torch.cosine_similarity(txt_enc, out_enc, dim=-1).mean()
#      if no_text > 0:
#          loss -= sign * no_text * torch.cosine_similarity(txt_plot_enc, out_enc, dim=-1).mean()
  if len(subtract) > 0: # subtract text
      loss += -sign * 0.5 * torch.cosine_similarity(txt_enc0, out_enc, dim=-1).mean()
  if sync > 0 and upload_image: # image composition sync
      loss -= sync * ssim_loss(F.interpolate(img_out, ssim_size).float(), img_in)
#  if expand > 0:
#    global prev_enc
#    if i > 0:
#      loss += expand * torch.cosine_similarity(out_enc, prev_enc, dim=-1).mean()
#    prev_enc = out_enc.detach()
  del img_out, imgs_sliced, out_enc; torch.cuda.empty_cache()

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
  if i % save_freq == 0:
    checkout(i // save_freq)

outpic = ipy.Output()
outpic

pbar = ProgressBar(steps)
for i in range(steps):
  train(i)
  _ = pbar.upd()

HTML(makevid(tempdir))
torch.save(lats.lats, tempdir + '.pt')
files.download(tempdir + '.pt')
files.download(tempdir + '.mp4')
