<a href="https://colab.research.google.com/github/etuckerman/Text-to-Image_Dalle-E-mini/blob/main/Text_to_Image_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# for colab environments + GPU
!pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html


!pip install -q dalle-mini
!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git

[0mLooking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
[0m  Preparing metadata (setup.py) ... [?25l[?25hdone
[0m

In [2]:
import jax
import jax.numpy as jnp

In [3]:
# Returns the number of local devices available to JAX.
jax.local_device_count()

1

In [4]:
#check if using cuda gpu
jax.devices()

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]

In [5]:
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
from transformers import CLIPProcessor, FlaxCLIPModel

In [6]:
from huggingface_hub import hf_hub_url, cached_download, hf_hub_download

In [7]:
dalle_mini_files_list = ['config.json', 'tokenizer.json',
             'tokenizer_config.json', 'merges.txt', 'vocab.json',
             'special_tokens_map.json', 'enwiki-words-frequency.txt',
             'flax_model.msgpack']

In [8]:
import shutil

for each_file in dalle_mini_files_list:
  downloaded_file = hf_hub_download("dalle-mini/dalle-mini", filename=each_file)
  target_path = '/content/dalle-mini/' + each_file
  shutil.copy(downloaded_file, target_path)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [9]:
!ls -lah /content/dalle-mini

total 1.7G
drwxr-xr-x 4 root root 4.0K Jun 24 18:41 .
drwxr-xr-x 1 root root 4.0K Jun 24 18:32 ..
-rw-r--r-- 1 root root 1.3K Jun 24 18:50 config.json
-rw-r--r-- 1 root root  33M Jun 24 18:50 enwiki-words-frequency.txt
-rw-r--r-- 1 root root 1.7G Jun 24 18:50 flax_model.msgpack
drwxr-xr-x 2 root root 4.0K Jun 24 18:33 .ipynb_checkpoints
-rw-r--r-- 1 root root 450K Jun 24 18:50 merges.txt
-rw-r--r-- 1 root root  239 Jun 24 18:50 special_tokens_map.json
-rw-r--r-- 1 root root  497 Jun 24 18:50 tokenizer_config.json
-rw-r--r-- 1 root root 2.1M Jun 24 18:50 tokenizer.json
-rw-r--r-- 1 root root 783K Jun 24 18:50 vocab.json
drwxr-xr-x 2 root root 4.0K Jun 24 18:44 vqgan


In [10]:
import os
vqgan_files_list = ['config.json',  'flax_model.msgpack']

for each_file in vqgan_files_list:
  downloaded_file = hf_hub_download("dalle-mini/vqgan_imagenet_f16_16384", filename=each_file)
  target_path = '/content/dalle-mini/vqgan/' + each_file
  if os.path.exists(downloaded_file): # Check if file exists before copying
    shutil.copy(downloaded_file, target_path)
  else:
    print(f"Warning: Downloaded file {downloaded_file} not found.")

In [11]:
#for each_file in vqgan_files_list:
  #downloaded_file = hf_hub_download("dalle-mini/vqgan_imagenet_f16_16384", filename=each_file)
  #target_path = '/content/dalle-mini/vqgan/' + each_file
  #shutil.copy(downloaded_file, target_path)

In [12]:
!ls -lah /content/dalle-mini/vqgan

total 291M
drwxr-xr-x 2 root root 4.0K Jun 24 18:44 .
drwxr-xr-x 4 root root 4.0K Jun 24 18:41 ..
-rw-r--r-- 1 root root  434 Jun 24 18:50 config.json
-rw-r--r-- 1 root root 291M Jun 24 18:50 flax_model.msgpack


In [13]:
DALLE_MODEL_LOCATION = '/content/dalle-mini'
DALLE_COMMIT_ID = None
model, params  = DalleBart.from_pretrained(
    DALLE_MODEL_LOCATION, revision= DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False)

In [14]:
model

<dalle_mini.model.modeling.DalleBart at 0x7bab5a39b3a0>

In [15]:
model.dtype

In [16]:
model.config

DalleBartConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "architectures": [
    "eBart"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 16385,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 2730,
  "decoder_layers": 12,
  "decoder_start_token_id": 16384,
  "do_sample": true,
  "dropout": 0.0,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 2730,
  "encoder_layers": 12,
  "encoder_vocab_size": 50264,
  "eos_token_id": 16385,
  "force_ln_scale": false,
  "gradient_checkpointing": true,
  "image_length": 256,
  "image_vocab_size": 16384,
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "ln_positions": "normformer",
  "ln_type": "layernorm",
  "max_length": 257,
  "max_text_length": 64,
  "min_length": 257,
  "model_type": "dallebart",
  "normalize_text": true,
  "pad_token_id": 16385,
  "scale_embedding": false,
  "sinkhorn_iters": 1,
  "tau_init": 0.05,
  "tie_word_embeddings": false,
  "transformers_version": "4.25.1",
  "us

In [17]:
params

{'lm_head': {'kernel': DeviceArray([[-0.00301968, -0.00115364,  0.00724407, ..., -0.00811347,
                -0.00773314,  0.01539494],
               [ 0.06210343,  0.04317437,  0.05253811, ...,  0.06021526,
                 0.06636694,  0.0462695 ],
               [ 0.04940993,  0.06090811,  0.07043562, ...,  0.05165476,
                 0.06932093,  0.05365875],
               ...,
               [ 0.2637832 ,  0.2502213 ,  0.25847328, ...,  0.2649368 ,
                 0.26205418,  0.26459113],
               [ 0.09630896,  0.09840915,  0.08403565, ...,  0.08070044,
                 0.09171668,  0.0958437 ],
               [-0.10084821, -0.10031693, -0.10371408, ..., -0.10406047,
                -0.10876177, -0.10392802]], dtype=float32)},
 'model': {'decoder': {'embed_positions': {'embedding': DeviceArray([[ 0.03459017, -0.0065838 , -0.11748601, ..., -0.01451578,
                  -0.03927238, -0.00266367],
                 [-0.03116009,  0.00438436,  0.02691377, ..., -0.02886203