In [None]:
USE_MEGA = False #@param {type:"boolean"}

!pip install -q dalle-mini
!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git

import jax
import jax.numpy as jnp
import numpy as np
import random
from dalle_mini import DalleBart, DalleBartProcessor
from flax.jax_utils import replicate
from flax.training.common_utils import shard_prng_key
from functools import partial
from google.colab import widgets
from PIL import Image
from tqdm.notebook import trange
from vqgan_jax.modeling_flax_vqgan import VQModel

DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0" if not USE_MEGA else "dalle-mini/dalle-mini/mega-1-fp16:latest"
DALLE_COMMIT_ID = None
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"

model, params = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False)
vqgan, vqgan_params = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False)
params = replicate(params)
vqgan_params = replicate(vqgan_params)

@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
def p_generate(tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale):
  return model.generate(**tokenized_prompt, prng_key=key, params=params, top_k=top_k, top_p=top_p, temperature=temperature, condition_scale=condition_scale)
@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):
  return vqgan.decode_code(indices, params=params)

seed = random.randint(0, 2**32 - 1)
key = jax.random.PRNGKey(seed)
processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)

n_predictions = 9
gen_top_k = None
gen_top_p = None
temperature = None
cond_scale = 10.0

In [None]:
try:
  prompt = "a family of monkeys swimming in the ocean" #@param {type:"string"}
  prompts = [prompt]
  tokenized_prompts = processor(prompts)
  tokenized_prompt = replicate(tokenized_prompts)
  images = []
  for i in trange(max(n_predictions // jax.device_count(), 1)):
    key, subkey = jax.random.split(key)
    encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), params, gen_top_k, gen_top_p, temperature, cond_scale)
    encoded_images = encoded_images.sequences[..., 1:]
    decoded_images = p_decode(encoded_images, vqgan_params)
    decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
    for decoded_img in decoded_images:
      img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
      images.append(img)
  print(prompt)
  grid = widgets.Grid(3, 3)
  for i in range(3):
    for j in range(3):
      with grid.output_to(i, j):
        display(images[i*3+j])
except NameError:
  print("Please run the initialization code block first.")