Setup up the notebook.  The first two cells only needs to be run if you are using google colab.

In [None]:
# Install required libraries
!pip install -q dalle-mini==0.1.3
!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Set up the dalle model.  You will need a weights and biases account to get an API key (don't worry, it's free to create an account).

In [None]:
# Model references

# dalle-mega
DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest"  # can be wandb artifact or 🤗 Hub or local folder or google bucket
DALLE_COMMIT_ID = None

# if the notebook crashes too often you can use dalle-mini instead by uncommenting below line
# DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0"

# VQGAN model
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"

In [None]:
import jax
import jax.numpy as jnp

# check how many devices are available
jax.local_device_count()

In [None]:
# Load models & tokenizer
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
from transformers import CLIPProcessor, FlaxCLIPModel

# Load dalle-mini
model, params = DalleBart.from_pretrained(
    DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False
)

# Load VQGAN
vqgan, vqgan_params = VQModel.from_pretrained(
    VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False
)

Import everything we'll need.

In [None]:
import csv
import torch
import random
import os
from flax.jax_utils import replicate
from functools import partial
from dalle_mini import DalleBartProcessor
from flax.training.common_utils import shard_prng_key
import numpy as np
from PIL import Image
from tqdm.notebook import trange

In [None]:
# Model parameters are replicated on each device for faster inference.
params = replicate(params)
vqgan_params = replicate(vqgan_params)

# Model functions are compiled and parallelized to take advantage of multiple devices.
# model inference
@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
def p_generate(
    tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale
):
    return model.generate(
        **tokenized_prompt,
        prng_key=key,
        params=params,
        top_k=top_k,
        top_p=top_p,
        temperature=temperature,
        condition_scale=condition_scale,
    )


# decode image
@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):
    return vqgan.decode_code(indices, params=params)

# create a random key
seed = random.randint(0, 2**32 - 1)
key = jax.random.PRNGKey(seed)

processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)

Load the prompts dataset.  There are 200 prompts which are all a variation on the following vague theme:

> A  {gender}  with  an  object

where {gender} is replaced either "man", "woman", "boy" or "girl".



In [None]:
# Load the dataset of prompts
prompts_db = []
with open('/content/prompts.csv', newline='') as csvfile:
    spamreader = csv.reader(csvfile, delimiter='|', quotechar='|')
    i = 0
    # Iterate over and print out all rows
    for row in spamreader:
      print(row)
      # Skip the first row (i.e. the column headers)
      if i > 0 and row:
        prompts_db.append(row)
      i += 1


In [None]:
# Create output directory
output_dir = "/content/drive/My Drive/dalle_results/"
try: 
    os.mkdir(output_dir) 
except OSError as error: 
    print(error)  

# number of predictions per prompt
n_predictions = 1

# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)
gen_top_k = None
gen_top_p = None
temperature = None
cond_scale = 10.0

# Iterate over every pair of prompts
for i in range(3, int(len(prompts_db) / 4)):
    idx = int(i * 4)
    # Iterate over the man/woman/boy/girl variations
    for j in range(4):
        # Fix double spaces in prompts
        prompt = prompts_db[idx+j][3][1:].replace('  ', ' ')
        print(j, prompt)
        prompts = [
            "sunset over a lake in the mountains",
        ]
        # convert string to tokens
        tokenized_prompt = processor([prompt,])

        # replicate the prompts onto each device.
        tokenized_prompt = replicate(tokenized_prompt)

        # Generate 5 images for each prompt
        repetitions = 5
        for r in range(repetitions):
            # Set seed for reproducability
            key = jax.random.PRNGKey(int(prompts_db[idx+j][1]) + (r * 100))
            # get a new key
            key, subkey = jax.random.split(key)
            # generate images
            encoded_images = p_generate(
                tokenized_prompt,
                shard_prng_key(subkey),
                params,
                gen_top_k,
                gen_top_p,
                temperature,
                cond_scale,
            )
            # remove BOS
            encoded_images = encoded_images.sequences[..., 1:]
            # decode images
            decoded_images = p_decode(encoded_images, vqgan_params)
            decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
            for decoded_img in decoded_images:
                img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
                display(img)
                img.save(output_dir + prompts_db[idx+j][0] + '_' + str(r) + ".png")