<a href="https://colab.research.google.com/github/olaviinha/NeuralTextToImage/blob/main/dalle_mini.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#<font face="Trebuchet MS" size="6">DALL-E MINI<font color="#999" size="4">&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;</font><font color="#999" size="4">Neural text-to-image</font><font color="#999" size="4">&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;</font><a href="https://github.com/olaviinha/NeuralImageGeneration" target="_blank"><font color="#999" size="4">Github</font></a>

Dalle-mini generates images from text input. 

### Tips
- You may queue infinite text prompts by separating them by semicolon (`;`).

In [None]:
#@title #Setup
#@markdown This cell needs to be run only once. It will mount your Google Drive and setup prerequisites.<br>
#@markdown <small>Mounting Drive will enable this notebook to save outputs directly to your Drive. Otherwise you will need to copy/download them manually from this notebook.</small>

force_setup = False
pip_packages = 'dalle-mini'
main_repository = ''
mount_drive = True #@param {type:"boolean"}

#@markdown <small>Use `mini-1` if you are on Colab free plan.</small>
dalle_model = "mega-1-fp16" #@param ["mega-1-fp16", "mini-1"]



# Download the repo from Github
import os
from google.colab import output
import warnings
warnings.filterwarnings('ignore')
%cd /content/

# inhagcutils
if not os.path.isfile('/content/inhagcutils.ipynb') and force_setup == False:
  !pip -q install import-ipynb {pip_packages}
  !curl -s -O https://raw.githubusercontent.com/olaviinha/inhagcutils/master/inhagcutils.ipynb
import import_ipynb
from inhagcutils import *

# Mount Drive
if mount_drive is True:
  if not os.path.isdir('/content/drive'):
    from google.colab import drive
    drive.mount('/content/drive')
    drive_root = '/content/drive/My Drive'
  if not os.path.isdir('/content/mydrive'):
    os.symlink('/content/drive/My Drive', '/content/mydrive')
    drive_root = '/content/mydrive/'
  drive_root_set = True
else:
  create_dirs(['/content/faux_drive'])
  drive_root = '/content/faux_drive/'

if main_repository is not '':
  !git clone {main_repository}

import time, sys
from datetime import timedelta






#--Dalle mini--
!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git


if dalle_model == 'mega-1-fp16':
  DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest"
else:
  DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0"
DALLE_COMMIT_ID = None
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"

import jax
import jax.numpy as jnp
jax.local_device_count()

from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
from transformers import CLIPProcessor, FlaxCLIPModel

model, params = DalleBart.from_pretrained(
    DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False
)

vqgan, vqgan_params = VQModel.from_pretrained(
    VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False
)

from flax.jax_utils import replicate

params = replicate(params)
vqgan_params = replicate(vqgan_params)

from functools import partial

# model inference
@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
def p_generate(
    tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale
):
    return model.generate(
        **tokenized_prompt,
        prng_key=key,
        params=params,
        top_k=top_k,
        top_p=top_p,
        temperature=temperature,
        condition_scale=condition_scale,
    )


# decode image
@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):
    return vqgan.decode_code(indices, params=params)

import random

# create a random key
seed = random.randint(0, 2**32 - 1)
key = jax.random.PRNGKey(seed)

from dalle_mini import DalleBartProcessor

processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
#--//Dalle mini--



#--ISR--
%cd /content/
!git clone https://github.com/saadz-khan/image-super-resolution
%cd image-super-resolution
!python setup.py install
from ISR.models import RRDN
rdn = RRDN(weights='gans')
%cd /content/
#--//ISR--



output.clear()
# !nvidia-smi
op(c.ok, 'Setup finished.')

In [None]:
#@title # Do stuff
generate_image_of = "George Costanza ice fishing 2kg siika" #@param {type:"string"}

output_dir = "" #@param {type:"string"}
superres = True #@param {type:"boolean"}
number_of_images = 6 #@param {type:"slider", min:1, max:9}

uniq_id = gen_id()


# Output
if not os.path.isdir(drive_root+output_dir):
  os.mkdir(drive_root+output_dir)
dir_out = drive_root+fix_path(output_dir)
  
timer_start = time.time()




# -- DO THINGS --
prompts = [x.strip() for x in generate_image_of.split(';')]
tokenized_prompts = processor(prompts)
tokenized_prompt = replicate(tokenized_prompts)

# number of predictions per prompt
n_predictions = number_of_images

# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)
gen_top_k = None
gen_top_p = None
temperature = None
cond_scale = 10.0

from flax.training.common_utils import shard_prng_key
import numpy as np
from PIL import Image
from tqdm.notebook import trange

op(c.title, 'Run id:', uniq_id)
# print(f"Prompts: {prompts}\n")
# generate images
images = []

for i in trange(max(n_predictions // jax.device_count(), 1)):
    # get a new key
    key, subkey = jax.random.split(key)
    # generate images
    encoded_images = p_generate(
        tokenized_prompt,
        shard_prng_key(subkey),
        params,
        gen_top_k,
        gen_top_p,
        temperature,
        cond_scale,
    )
    # remove BOS
    encoded_images = encoded_images.sequences[..., 1:]
    # decode images
    decoded_images = p_decode(encoded_images, vqgan_params)
    decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
    for decoded_img in decoded_images:
        img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
        images.append(img)
        display(img)
        if output_dir != '':
          if superres is True:
            # lr_img = np.array(img)
            # lr_img = np.asarray(decoded_img * 255, dtype=np.uint8)
            # lr_img = img
            # lr_img = Image.fromarray(img)
            # lr_img = np.asarray(img)
            lr_img = np.array(img)
            sr_img = rdn.predict(lr_img, by_patch_of_size=50)
            img = Image.fromarray(sr_img)
          timestamp = datetime.date.today().strftime('%Y%m%d')
          img_filename = str(timestamp) + '_' + uniq_id + '_' + str(i) + '_' + (''.join(e for e in generate_image_of[:60].title() if e.isalnum()))+'.png'
          img_path = dir_out+img_filename
          # pil_image = TF.to_pil_image(images[k])
          img.save(img_path)
          # pil_image.save(f'{dir_output[k]}/{file_title}.png')
        print()
# -- END THINGS --

timer_end = time.time()

print('\nElapsed', timedelta(seconds=timer_end-timer_start))
