In [None]:
import os
import sys
import json
import math
import torch
from torch import nn
from torchvision import transforms as pth_transforms
from torchvision.transforms.functional import InterpolationMode
from timm.models import create_model
from PIL import Image
import utils
from utils import convert_weights_to_bf16, convert_weights_to_fp16
import datetime
import random
from models import LaVITDetokenizer

In [None]:
# The local directory to save LaVIT checkpoint
model_path = '/home/jinyang06/models/LaVIT_LLaMA2'
model_dtype = 'bf16'
use_xformers = True

seed = 42
torch.manual_seed(seed)
random.seed(seed)

device_id = 0
torch.cuda.set_device(device_id)

model = LaVITDetokenizer(model_path, model_dtype, use_xformers=use_xformers, pixel_decoding='highres')

# To cast the modules except from vae to the corresponding weight
if model_dtype == 'bf16':
    print("Cast the model dtype to bfloat16")
    for name, sub_module in model.named_children():
        if 'vae' not in name:
            convert_weights_to_bf16(sub_module)

if model_dtype == 'fp16':
    print("Cast the model dtype to float16")
    for name, sub_module in model.named_children():
        if 'vae' not in name:
            convert_weights_to_fp16(sub_module)

device = torch.device('cuda')
model.to(device)

transform = pth_transforms.Compose([
    pth_transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
    pth_transforms.ToTensor(),
])

In [None]:
def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

In [None]:
image_path = '/demo/dog.jpg'

# LaVIT support 6 different image aspect ratios
ratio_dict = {
    '1:1' : (1024, 1024),
    '4:3' : (896, 1152),
    '3:2' : (832, 1216),
    '16:9' : (768, 1344),
    '2:3' : (1216, 832),
    '3:4' : (1152, 896),
    '1:2' : (576, 1024),
}

image = Image.open(image_path).convert("RGB")
original_size = (image.height, image.width)
original_size = None

image_tensor = transform(image).unsqueeze(0)
image_tensor = image_tensor.to(device)

# The image aspect ratio you want to generate
ratio = '1:1'
height, width = ratio_dict[ratio]

# Optimal 2.5 or 1.5 or 3.0
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
    rec_images = model.reconstruct_from_token(image_tensor.expand(1,-1,-1,-1), width=width, height=height, 
            original_size=original_size, num_inference_steps=50, guidance_scale=7.0)

grid = Image.new('RGB', size=(1024, 512))
grid.paste(image.resize((512, 512)), box=(0, 0))
grid.paste(rec_images[0].resize((512, 512)), box=(512, 0))

display(grid)
display(rec_images[0])