論文<br>
https://cdn.openai.com/papers/dall-e-2<br>
<br>
GitHub<br>
https://github.com/LAION-AI/dalle2-laion<br>
<br>
<a href="https://colab.research.google.com/github/kaz12tech/ai_demos/blob/master/dalle2_laion_demo.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 環境セットアップ

## GPU確認

In [None]:
!nvidia-smi

In [None]:
!git clone https://github.com/JingyunLiang/SwinIR.git

In [None]:
!pip install -q dalle2_pytorch==0.15.4

# for swinIR
!pip install -q timm
!pip install -q opencv-python

In [None]:
# Laion2B
!wget -c https://huggingface.co/laion/DALLE2-PyTorch/resolve/main/decoder/1.5B/latest.pth \
      -O latest.pth

# prior
!wget -c https://huggingface.co/zenglishuci/conditioned-prior/resolve/main/vit-l-14/prior_aes_finetune.pth \
      -O prior_aes_finetune.pth

In [None]:
import os
from IPython.display import Image as IPythonImage

from PIL import Image
import numpy as np

import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP, OpenAIClipAdapter

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
clip = OpenAIClipAdapter("ViT-L/14").to(device)

prior_network = DiffusionPriorNetwork(
    dim=768,
    depth=24,
    dim_head=64,
    heads=32,
    normformer=True,
    attn_dropout=5e-2,
    ff_dropout=5e-2,
    num_time_embeds=1,
    num_image_embeds=1,
    num_text_embeds=1,
    num_timesteps=1000,
    ff_mult=4
).to(device)

diffusion_prior = DiffusionPrior(
    net=prior_network,
    clip=clip,
    image_embed_dim=768,
    timesteps=1000,
    cond_drop_prob=0.1,
    loss_type="l2",
    condition_on_text_encodings=True,
).to(device)

prior_state_dict = torch.load('/content/prior_aes_finetune.pth', map_location='cpu')
diffusion_prior.load_state_dict(prior_state_dict['ema_model'], strict=True)

del prior_state_dict

In [None]:
unet = Unet(
  dim = 320,
  cond_dim = 512,
  image_embed_dim = 768,
  text_embed_dim = 768,
  cond_on_text_encodings = True,
  channels = 3,
  dim_mults = [1, 2, 3, 4],
  num_resnet_blocks= 4,
  attn_heads= 8,
  attn_dim_head= 64,
  sparse_attn=True,
  memory_efficient=True,
  self_attn =[False, True, True, True]
)

decoder = Decoder(
    unet = (unet),
    image_sizes = (64, 64),
    clip = clip,
    channels = 3,
    timesteps = 1000,
    loss_type ="l2",
    beta_schedule = ["cosine"],
    learned_variance =True
).cuda()

decoder_state_dict = torch.load('/content/latest.pth', map_location='cpu')
decoder.load_state_dict(decoder_state_dict, strict=False)
decoder.eval()

del decoder_state_dict

In [None]:
dalle2 = DALLE2(
    prior = diffusion_prior,
    decoder = decoder
)

images = dalle2(
    ['the end of the world'],
    cond_scale = 3. # classifier free guidance strength (> 1 would strengthen the condition)
)

In [None]:
def show_images(np_images):
  for i, np_img in enumerate(np_images):
    image = Image.fromarray(np.uint8(np_img * 255))
    display(image)

def save_images(output_dir, np_images):
  os.makedirs(output_dir, exist_ok=True)
  for i, np_img in enumerate(np_images):
    image = Image.fromarray(np.uint8(np_img * 255))
    output_path = os.path.join(output_dir, f'{i}.png')
    image.save(output_path)

images_np = images.cpu().permute(0, 2, 3, 1).numpy()
save_images('./outputs', images_np)
show_images(images_np)

In [None]:
!python SwinIR/main_test_swinir.py \
  --task real_sr \
  --model_path experiments/pretrained_models/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth \
  --folder_lq './outputs' \
  --scale 4 \
  --large_model

In [None]:
IPythonImage('/content/results/swinir_real_sr_x4_large/0_SwinIR.png')