# Text to Image tool

Based on [CLIP](https://github.com/openai/CLIP) + FFT from [Lucent](https://github.com/greentfrapp/lucent) // made by [eps696](https://github.com/eps696) [Vadim Epstein]  
thanks to [Ryan Murdock](https://twitter.com/advadnoun), [Jonathan Fly](https://twitter.com/jonathanfly), [Hannu Toyryla](https://twitter.com/htoyryla) for ideas

## Features 
* complex requests:
  * image and/or text as main prompts  
   (composition similarity controlled with [SSIM](https://github.com/Po-Hsun-Su/pytorch-ssim) loss)
  * additional text prompts for fine details and to subtract (avoid) topics
  * criteria inversion (show "the opposite")

* generates [FFT-encoded](https://github.com/greentfrapp/lucent/blob/master/lucent/optvis/param/spatial.py) image (massive detailed textures, a la deepdream)
* fast convergence
* undemanding for RAM - fullHD/4K and above
* saving/loading FFT snapshots to resume processing
* selectable CLIP model


**Run the cell below after each session restart**

Mark `resume` and upload `.pt` file, if you're resuming from the saved params.

In [None]:
#@title General setup

import subprocess
CUDA_version = [s for s in subprocess.check_output(["nvcc", "--version"]).decode("UTF-8").split(", ") if s.startswith("release")][0].split(" ")[-1]
print("CUDA version:", CUDA_version)

if CUDA_version == "10.0":
    torch_version_suffix = "+cu100"
elif CUDA_version == "10.1":
    torch_version_suffix = "+cu101"
elif CUDA_version == "10.2":
    torch_version_suffix = ""
else:
    torch_version_suffix = "+cu110"

!pip install torch==1.7.1{torch_version_suffix} torchvision==0.8.2{torch_version_suffix} -f https://download.pytorch.org/whl/torch_stable.html ftfy regex

try: 
  !pip3 install googletrans==3.1.0a0
  from googletrans import Translator, constants
  # from pprint import pprint
  translator = Translator()
except: pass
!pip install ftfy==5.8

import os
import io
import time
from math import exp
import random
import imageio
import numpy as np
import PIL
from skimage import exposure
from base64 import b64encode

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.autograd import Variable

from IPython.display import HTML, Image, display, clear_output
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
import ipywidgets as ipy
# import glob
from google.colab import output, files

import warnings
warnings.filterwarnings("ignore")

!pip install git+https://github.com/openai/CLIP.git
import clip
!pip install git+https://github.com/Po-Hsun-Su/pytorch-ssim
import pytorch_ssim as ssim

%cd /content
!rm -rf aphantasia
!git clone https://github.com/eps696/aphantasia
%cd aphantasia/
from clip_fft import to_valid_rgb, fft_image, slice_imgs, checkout
from utils import pad_up_to, basename, img_list, img_read
from progress_bar import ProgressIPy as ProgressBar

workdir = '_out'
tempdir = os.path.join(workdir, 'ttt')
os.makedirs(tempdir, exist_ok=True)

clear_output()

resume = False #@param {type:"boolean"}
if resume:
  resumed = files.upload()
  params_pt = list(resumed.values())[0]

def makevid(seq_dir, size=None):
  out_sequence = seq_dir + '/%04d.jpg'
  out_video = seq_dir + '.mp4'
  !ffmpeg -y -v warning -i $out_sequence $out_video
  data_url = "data:video/mp4;base64," + b64encode(open(out_video,'rb').read()).decode()
  wh = '' if size is None else 'width=%d height=%d' % (size, size)
  return """<video %s controls><source src="%s" type="video/mp4"></video>""" % (wh, data_url)

!nvidia-smi -L
print('\nDone!')

Type some `text` and/or upload some image to start.  
`fine_details` input would make micro details follow that topic.  
Put to `subtract` the topics, which you would like to avoid in the result.  
*NB: more prompts = more memory! (handled by auto-decreasing `samples` amount, hopefully you don't need to act).*  
`invert` the whole criteria, if you want to see "the totally opposite".

In [None]:
#@title Input

text = "" #@param {type:"string"}
fine_details = "" #@param {type:"string"}
subtract = "" #@param {type:"string"}
translate = False #@param {type:"boolean"}
invert = False #@param {type:"boolean"}
upload_image = False #@param {type:"boolean"}

if translate:
  text = translator.translate(text, dest='en').text
if upload_image:
  uploaded = files.upload()

Select CLIP `model` (results do vary!). I prefer ViT for consistency.  
`overscan` option produces semi-seamlessly tileable texture (when off, it's more centered).  
`sync` value adds SSIM loss between the output and input image (if there's one), allowing to "redraw" it with controlled similarity. 

Decrease `samples` if you face OOM (it's the main RAM eater).  
Setting `steps` much higher (1000-..) will elaborate details and make tones smoother, but may start throwing texts like graffiti.  
`progressive_grow` may boost macro forms creation (especially with lower `learning_rate`), see more [here](https://github.com/eps696/aphantasia/issues/2).  

In [None]:
#@title Generate

# from google.colab import drive
# drive.mount('/content/GDrive')
# clipsDir = '/content/GDrive/MyDrive/T2I ' + dtNow.strftime("%Y-%m-%d %H%M")

!rm -rf $tempdir

sideX = 1280 #@param {type:"integer"}
sideY = 720 #@param {type:"integer"}
#@markdown > Config
model = 'ViT-B/32' #@param ['ViT-B/32', 'RN101', 'RN50x4', 'RN50']
overscan = True #@param {type:"boolean"}
sync =  0.01 #@param {type:"number"}
contrast = 1. #@param {type:"number"}
#@markdown > Training
steps = 200 #@param {type:"integer"}
samples = 200 #@param {type:"integer"}
learning_rate = .05 #@param {type:"number"}
progressive_grow = False #@param {type:"boolean"}
save_freq = 1 #@param {type:"integer"}

if len(fine_details) > 0:
  samples = int(samples * 0.75)
if len(subtract) > 0:
  samples = int(samples * 0.75)
print(' using %d samples' % samples)

model_clip, _ = clip.load(model)
modsize = 288 if model == 'RN50x4' else 224
xmem = {'RN50':0.5, 'RN50x4':0.16, 'RN101':0.33}
if 'RN' in model:
  samples = int(samples * xmem[model])

norm_in = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
sign = 1. if invert is True else -1.

if upload_image:
  in_img = list(uploaded.values())[0]
  print(' image:', list(uploaded)[0])
  img_in = torch.from_numpy(imageio.imread(in_img).astype(np.float32)/255.).unsqueeze(0).permute(0,3,1,2).cuda()[:,:3,:,:]
  in_sliced = slice_imgs([img_in], samples, modsize, transform=norm_in)[0]
  img_enc = model_clip.encode_image(in_sliced).detach().clone()
  if sync > 0:
    ssim_loss = ssim.SSIM(window_size = 11)
    img_in = F.interpolate(img_in, (sideY, sideX)).float()
  else:
    del img_in
  del in_sliced; torch.cuda.empty_cache()

if len(text) > 2:
  print(' macro:', text)
  if translate:
    translator = Translator()
    text = translator.translate(text, dest='en').text
    print(' translated to:', text) 
  tx = clip.tokenize(text)
  txt_enc = model_clip.encode_text(tx.cuda()).detach().clone()

if len(fine_details) > 0:
  print(' micro:', fine_details)
  if translate:
      translator = Translator()
      fine_details = translator.translate(fine_details, dest='en').text
      print(' translated to:', fine_details) 
  tx2 = clip.tokenize(fine_details)
  txt_enc2 = model_clip.encode_text(tx2.cuda()).detach().clone()

if len(subtract) > 0:
  print(' without:', subtract)
  if translate:
      translator = Translator()
      subtract = translator.translate(subtract, dest='en').text
      print(' translated to:', subtract) 
  tx0 = clip.tokenize(subtract)
  txt_enc0 = model_clip.encode_text(tx0.cuda()).detach().clone()

shape = [1, 3, sideY, sideX]
param_f = fft_image 
# param_f = pixel_image
# learning_rate = 1.
init_pt = params_pt if resume is True else None
params, image_f = param_f(shape, resume=init_pt)
image_f = to_valid_rgb(image_f)

if progressive_grow is True:
  lr1 = learning_rate * 2
  lr0 = lr1 * 0.01
else:
  lr0 = learning_rate
optimizer = torch.optim.Adam(params, lr0)

def save_img(img, fname=None):
  img = np.array(img)[:,:,:]
  img = np.transpose(img, (1,2,0))  
  img = np.clip(img*255, 0, 255).astype(np.uint8)
  if fname is not None:
    imageio.imsave(fname, np.array(img))
    imageio.imsave('result.jpg', np.array(img))

def checkout(num):
  with torch.no_grad():
    img = image_f(contrast=contrast).cpu().numpy()[0]
  save_img(img, os.path.join(tempdir, '%04d.jpg' % num))
  outpic.clear_output()
  with outpic:
    display(Image('result.jpg'))

def train(i):
  loss = 0
  img_out = image_f()

  micro = False if len(fine_details) > 0 else None
  imgs_sliced = slice_imgs([img_out], samples, modsize, norm_in, overscan=overscan, micro=micro)
  out_enc = model_clip.encode_image(imgs_sliced[-1])
  if upload_image:
      loss += sign * torch.cosine_similarity(img_enc, out_enc, dim=-1).mean()
  if len(text) > 0: # input text
      loss += sign * torch.cosine_similarity(txt_enc, out_enc, dim=-1).mean()
  if len(subtract) > 0: # subtract text
      loss += -sign * torch.cosine_similarity(txt_enc0, out_enc, dim=-1).mean()
  if sync > 0 and upload_image: # image composition sync
      loss *= 1. + sync * (steps/(i+1) * ssim_loss(img_out, img_in) - 1)
  if len(fine_details) > 0: # input text for micro details
      imgs_sliced = slice_imgs([img_out], samples, modsize, norm_in, overscan=overscan, micro=True)
      out_enc2 = model_clip.encode_image(imgs_sliced[-1])
      loss += sign * torch.cosine_similarity(txt_enc2, out_enc2, dim=-1).mean()
      del out_enc2; torch.cuda.empty_cache()
  del img_out, imgs_sliced, out_enc; torch.cuda.empty_cache()

  if progressive_grow is True:
    lr_cur = lr0 + (i / steps) * (lr1 - lr0)
    for g in optimizer.param_groups: 
      g['lr'] = lr_cur

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
  if i % save_freq == 0:
    checkout(i // save_freq)

outpic = ipy.Output()
outpic

pbar = ProgressBar(steps)
for i in range(steps):
  train(i)
  _ = pbar.upd()

HTML(makevid(tempdir))
torch.save(params, tempdir + '.pt')
files.download(tempdir + '.pt')
files.download(tempdir + '.mp4')
