<a href="https://colab.research.google.com/github/olaviinha/NeuralTextToImage/blob/main/ruDALLE_1_1_0rc0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#<font face="Trebuchet MS" size="6">Big Sleep: ruDALLE 1.1.0rc0<font color="#999" size="4">&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;</font><font color="#999" size="4">Neural text-to-image</font><font color="#999" size="4">&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;</font><a href="https://github.com/olaviinha/NeuralImageGeneration" target="_blank"><font color="#999" size="4">Github</font></a>

ruDALLE generates images from text input. 

### Tips
- You may queue infinite text prompts by separating them by semicolon (`;`).
- If you plan to use _a vast number_ of prompts, for the sake of clarity, you may also use the `pre = ['First prompt', 'Second prompt']` cell. To use this cell, enter `pre` in the `generate_image_of` field.
- This notebook loops throught the queue prompt by prompt, as opposed to generating a vast number of images until moving to next prompt.
- Seed `0` = random seed.

In [None]:
#@title #Setup
#@markdown This cell needs to be run only once. It will mount your Google Drive and setup prerequisites.<br>

force_setup = False
pip_packages = 'rudalle==1.1.0rc0 ruclip mtranslate'
main_repository = ''
mount_drive = True #@param {type:"boolean"}

# Download the repo from Github
import os
from google.colab import output
import warnings
warnings.filterwarnings('ignore')
%cd /content/

# inhagcutils
if not os.path.isfile('/content/inhagcutils.ipynb') and force_setup == False:
  !pip -q install import-ipynb {pip_packages}
  !curl -s -O https://raw.githubusercontent.com/olaviinha/inhagcutils/master/inhagcutils.ipynb
import import_ipynb
from inhagcutils import *

# Mount Drive
if mount_drive is True:
  if not os.path.isdir('/content/drive'):
    from google.colab import drive
    drive.mount('/content/drive')
    drive_root = '/content/drive/My Drive'
  if not os.path.isdir('/content/mydrive'):
    os.symlink('/content/drive/My Drive', '/content/mydrive')
    drive_root = '/content/mydrive/'
  drive_root_set = True
else:
  create_dirs(['/content/faux_drive'])
  drive_root = '/content/faux_drive/'

if main_repository is not '':
  !git clone {main_repository}

#---

from tqdm.auto import tqdm
import random

from timeit import default_timer as timer
from datetime import timedelta 
import time 

# setup
from mtranslate import translate
import ruclip
from rudalle.pipelines import generate_images, show, super_resolution, cherry_pick_by_ruclip
from rudalle import get_rudalle_model, get_tokenizer, get_vae, get_realesrgan
from rudalle.utils import seed_everything

# prepare models:
device = 'cuda'
dalle = get_rudalle_model('Malevich', pretrained=True, fp16=True, device=device)
tokenizer = get_tokenizer()
vae = get_vae(dwt=True).to(device)

# pipeline utils:
realesrgan = get_realesrgan('x4', device=device)
clip, processor = ruclip.load('ruclip-vit-base-patch32-384', device=device)
clip_predictor = ruclip.Predictor(clip, processor, device, bs=8)
# setup end setup

import torch
import torchvision
import time

def pil_list_to_torch_tensors(pil_images):
  result = []
  for pil_image in pil_images:
    image = np.array(pil_image, dtype=np.uint8)
    image = torch.from_numpy(image)
    image = image.permute(2, 0, 1).unsqueeze(0)
    result.append(image)
  return torch.cat(result, dim=0)

def save_pils(pil_images, save_dir, key=None):
  os.makedirs(save_dir, exist_ok=True)
  for i, pil_image in enumerate(pil_images):
    #stamp = int(time.time())
    stamp = time.strftime('%Y%m%d-%H%M%S')
    if key is not None:
      filename = save_dir+str(stamp)+'_'+str(i)+'_'+key+'.png'
    else:
      filename = save_dir+str(stamp)+'_'+str(i)+'_'+'.png'
    pil_image.save(filename)

def show(pil_images, nrow=4, size=14, save_dir=None, key=None, show=True):
  if save_dir is not None:
    os.makedirs(save_dir, exist_ok=True)
    for i, pil_image in enumerate(pil_images):
      #stamp = int(time.time())
      stamp = time.strftime('%Y%m%d-%H%M%S')
      if key is not None:
        filename = save_dir+str(stamp)+'_'+str(i)+'_'+key+'.png'
      else:
        filename = save_dir+str(stamp)+'_'+str(i)+'_'+'.png'
      pil_image.save(filename)

  pil_images = [pil_image.convert('RGB') for pil_image in pil_images]
  imgs = torchvision.utils.make_grid(pil_list_to_torch_tensors(pil_images), nrow=nrow)
  if not isinstance(imgs, list):
    imgs = [imgs.cpu()]
  fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(size, size))
  for i, img in enumerate(imgs):
    img = img.detach()
    img = torchvision.transforms.functional.to_pil_image(img)
    # if save_dir is not None:
    #     count = len(glob(join(save_dir, 'group_*.png')))
    #     img.save(join(save_dir, f'group_{count+i}.png'))
    if show:
      axs[0, i].imshow(np.asarray(img))
      axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
  if show:
    fix.show()
    plt.show()

output.clear()
# !nvidia-smi
op(c.ok, 'Setup finished.')

In [None]:
# You may use this to queue prompts
# Enter 'pre' (without quotes) to generate_image_of field to use this queue.

pre = [
  'Enter first prompt here',
  'Second prompt here',
  'etc',
  'etc',
  'as many as you like'
]

In [None]:
#@title Run

generate_image_of = "" #@param {type:"string"}
output_dir = '' #@param {type:"string"}
#max_imgs =  100 #@param {type:"integer"}

seed =  0 #@param {type:"integer"}

# images_per_cycle =  3 #@param {type:"integer"}
# num_cycles = 1 #@param {type:"integer"}

imgs_per_prompt = 3 #@param {type:"slider", min:1, max:6, step:1}
loop = 0 #@param {type:"integer"}

top_k_ = 1024 #param {type:"integer"}
top_p_ = .995 #param {type:"number"}

dir_out = fix_path(drive_root+output_dir)

superres = True

if not os.path.isdir(dir_out):
  os.mkdir(dir_out)

if '://' in generate_image_of:
  iterator = range(0,max_imgs)
  use_api = True
else:
  if generate_image_of is 'pre':
    iterator = pre 
  else:
    iterator = generate_image_of.split(';')
  if loop > 1:
    iterator = iterator * loop
  use_api = False

total = len(iterator)

start = timer()

for txt_index, title in enumerate(iterator):

  if use_api is True:
    title = requests.get(generate_image_of+'/?i='+str(txt_index)).text

  # title = title.replace(' ', '%20')
  text = translate(title, 'ru')

  # if ',' in title:
  #   lauseet = []
  #   titles = title.split(',')
  #   for t in titles:
  #     lauseet.append(translate('ru', t))
  #   text = ', '.join(lauseet)
  # else:
  #   text = translate('ru', title)

  if seed is 0:
    random.seed()
    random_seed = random.randint(0, 2**32)
  else:
    random_seed = seed

  seed_everything(random_seed)

  file_title = ''.join(e for e in title[:80].title() if e.isalnum())

  show_imgs = True if txt_index+1 == total else False

  output.clear()
  op(c.title, str(txt_index+1)+'/'+str(total)+' '+': '+str(imgs_per_prompt*total)+' images')
  op(c.title, 'Prompt', title)
  op(c.title, 'Actual', text)

  pil_images = []
  scores = []
  for top_k, top_p, images_num in [(2048, 0.995, imgs_per_prompt),]:
    _pil_images, _scores = generate_images(text, tokenizer, dalle, vae, top_k=top_k, images_num=images_num, bs=8, top_p=top_p)
    pil_images += _pil_images
    scores += _scores

    if superres is True:
      sr_images = super_resolution(pil_images, realesrgan)
    else:
      sr_images = pil_images
      
    save_pils(sr_images, save_dir=dir_out, key=file_title)

  # sr_images = super_resolution(top_images, realesrgan)
  # show(sr_images, 3)

end = timer()
print()
print('Generated '+str(imgs_per_prompt*total)+' images. Time elapsed', timedelta(seconds=end-start))