# 1. Generating Images from the Text

<img src="https://github.com/borisdayma/dalle-mini/blob/main/img/logo.png?raw=true" width="200">

### DALL-E 모델 구조는 [여기](https://github.com/borisdayma/dalle-mini)를 참고해 주세요.
학습과정이 궁굼하다면 [여기](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)를 참고하세요.

In [None]:
!pip install transformers flax "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install git+https://github.com/patil-suraj/vqgan-jax.git # VQGAN model in JAX
!pip install git+https://github.com/borisdayma/dalle-mini.git # Dall-E model

PyTorch 기반의 코드가 아니라, [JAX](https://jax.readthedocs.io/en/latest/index.html) 기반의 코드이므로 코드 자체를 새롭게 다루기보다는 여러 prompt 입력을 시도해보면서 Dall-E 의 전반적인 작동을 체험해보기로 합시다.

In [None]:
import jax
import jax.numpy as jnp

# import jax.tools.colab_tpu            # For Tpu setting
# jax.tools.colab_tpu.setup_tpu()

# import os                             # For emulating 4 GPUs
# os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'

from tqdm.notebook import tqdm, trange

# check how many devices are available
device_count = jax.device_count()
jax.devices(), device_count

In [None]:
from dalle_mini import DalleBart, DalleBartProcessor
from flax.jax_utils import replicate

# dalle-mini / dalle-mega
DALLE_MODEL = "dalle-mini/dalle-mini" # "dalle-mini/dalle-mega" # can be wandb artifact or 🤗 Hub or local folder or google bucket
DALLE_COMMIT_ID = None

# Load dalle-mini processor
processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)

# Load dalle-mini
model, params = DalleBart.from_pretrained(
    DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False
)

params = replicate(params)


In [None]:
# 아래 prompts 에 그림 생성의 소스로 쓸 문장(영어)들을 넣어주세요.
prompts = ["sunset over a lake in the mountains", "cityscapes with clear sky"]
num_prompts = len(prompts)

# tokenize the prompts
tokenized = processor(prompts)

print(tokenized.keys())
print(tokenized["input_ids"].shape)
print(tokenized["attention_mask"].shape)
print(tokenized["input_ids"])

tokenized = replicate(tokenized)

Notes:
* `0`: BOS, Begining of Sequence, 문장의 시작
* `2`: EOS, End of Sequence, 문장의 끝
* `1`: PAD, 남는 부분 패딩

In [None]:
n_predictions = 8

# create random keys
seed = 42
key = jax.random.PRNGKey(seed)

In [None]:
from flax.training.common_utils import shard_prng_key

# generate sample predictions
@jax.pmap
def p_generate(
    tokenized_prompt, key, params
):
    return model.generate(
        **tokenized_prompt,
        prng_key=key,
        params=params
    )

encoded_images = []
for i in trange(max(n_predictions // jax.device_count(), 1)):
    key, subkey = jax.random.split(key)
    encoded_images.append(p_generate(
        tokenized,
        shard_prng_key(subkey),
        params
    )['sequences'])


In [None]:
encoded_images

첫번째 토큰은 (`16384`) 디코더의 시작을 나타내는 특수 토큰입니다. (즉, 코드 북에 해당되는 내용이 아님).



In [None]:
# remove first token (BOS)
encoded_images = jnp.concatenate(encoded_images)
encoded_images = encoded_images.reshape((-1, encoded_images.shape[-1]))[..., 1:]
encoded_images

In [None]:
encoded_images = jax.device_put_sharded(encoded_images.split(jax.device_count()), devices=jax.devices())

## Decode images
디코딩은 [VQ-GAN 구조](https://github.com/CompVis/taming-transformers)를 활용합니다. JAX로 구현된 모델은 다음에서 확인이 가능합니다: [`flax-community/vqgan_f16_16384`](https://github.com/patil-suraj/vqgan-jax).

In [None]:
import numpy as np
from vqgan_jax.modeling_flax_vqgan import VQModel

# VQGAN model
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"

# Load VQGAN
vqgan, vqgan_params = VQModel.from_pretrained(
    VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False
)

vqgan_params = replicate(vqgan_params)

In [None]:
# decode images
@jax.pmap
def p_decode(indices, params):
    decoded_images = vqgan.decode_code(indices, params=params)
    return decoded_images.clip(0., 1.)

decoded_images = p_decode(encoded_images, vqgan_params)

In [None]:
decoded_images = decoded_images.reshape((-1, 2) + decoded_images.shape[2:])

In [None]:
# convert to image
images = [
    [np.asarray(image * 255, dtype=np.uint8) for image in prompt_images] 
    for prompt_images in decoded_images
]

In [None]:
import matplotlib.pyplot as plt

# display an image
for prompt, samples in zip(prompts, list(zip(*images))):
    print(prompt)

    all_images = np.concatenate(samples, axis=1)
    plt.figure(figsize=(24, 4))
    plt.imshow(all_images)
    plt.show()

# 2. Rank images with CLIP
앞서 확인하였던 CLIP 모델로 생성된 이미지와 텍스트가 얼마나 일치하는지 랭킹을 매겨봅시다.

In [None]:
from transformers import CLIPModel, CLIPProcessor

In [None]:
# set up model and processor
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

In [None]:
import torch
from PIL import Image

images_pil = [
    [Image.fromarray(np.asarray(image, dtype=np.uint8)) for image in prompt_images] 
    for prompt_images in images
]

for prompt, samples in zip(prompts, list(zip(*images_pil))):
    print(prompt)
    inputs = processor(text=prompt, images=samples, return_tensors="pt", padding=True, truncation=True)

    with torch.no_grad():
        outputs = clip(**inputs)

    scores = outputs.logits_per_image.squeeze(-1).softmax(-1) # normalize and sum all scores to 1

    for idx in scores.argsort(dim=0, descending=True):
        print(f'Score: {scores[idx]}')
        plt.imshow(samples[idx])
        plt.show()