### Colab by mega b#6696

##### [Original Models](https://github.com/robvanvolt/DALLE-models) & [Colab](https://github.com/afiaka87/dalle-pytorch-pretrained) simplified.

---

#### Join the [Dall-E PyTorch Discord server](https://discord.gg/VHqAXKF3p9) to help with recreating Dall-E!


In [1]:
#@markdown # **1** Install required components.
from IPython.display import clear_output
from google.colab.output import eval_js
eval_js('google.colab.output.setIframeHeight("500")')

!nvidia-smi -L
print("Installing...")
!pip -q install tqdm
from tqdm.notebook import *
with tqdm(total=10) as pbar:
  !pip -q install keras==2.4.0
  pbar.update(1)
  !pip -q install git+https://github.com/afiaka87/CLIP.git
  pbar.update(1)
  !pip -q install taming-transformers
  pbar.update(1)
  !pip -q install dalle-pytorch==0.14.3
  pbar.update(1)
  !pip -q install tokenizers
  pbar.update(1)
  !pip -q install ftfy
  pbar.update(1)
  !pip -q install regex
  pbar.update(1)
  !pip -q install triton==0.4.2
  pbar.update(1)
  !git clone https://github.com/johnpaulbin/dalle-pytorch-pretrained.git
  pbar.update(1)
  !pip -q install wandb
  pbar.update(1)
%cd dalle-pytorch-pretrained

clear_output()

print("Done! We can now go to the next cell.")

Done, move onto the next cell.


In [2]:
#@markdown # **2** Install required dependencies.

eval_js('google.colab.output.setIframeHeight("250")')

#!wget --no-clobber https://www.dropbox.com/s/hl5hyzhyal3vfye/dalle_iconic_butterfly_149.pt
%pip install tokenizers
from tokenizers import Tokenizer
import torch

tokenizer = Tokenizer.from_file("/content/dalle-pytorch-pretrained/cc12m_tokenizer.json")

VOCAB_SIZE = tokenizer.get_vocab_size()

def tokenize(texts, context_length = 256, add_start = False, add_end = False, truncate_text = False):
    if isinstance(texts, str):
        texts = [texts]

    sot_tokens = tokenizer.encode("<|startoftext|>").ids if add_start else []
    eot_tokens = tokenizer.encode("<|endoftext|>").ids if add_end else []
    all_tokens = [sot_tokens + tokenizer.encode(text).ids + eot_tokens for text in texts]
    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)

    for i, tokens in enumerate(all_tokens):
        if len(tokens) > context_length:
            if truncate_text:
                tokens = tokens[:context_length]
            else:
                raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
        result[i, :len(tokens)] = torch.tensor(tokens)

    return result

!wget --no-clobber <dropbox_url>

%pip install gpustat
!wget "https://github.com/lucidrains/DALLE-pytorch/archive/refs/tags/0.14.3.zip" -O /content/
!unzip /content/0.14.3.zip -d /content/dalle-pytorch-pretrained
!mv /content/dalle-pytorch-pretrained/DALLE-pytorch-0.14.3 /content/dalle-pytorch-pretrained/DALLE-pytorch
%cd ./DALLE-pytorch/
!python3 setup.py install
!sudo apt-get -y install llvm-9-dev cmake
!git clone https://github.com/microsoft/DeepSpeed.git /tmp/Deepspeed
%cd /tmp/Deepspeed
!DS_BUILD_SPARSE_ATTN=1 ./install.sh -r

!pip install deepspeed

%cd /content/
!apt-get install pv
!apt-get install jq
!wget https://raw.githubusercontent.com/tonikelope/megadown/master/megadown -O megadown.sh
!chmod 755 megadown.sh

clear_output()

print("Finished, move onto the next cell.")

Finished, move onto the next cell.


In [None]:
%cd /content/

import wandb


eval_js('google.colab.output.setIframeHeight("500")')

#@markdown # **3** Choose the Dall-E Model.

# Using https://github.com/robvanvolt/DALLE-models/tree/main/models/taming_transformer/64L_64HD_8H_756I_128T_cc12m_1E

# Old model (Recommended): https://mega.nz/#!kShC2QjR!5BEPvrouy89XgRFo130hYdSLZu_hyz9s7oWUnhQsXb4

# New model (More general): https://mega.nz/#!5PhBUCSD!kVzo_VFJde1kxPn3-cPpu2cN5dZwaBxFEO_o4hm9RSM

chosen_model = "3 Epoch 12M (by robvanvolt) (Highly recommended!)" #@param ["Old Model (by robvanvolt)", "May 20 model (More General Outputs)", "Old COCO model (by afiaka87)", "New model illustrations_imagenetvqgan (by afiaka87)", "2 Epoch 12M (by robvanvolt)", "3 Epoch 12M (by robvanvolt) (Highly recommended!)"]

if chosen_model == "Old Model (by robvanvolt)":
  chosen_model = "https://mega.nz/#!kShC2QjR!5BEPvrouy89XgRFo130hYdSLZu_hyz9s7oWUnhQsXb4"
elif chosen_model == "May 20 model (More General Outputs)":
  chosen_model = "https://mega.nz/#!5PhBUCSD!kVzo_VFJde1kxPn3-cPpu2cN5dZwaBxFEO_o4hm9RSM"
elif chosen_model == "New model illustrations_imagenetvqgan (by afiaka87)":
  pass
elif chosen_model == "Old COCO model (by afiaka87)":
  chosen_model = "https://www.dropbox.com/s/oper4enc0s0r738/vg_coco_oi_cc100k_latest.pt?dl=1"
elif chosen_model == "2 Epoch 12M (by robvanvolt)":
  chosen_model = "https://api.wandb.ai/files/robvanvolt/dalle_train_transformer/3bwiteds/dalle.pt"
elif chosen_model == "3 Epoch 12M (by robvanvolt) (Highly recommended!)":
  chosen_model = "https://github.com/johnpaulbin/DALLE-models/releases/download/model/16L_64HD_8H_512I_128T_cc12m_cc3m_3E.pt"

if "https://mega.nz" in chosen_model:
  !/content/megadown.sh $chosen_model --o dalle_checkpoint.pt
elif chosen_model == "New model illustrations_imagenetvqgan (by afiaka87)":
  run = wandb.init()
  artifact = run.use_artifact('dalle-pytorch-replicate/royalty_free_illustrations/trained-dalle:v7', type='model')
  artifact_dir = artifact.download()
else:
  !wget "$chosen_model" -O dalle_checkpoint.pt

clear_output()

!mkdir -p ~/.cache/dalle;

!wget https://www.dropbox.com/s/15mhdhy57y6qttf/vqgan.1024.model.ckpt;

!wget https://www.dropbox.com/s/q8nayimg4skf0pl/vqgan.1024.config.yml;

!wget https://www.dropbox.com/s/r4uukngelv2vhk3/variety.bpe?dl=1 -O variety.bpe

!cp "vqgan.1024.model.ckpt" ~/.cache/dalle;
!cp "vqgan.1024.config.yml" ~/.cache/dalle;

clear_output()

print("Finished downloading the selected model.")

import os
import glob
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import random

/content
--2023-02-17 13:08:07--  https://github.com/johnpaulbin/DALLE-models/releases/download/model/16L_64HD_8H_512I_128T_cc12m_cc3m_3E.pt
Resolving github.com (github.com)... 140.82.113.3
Connecting to github.com (github.com)|140.82.113.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/379488067/2c1416ce-9ca2-4292-85c2-1f9f61ce642f?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20230217%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20230217T130807Z&X-Amz-Expires=300&X-Amz-Signature=ac7b123f9e1765baee0ab6c9e9980c9b42f04921a3a0c47b9b5e07ac89219671&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=379488067&response-content-disposition=attachment%3B%20filename%3D16L_64HD_8H_512I_128T_cc12m_cc3m_3E.pt&response-content-type=application%2Foctet-stream [following]
--2023-02-17 13:08:07--  https://objects.githubusercontent.com/github-production-release-asset-

In [None]:
#@markdown # **4** Try out the model.
#@markdown #### Results will be saved in the outputs directory. Refresh (right click the folder -> refresh) if you dont see the result inside the folder.
if chosen_model != "New model illustrations_imagenetvqgan (by afiaka87)":
  checkpoint_path = "/content/dalle_checkpoint.pt"
else:
  checkpoint_path = "/content/artifacts/trained-dalle:v7/dalle.pt"

text = "the grand canyon with snow on it. snow located on the grand canyon. a snowy grand canyon." #@param {type:"string"}

generate_16_images = False #@param {type:"boolean"}

num_images = 8
batch_size = 8

if generate_16_images:
  num_images = 16
  batch_size = 16


text_cleaned = text.replace(" ", "_")
_folder = f"/content/outputs/{text_cleaned}/"

allow = ["https://www.dropbox.com/s/oper4enc0s0r738/vg_coco_oi_cc100k_latest.pt?dl=1", "New model illustrations_imagenetvqgan (by afiaka87)"]

if chosen_model not in allow:
  !python /content/dalle-pytorch-pretrained/DALLE-pytorch/generate.py --dalle_path=$checkpoint_path --taming --text="$text" --num_images=$num_images --batch_size=$batch_size --outputs_dir="$_folder"; wait;
else:
  !python /content/dalle-pytorch-pretrained/DALLE-pytorch/generate.py --dalle_path=$checkpoint_path --taming --text="$text" --num_images=$num_images --batch_size=$batch_size --outputs_dir="$_folder" --bpe_path variety.bpe; wait;

#clear_output()

print("Finished generating images, attempting to display results...")

%matplotlib inline

final = text_cleaned[:100]

text_cleaned = text.replace(" ", "_")
output_dir = f"/content/outputs/{text_cleaned}/{final}/" 
images = []

for img_path in glob.glob(f'{output_dir}*.jpg'):
    images.append(mpimg.imread(img_path))

plt.figure(figsize=(32,32))

if generate_16_images:
  plt.figure(figsize=(64,64))

columns = 4
for i, image in enumerate(images):
    plt.subplot(len(images) / columns + 1, columns, i + 1)
    plt.imshow(image)


# Optional cells
### May break the session if used. If so, factory reset runtime.

In [None]:
import glob
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline

text_cleaned = text.replace(" ", "_")
output_dir = f"/content/outputs/{text_cleaned}/" #@param
images = []
for img_path in glob.glob(f'{output_dir}*.jpg'):
    images.append(mpimg.imread(img_path))

plt.figure(figsize=(4,4))
columns = 5
for i, image in enumerate(images):
    plt.subplot(len(images) / columns + 1, columns, i + 1)
    plt.imshow(image)

In [None]:
%pip install "git+https://github.com/openai/CLIP.git"
import clip
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

In [None]:
""" Get rank by CLIP! """
image = F.interpolate(images, size=224)
text = clip.tokenize(["this colorful bird has a yellow breast , with a black crown and a black cheek patch."]).to(device)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_text.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)

In [None]:
np_images = images.cpu().numpy()
scores = probs[0]

def show_reranking(images, scores, sort=True):
    img_shape = images.shape
    if sort:
        scores_sort = scores.argsort()
        scores = scores[scores_sort[::-1]]
        images = images[scores_sort[::-1]]

    rows = 4
    cols = img_shape[0] // 4
    img_idx = 0

    for col in range(cols):
        fig, axs = plt.subplots(1, rows, figsize=(20,20))
        plt.subplots_adjust(wspace=0.01)
        for row in range(rows):
            tran_img = np.transpose(images[img_idx], (1,2,0))
            axs[row].imshow(tran_img, interpolation='nearest')
            axs[row].set_title("{}%".format(np.around(scores[img_idx]*100, 5)))
            axs[row].set_xticks([])
            axs[row].set_yticks([])
            img_idx += 1

show_reranking(np_images, scores)

In [None]:
from torchvision import transforms

txt = "this bird has wings that are brown with a white belly"
img_path = "images/Yellow_Headed_Blackbird_0013_8362.jpg"

img = Image.open(img_path)
tf = transforms.Compose([
    transforms.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
    transforms.RandomResizedCrop(256, scale=(0.6, 1.0), ratio=(1.0, 1.0)),
    transforms.ToTensor(),
])
img = tf(img).cuda()

sot_token = vocab.encode("<|startoftext|>").ids[0]
eot_token = vocab.encode("<|endoftext|>").ids[0]
codes = [0] * dalle_dict['hparams']['text_seq_len']
text_token = vocab.encode(txt).ids
tokens = [sot_token] + text_token + [eot_token]
codes[:len(tokens)] = tokens
caption_token = torch.LongTensor(codes).cuda()

imgs = img.repeat(16,1,1,1)
caps = caption_token.repeat(16,1)

mask = (caps != 0).cuda()

images = dalle.generate_images(
        caps,
        mask = mask,
        img = imgs,
        num_init_img_tokens = (100),  # you can set the size of the initial crop, defaults to a little less than ~1/2 of the tokens, as done in the paper
        filter_thres = 0.9,
        temperature = 1.0
)

grid = make_grid(images, nrow=4, normalize=False, range=(-1, 1)).cpu()
show(grid)