In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install git+https://github.com/openai/DALL-E.git

In [None]:
import io
import os, sys
import requests
import PIL
import random

import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF

from dall_e          import map_pixels, unmap_pixels, load_model
from IPython.display import display, display_markdown

target_image_size = 256

#torch.cuda.current_device()
#torch.cuda.get_device_name(0)
#torch.cuda.is_available()

def download_image(url):
    resp = requests.get(url)
    resp.raise_for_status()
    return PIL.Image.open(io.BytesIO(resp.content))

def preprocess(img):
    s = min(img.size)
    
    if s < target_image_size:
        raise ValueError(f'min dim for image {s} < {target_image_size}')
        
    r = target_image_size / s
    s = (round(r * img.size[1]), round(r * img.size[0]))
    img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
    img = TF.center_crop(img, output_size=2 * [target_image_size])
    img = torch.unsqueeze(T.ToTensor()(img), 0)
    return map_pixels(img)

def decode(z, savefile=None):
  z = F.one_hot(z, num_classes=enc.vocab_size).permute(0, 3, 1, 2).float()

  x_stats = dec(z).float()
  x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3]))
  x_rec = T.ToPILImage(mode='RGB')(x_rec[0])

  display(x_rec)

  if (savefile is not None):
    x_rec.save(savefile)

In [None]:
# This can be changed to a GPU, e.g. 'cuda:0'.
dev = torch.device('cpu')

# For faster load times, download these files locally and use the local paths instead.
enc = load_model("https://cdn.openai.com/dall-e/encoder.pkl", dev)
dec = load_model("https://cdn.openai.com/dall-e/decoder.pkl", dev)

In [None]:
x = preprocess(download_image('https://c.pxhere.com/photos/7a/ad/dog_labrador_light_brown_hundeportrait_out_dog_head_nature_pet-655500.jpg!d'))
display(T.ToPILImage(mode='RGB')(x[0]))

In [None]:
import torch.nn.functional as F

z_logits = enc(x)
z = torch.argmax(z_logits, axis=1)

decode(z)

In [None]:
z_rand = torch.rand(1,32,32)*1024
z_rand = z_rand.long()

decode(z_rand)

In [None]:
z = torch.argmax(z_logits, axis=1)

v =  random.randrange(8192)
index = torch.tensor([0,1,2,3,4,5,26,27,28,29,30,31])

z[0].index_fill_(0, index, v)
z[0].index_fill_(1, index, v)
decode(z)

In [None]:
mask = torch.ones(1,32,32)
index = torch.tensor([0,1,2,3,4,5,26,27,28,29,30,31])
mask[0].index_fill_(0, index, 0)
mask[0].index_fill_(1, index, 0)

z_rand = torch.rand(1,32,32)*8192*(1-mask)
z_image = torch.argmax(z_logits, axis=1)*mask

z_comp = z_rand + z_image
z_comp = z_comp.long()
decode(z_comp)

In [None]:
mask = torch.zeros(1,32,32)
rand = torch.rand(1,32,32)*8192
base = torch.argmax(z_logits, axis=1)

for f in range(30):
  mask[0].index_fill_(0, torch.tensor([f,30,31]), 1)
  #mask[0].index_fill_(0, torch.tensor(f), 1)
  z_comp = rand*(1-mask) + base*mask
  decode(z_comp.long(), savefile="drive/MyDrive/AIExperiments/DALLE/anim_3_%02d.jpg" % f)