<a href="https://colab.research.google.com/github/danakhang/freethinker/blob/main/text2image.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#set up


In [1]:
import jax
import jax.numpy as jnp

In [2]:
jax.local_device_count()

1

In [3]:
jax.devices()

[GpuDevice(id=0, process_index=0)]

In [4]:
!pip install -q dalle-mini

[K     |████████████████████████████████| 216 kB 11.5 MB/s 
[K     |████████████████████████████████| 53 kB 1.8 MB/s 
[K     |████████████████████████████████| 4.9 MB 65.5 MB/s 
[K     |████████████████████████████████| 180 kB 73.4 MB/s 
[K     |████████████████████████████████| 1.8 MB 41.1 MB/s 
[K     |████████████████████████████████| 235 kB 60.5 MB/s 
[K     |████████████████████████████████| 145 kB 50.9 MB/s 
[K     |████████████████████████████████| 217 kB 17.2 MB/s 
[K     |████████████████████████████████| 51 kB 5.9 MB/s 
[K     |████████████████████████████████| 85 kB 4.4 MB/s 
[K     |████████████████████████████████| 163 kB 6.7 MB/s 
[K     |████████████████████████████████| 6.6 MB 33.9 MB/s 
[K     |████████████████████████████████| 181 kB 74.9 MB/s 
[K     |████████████████████████████████| 162 kB 76.5 MB/s 
[K     |████████████████████████████████| 63 kB 1.9 MB/s 
[K     |████████████████████████████████| 158 kB 62.9 MB/s 
[K     |████████████████████████

In [5]:
!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git

  Building wheel for vqgan-jax (setup.py) ... [?25l[?25hdone


In [6]:
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
from transformers import CLIPProcessor, FlaxCLIPModel

In [7]:
from huggingface_hub import hf_hub_url, cached_download, hf_hub_download

In [8]:
dalle_mini_files_list = ['config.json', 'tokenizer.json', 
             'tokenizer_config.json', 'merges.txt', 'vocab.json', 
             'special_tokens_map.json', 'enwiki-words-frequency.txt', 
             'flax_model.msgpack']

In [9]:
import shutil

In [10]:
for each_file in dalle_mini_files_list:
  downloaded_file = hf_hub_download("dalle-mini/dalle-mini", filename=each_file)
  target_path = '/content/dalle-mini/' + each_file
  shutil.copy(downloaded_file, target_path)

Downloading:   0%|          | 0.00/1.33k [00:00<?, ?B/s]

FileNotFoundError: ignored

In [None]:
!ls -lah /content/dalle-mini

In [None]:
vqgan_files_list = ['config.json',  'flax_model.msgpack']

In [None]:
for each_file in vqgan_files_list:
  downloaded_file = hf_hub_download("dalle-mini/vqgan_imagenet_f16_16384", filename=each_file)
  target_path = '/content/dalle-mini/vqgan/' + each_file
  shutil.copy(downloaded_file, target_path)

In [None]:
!ls -lah /content/dalle-mini/vqgan

In [None]:
DALLE_MODEL_LOCATION = '/content/dalle-mini'
DALLE_COMMIT_ID = None
model, params = DalleBart.from_pretrained(    
    DALLE_MODEL_LOCATION, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False,
)

In [None]:
model.config

In [None]:
VQGAN_LOCAL_REPO = '/content/dalle-mini/vqgan'
VQGAN_LCOAL_COMMIT_ID = None
vqgan, vqgan_params = VQModel.from_pretrained(
    VQGAN_LOCAL_REPO, revision=VQGAN_LCOAL_COMMIT_ID, _do_init=False
)

In [None]:
DALLE_MODEL_LOCATION = '/content/dalle-mini'
DALLE_COMMIT_ID = None
processor = DalleBartProcessor.from_pretrained(
    DALLE_MODEL_LOCATION, 
    revision=DALLE_COMMIT_ID)

In [None]:
processor

In [None]:
# Works for all available devices to replicate the module
from flax.jax_utils import replicate

params = replicate(params)
vqgan_params = replicate(vqgan_params)

#Model Inference


In [None]:
from functools import partial

In [None]:
@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
def p_generate(
    tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale
):
  return model.generate(
      **tokenized_prompt,
      prng_key=key,
      params=params,
      top_k=top_k,
      top_p=top_p,
      temperature=temperature,
      condition_scale=condition_scale,
  )

#Decode Images

In [None]:
@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):
    return vqgan.decode_code(indices, params=params)

#Setting up the Text input prompt

In [None]:
prompts = ['vincent van gogh paintings mixed with car paintings']

prompts = ['MICHELANGELO paintings mixed with Pink Moon paintings']

In [None]:
tokenized_prompts = processor(prompts)

In [None]:
tokenized_prompt = replicate(tokenized_prompts)

#Defining model parameters

##random key

In [None]:
import random

# create a random key
seed = random.randint(0, 2**32 - 1)
key = jax.random.PRNGKey(seed)

In [None]:
# number of predictions per prompt
n_predictions = 2

# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)
gen_top_k = None
gen_top_p = None
temperature = None
cond_scale = 10.0

In [None]:
from flax.training.common_utils import shard_prng_key
import numpy as np
from PIL import Image
from tqdm.notebook import trange

In [None]:
print(f"Prompts: {prompts}\n")

In [None]:
images = []
for i in trange(max(n_predictions // jax.device_count(), 1)):
    # get a new key
    key, subkey = jax.random.split(key)
    # generate images
    encoded_images = p_generate(
        tokenized_prompt,
        shard_prng_key(subkey),
        params,
        gen_top_k,
        gen_top_p,
        temperature,
        cond_scale,
    )
    # remove BOS
    encoded_images = encoded_images.sequences[..., 1:]
    # decode images
    decoded_images = p_decode(encoded_images, vqgan_params)
    decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
    for decoded_img in decoded_images:
        img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
        images.append(img)
        display(img)
        print()