In [None]:
%pip install -q git+https://github.com/CompVis/taming-transformers.git
!git clone https://github.com/CompVis/taming-transformers.git && cd taming-transformers && python -m pip install -e .
%pip install -q "omegaconf>=2.0.0" "pytorch-lightning>=1.0.8" einops transformers imageio-ffmpeg
%pip install -q git+https://github.com/cthiounn/dalle-tiny.git

# !curl -L "https://huggingface.co/boris/vqgan_f16_16384/raw/main/config.yaml" > "config_vqgan_minidalle.yaml"
# !curl -L "https://huggingface.co/boris/vqgan_f16_16384/resolve/main/model.ckpt" > "model_vqgan_minidalle.ckpt"

import os
os._exit(00)

In [None]:
from tqdm import tqdm
import s3fs
import os

# Create filesystem object
S3_ENDPOINT_URL = "https://" + os.environ["AWS_S3_ENDPOINT"]
fs = s3fs.S3FileSystem(client_kwargs={'endpoint_url': S3_ENDPOINT_URL})
BUCKET = "cthiounn2"
fs.ls(BUCKET)

files=['model_vqgan_minidalle.ckpt','config_vqgan_minidalle.yaml','config.json','pytorch_model.bin']
for file in tqdm(files):
    with fs.open(f'{BUCKET}/{file}', mode="rb") as file_in, open(file,"wb") as file_out:
            file_out.write(file_in.read())



In [None]:
import sys
sys.path.append(".")

# also disable grad to save memory
from omegaconf import OmegaConf
import taming
from taming.models.vqgan import VQModel
import torch
torch.set_grad_enabled(False)

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



config_path = "config_vqgan_minidalle.yaml"
config = OmegaConf.load(config_path)
vqmodel=VQModel(**config.model.params).to(DEVICE)

ckpt_path = "model_vqgan_minidalle.ckpt"

sd = torch.load(ckpt_path, map_location=DEVICE)["state_dict"]
vqmodel.load_state_dict(sd, strict=False)

In [None]:
from dalle_tiny.model import TinyDalleModel
import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader

device = 'cuda' if torch.cuda.is_available() else 'cpu'

from transformers import BartForConditionalGeneration

try:
    model=TinyDalleModel.from_pretrained('.')
except:
    model=TinyDalleModel.from_pretrained('facebook/bart-large-cnn')
model.reinit_model_for_images()

In [None]:
captions=[]
captions.append("A black clock in snowy area with building in background.")
captions.append("A city bus being followed by a red car.")
captions.append("A full view of a beautiful store in a town.")
captions.append("Family and friends are together at the beach.")

In [None]:
import matplotlib.pyplot as plt
from transformers import BartTokenizer
from collections import defaultdict

tokenizer=BartTokenizer.from_pretrained('facebook/bart-large-cnn')

image_dict=defaultdict(list)
for i in range(0,32):
    print(i)
    FILE_KEY_S3 = f"checkpoint_fixdecoderidtoken_{i*5}.pth"
    FILE_PATH_S3 = BUCKET + "/" + FILE_KEY_S3
    try:
        with fs.open(FILE_PATH_S3, mode="rb") as file_in:

            model.load_state_dict(torch.load(file_in,map_location=device))
            model.config.eos_token_id=16384
            model.config.max_length=257
        model.eval()
        model=model.to(device)
        
        for j,caption in enumerate(captions):
            inputs=tokenizer(caption, return_tensors="pt",max_length=257,padding="max_length")
            inputs=inputs.to(device)
            pred=model.generate(**inputs, do_sample=True, top_k=100)
            pred=pred.detach()
            pred=pred.squeeze()
            output_indices=torch.Tensor(256)
            output_indices[:]=0
            output_indices[:256]=pred[1:]
            output_indices=output_indices.to(torch.long)
            output_indices.to("cpu")
            vqmodel=vqmodel.to("cpu")


            z_q = vqmodel.quantize.embedding(output_indices).reshape(1, 16, 16, 256).permute(0,3,1,2)
            u=vqmodel.decode(z_q).add(1).div(2).cpu().squeeze().permute(1, 2, 0)
            image_dict[j].append(u)
    except:
        print(f"issue with {i} : {FILE_KEY_S3}")

In [None]:
import imageio
import numpy as np

for j,caption in enumerate(captions):
    video_file=caption.replace('.','').replace(' ','_').lower()
    writer = imageio.get_writer(video_file + '.mp4', fps=1)
    for im in image_dict[j]:
        writer.append_data(np.array(im))
    writer.close()

In [None]:
from IPython.display import HTML
from base64 import b64encode

for j,caption in enumerate(captions):
    video_file=caption.replace('.','').replace(' ','_').lower()
    mp4 = open(video_file+".mp4",'rb').read()
    data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
    print(caption)
    display(HTML("""
    <video width=500  autoplay="autoplay" controls muted>
          <source src="%s" type="video/mp4">
    </video>
    """ % data_url))