論文<br>
https://cdn.openai.com/papers/dall-e-2<br>
<br>
GitHub<br>
https://github.com/LAION-AI/dalle2-laion<br>
<br>
<a href="https://colab.research.google.com/github/kaz12tech/ai_demos/blob/master/dalle2_laion_demo.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 環境セットアップ

## GPU確認

In [None]:
!nvidia-smi

## GitHubからコード取得

In [None]:
# for SR
!git clone https://github.com/JingyunLiang/SwinIR.git

## ライブラリのインポート

In [None]:
!pip install -q dalle2_pytorch==0.15.4

# for swinIR
!pip install -q timm
!pip install -q opencv-python

## ライブラリのインポート

In [None]:
import os
import json
from IPython.display import Image as IPythonImage

from PIL import Image
import numpy as np

import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP, OpenAIClipAdapter

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# 学習済みモデルのセットアップ

In [None]:
!mkdir decoder prior

# Laion2B
!wget -c https://huggingface.co/laion/DALLE2-PyTorch/resolve/main/decoder/1.5B/latest.pth \
      -O ./decoder/latest.pth
!wget -c https://huggingface.co/laion/DALLE2-PyTorch/resolve/main/decoder/1.5B_laion2B/decoder_config.json \
      -O ./decoder/decoder_config.json

# prior
!wget -c https://huggingface.co/zenglishuci/conditioned-prior/resolve/main/vit-l-14/prior_aes_finetune.pth \
      -O ./prior/prior_aes_finetune.pth

# 学習済みモデルのロード

In [None]:
decoder_model_path = './decoder/latest.pth'
decoder_config_path = './decoder/decoder_config.json'
prior_model_path = './prior/prior_aes_finetune.pth'
with open(decoder_config_path, "r") as f:
  decoder_config = json.load(f)

## clip, prior

In [None]:
clip = OpenAIClipAdapter(decoder_config['decoder']['clip']['model']).to(device)

prior_network = DiffusionPriorNetwork(
    dim=768,
    depth=24,
    dim_head=64,
    heads=32,
    normformer=True,
    attn_dropout=5e-2,
    ff_dropout=5e-2,
    num_time_embeds=1,
    num_image_embeds=1,
    num_text_embeds=1,
    num_timesteps=1000,
    ff_mult=4
).to(device)

diffusion_prior = DiffusionPrior(
    net=prior_network,
    clip=clip,
    image_embed_dim=768,
    timesteps=1000,
    cond_drop_prob=0.1,
    loss_type="l2",
    condition_on_text_encodings=True,
).to(device)

prior_state_dict = torch.load(prior_model_path, map_location='cpu')
diffusion_prior.load_state_dict(prior_state_dict['ema_model'], strict=True)

del prior_state_dict

## decoder

In [None]:
unet = Unet(
  dim = decoder_config['decoder']['unets'][0]['dim'],
  cond_dim = decoder_config['decoder']['unets'][0]['cond_dim'],
  image_embed_dim = decoder_config['decoder']['unets'][0]['image_embed_dim'],
  text_embed_dim = decoder_config['decoder']['unets'][0]['text_embed_dim'],
  cond_on_text_encodings = decoder_config['decoder']['unets'][0]['cond_on_text_encodings'],
  channels = decoder_config['decoder']['unets'][0]['channels'],
  dim_mults = decoder_config['decoder']['unets'][0]['dim_mults'],
  num_resnet_blocks= decoder_config['decoder']['unets'][0]['num_resnet_blocks'],
  attn_heads= decoder_config['decoder']['unets'][0]['attn_heads'],
  attn_dim_head= decoder_config['decoder']['unets'][0]['attn_dim_head'],
  sparse_attn=decoder_config['decoder']['unets'][0]['sparse_attn'],
  memory_efficient=decoder_config['decoder']['unets'][0]['memory_efficient'],
  self_attn =decoder_config['decoder']['unets'][0]['self_attn'],
)

decoder = Decoder(
    unet = (unet),
    clip = clip,
    image_sizes = decoder_config['decoder']['image_sizes'],
    channels = decoder_config['decoder']['channels'],
    timesteps = decoder_config['decoder']['timesteps'],
    loss_type =decoder_config['decoder']['loss_type'],
    beta_schedule = decoder_config['decoder']['beta_schedule'],
    learned_variance =decoder_config['decoder']['learned_variance'],
).cuda()

decoder_state_dict = torch.load(decoder_model_path, map_location='cpu')
decoder.load_state_dict(decoder_state_dict, strict=False)
decoder.eval()

del decoder_state_dict

# DALLE2

## inference

In [None]:
dalle2 = DALLE2(
    prior = diffusion_prior,
    decoder = decoder
)

images = dalle2(
    ['dogs'],
    cond_scale = 3. # classifier free guidance strength (> 1 would strengthen the condition)
)

## show image

In [None]:
def show_images(np_images):
  for i, np_img in enumerate(np_images):
    image = Image.fromarray(np.uint8(np_img * 255))
    display(image)

def save_images(output_dir, np_images):
  os.makedirs(output_dir, exist_ok=True)
  for i, np_img in enumerate(np_images):
    image = Image.fromarray(np.uint8(np_img * 255))
    output_path = os.path.join(output_dir, f'{i}.png')
    image.save(output_path)

images_np = images.cpu().permute(0, 2, 3, 1).numpy()
save_images('./outputs', images_np)
show_images(images_np)

# 64x64->256x256

In [None]:
!python SwinIR/main_test_swinir.py \
  --task real_sr \
  --model_path experiments/pretrained_models/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth \
  --folder_lq './outputs' \
  --scale 4 \
  --large_model

In [None]:
IPythonImage('/content/results/swinir_real_sr_x4_large/0_SwinIR.png')