<a href="https://colab.research.google.com/github/frasercrichton/police-and-thieves/blob/main/jupyter_notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


This Notebook is copied from Dave Davies' original made available here:  
SOLVED - DALL-E Mini: Too much traffic, please try again by Dave Davies
https://wandb.ai/onlineinference/ml-news/reports/Beating-The-DALL-E-Mini-Too-much-traffic-please-try-again---VmlldzoyMTg4Mjkz

**NOTICE:** Currently, a wandb API key is required to run this colab (for downloading the models). If you're signed in to a Google account associated with a wandb account, it should work automatically.

[Click Here for full details and a visual walkthrough of this Colab.](https://wandb.ai/onlineinference/ml-news/reports/Beating-The-DALL-E-Mini-Too-much-traffic-please-try-again---VmlldzoyMTg4Mjkz) 

# Before Starting

Click the "Connect" text in the upper right to start a runtime session. If you are connected, you'll instead see "RAM" and "Disk" next to a couple slightly filled up bars.

To make sure you are using a GPU, click on that button again to see the resources tab. If "GPU RAM" is included in there, you're good to continue. If not, click on "Change runtime type" at the bottom of the resources tab and change your "Hardware accelerator" to "GPU".

# If You Run Into Issues

If you encounter any errors, open the session resources tab from before and click on "Manage sessions". Terminate all sessions listed, then close out and click on the "Reconnect" button in the top right. Make sure you're using a GPU, then you may continue through the rest of the Colab.

# Credit

This Colab is a lightly modified version of the [DALL·E Mini inference pipeline](https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/tools/inference/inference_pipeline.ipynb). More information about using DALL·E Mini through code can be found at the [DALL·E Mini GitHub repository](https://github.com/borisdayma/dalle-mini).

# Initialize

Run this code section to initialize everything required to use DALL·E Mini. It may take a minute or so to process. Please wait for it to finish working before continuing.

Note: DALL·E Mini is the default option. If you would like to use DALL·E Mega, check the "USE_MEGA" box. If you choose to use DALL·E Mega and are not using Colab Pro, you will likely run into RAM errors.

In [None]:
USE_MEGA = False #@param {type:"boolean"}

!pip install -q dalle-mini
!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git

import jax
import jax.numpy as jnp
import numpy as np
import random
from dalle_mini import DalleBart, DalleBartProcessor
from flax.jax_utils import replicate
from flax.training.common_utils import shard_prng_key
from functools import partial
from google.colab import widgets
from PIL import Image
from tqdm.notebook import trange
from vqgan_jax.modeling_flax_vqgan import VQModel

DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0" if not USE_MEGA else "dalle-mini/dalle-mini/mega-1-fp16:latest"
DALLE_COMMIT_ID = None
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"

model, params = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False)
vqgan, vqgan_params = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False)
params = replicate(params)
vqgan_params = replicate(vqgan_params)

@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
def p_generate(tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale):
  return model.generate(**tokenized_prompt, prng_key=key, params=params, top_k=top_k, top_p=top_p, temperature=temperature, condition_scale=condition_scale)
@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):
  return vqgan.decode_code(indices, params=params)

seed = random.randint(0, 2**32 - 1)
key = jax.random.PRNGKey(seed)
processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)


# Generate Images

This is where you can generate images with DALL·E Mini. Enter your prompt into the "prompt" field to the right, and run the code block. A grid of 9 images will be produced. The first run may take slightly longer than subsequent runs as the hardware warms up.

You can change the prompt and re-run the code as many times as you like without re-running the initialization code above.


The following code mounts your Google Drive (Mounted at /content/drive) so that you can save and generated images to Drive.

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
n_predictions = 100
gen_top_k = None
gen_top_p = None
temperature = None
cond_scale = 10.0

try:
  prompt = "metropolitan police uk" #@param {type:"string"}
  prompts = [prompt]
  tokenized_prompts = processor(prompts)
  tokenized_prompt = replicate(tokenized_prompts)
  images = []
  for i in trange(max(n_predictions // jax.device_count(), 1)):
    key, subkey = jax.random.split(key)
    encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), params, gen_top_k, gen_top_p, temperature, cond_scale)
    encoded_images = encoded_images.sequences[..., 1:]
    decoded_images = p_decode(encoded_images, vqgan_params)
    decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
    for decoded_img in decoded_images:
      img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
      img.save("met-police/" + str(i) + '.jpg')
      images.append(img)
  print(prompt)
  

  grid = widgets.Grid(10, 10)
  for i in range(10):
    for j in range(10):
      with grid.output_to(i, j):
        display(images[i*10+j])
except NameError:
  print("Please run the initialization code block first.")

  0%|          | 0/100 [00:00<?, ?it/s]