In [None]:
#モジュールのインポート
import torch
from dalle_pytorch import DiscreteVAE

#dVAEのパラメータ定義
vae = DiscreteVAE(
    image_size = 256,
    num_layers = 3,           # ダウンサンプリングの数。ex. 256 / (2 ** 3) = (32 x 32 feature map)
    num_tokens = 8192,        # visual tokensの数。論文では8192を使用したが、もっと小さくすることができる
    codebook_dim = 512,       # codebookの次元
    hidden_dim = 64,          # hiddenの次元
    num_resnet_blocks = 1,    # resnetのブロックの数
    temperature = 0.9,        # gumbel softmax温度。これが低いほど、離散化は難しくなる
    straight_through = False, # gumbel softmaxのためのstraight-through。どちらが良いかわからない
)

images = torch.randn(4, 3, 256, 256)

In [None]:
#モジュールのインポート
import torch
from dalle_pytorch import DiscreteVAE, DALLE

#dVAEのパラメータ定義
vae = DiscreteVAE(
    image_size = 256,
    num_layers = 3,
    num_tokens = 8192,
    codebook_dim = 1024,
    hidden_dim = 64,
    num_resnet_blocks = 1,
    temperature = 0.9
)

#DALLEのパラメータ定義
dalle = DALLE(
    dim = 1024,
    vae = vae,                  # （1）画像シーケンスの長さと（2）画像トークンの数を自動的に推測
    num_text_tokens = 10000,    # テキストの語彙サイズ
    text_seq_len = 256,         # テキストシーケンスの長さ
    depth = 12,                 # 64を目指すべき
    heads = 16,                 # attention headの数
    dim_head = 64,              # attention headの次元
    attn_dropout = 0.1,         # attention dropout
    ff_dropout = 0.1            # feedforward dropout
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)
mask = torch.ones_like(text).bool()

loss = dalle(text, images, mask = mask, return_loss = True)
loss.backward()

# 大量のデータを使用して上記を長時間実行

images = dalle.generate_images(text, mask = mask)
images.shape # (4, 3, 256, 256)

In [1]:
#上記2つを実行した場合は実行する必要なし
#時間短縮用の事前学習データ読み込み
import torch
from dalle_pytorch import OpenAIDiscreteVAE, DALLE

vae = OpenAIDiscreteVAE()       # 事前学習済みのOpenAIVAEをロード

dalle = DALLE(
    dim = 1024,
    vae = vae,                  # （1）画像シーケンスの長さと（2）画像トークンの数を自動的に推測
    num_text_tokens = 10000,    # テキストの語彙サイズ
    text_seq_len = 256,         # テキストシーケンスの長さ
    depth = 1,                  # 64を目指すべき
    heads = 16,                 # attention headの数
    dim_head = 64,              # attention headの次元
    attn_dropout = 0.1,         # attention dropout
    ff_dropout = 0.1            # feedforward dropout
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)
mask = torch.ones_like(text).bool()

loss = dalle(text, images, mask = mask, return_loss = True)
loss.backward()

In [2]:
#モジュールのインポート
import torch
from dalle_pytorch import CLIP

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 10000,
    text_enc_depth = 6,
    text_seq_len = 256,
    text_heads = 8,
    num_visual_tokens = 512,
    visual_enc_depth = 6,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)
mask = torch.ones_like(text).bool()

#損失関数
loss = clip(text, images, text_mask = mask, return_loss = True)
loss.backward()

In [None]:
images, scores = dalle.generate_images(text, mask = mask, clip = clip)

scores.shape # (2,)
images.shape # (2, 3, 256, 256)

# 論文では512 samplingのtop 32

In [None]:
#DALLEのパラメータ定義①
dalle = DALLE(
    dim = 1024,
    vae = vae,
    num_text_tokens = 10000,
    text_seq_len = 256,
    depth = 64,
    heads = 16,
    reversible = True  # <-- reversible networks https://arxiv.org/abs/2001.04451
)

In [None]:
#DALLEのパラメータ定義②
dalle = DALLE(
    dim = 1024,
    vae = vae,
    num_text_tokens = 10000,
    text_seq_len = 256,
    depth = 64,
    heads = 16,
    reversible = True,
    attn_types = ('full', 'axial_row', 'axial_col', 'conv_like')  # cycles between these four types of attention
)

In [None]:
#DALLEのパラメータ定義③
dalle = DALLE(
    dim = 512,
    vae = vae,
    num_text_tokens = 10000,
    text_seq_len = 256,
    depth = 64,
    heads = 8,
    attn_types = ('full', 'sparse')  # interleave sparse and dense attention for 64 layers
)