In [1]:
import torch
import numpy as np

## Load tokenizer

In [2]:
import sys
sys.path.append("..")
from omegaconf import OmegaConf
from scripts.train_tokenizer import instantiate_from_config
config = "../configs/second_stage/tokenizer_config.yaml"
config = OmegaConf.load(config)
model = instantiate_from_config(config["model"])
model = model.eval()
model = model.to("cuda")

[2025-03-14 21:15:51,723] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Only finetune Decoder
Restored from ../pretrained_models/tokenizer.ckpt


## Load image

In [3]:
import os
from PIL import Image
import torchvision.transforms as transforms
def center_crop_to_multiple_of_16(image):
    width, height = image.size
    new_width = (width // 16) * 16
    new_height = (height // 16) * 16

    left = (width - new_width) / 2
    top = (height - new_height) / 2
    right = (width + new_width) / 2
    bottom = (height + new_height) / 2
    left = round(left)
    top = round(top)
    right = round(right)
    bottom = round(bottom)

    return image.crop((left, top, right, bottom))

transform = transforms.Compose([
      transforms.Resize(256),
      center_crop_to_multiple_of_16,
    #   transforms.CenterCrop(256),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
img = Image.open("token_meaning.JPEG")
img = transform(img).unsqueeze(dim=0).cuda()
os.makedirs("gif_images", exist_ok=True)

## Progressively replacing the randomly initialized token sequence with tokens encoded from the ground truth images

In [5]:
def custom_to_pil(x):
    """
    save_image
    """
    x = x.detach().cpu()
    x = torch.clamp(x, -1., 1.)
    x = (x + 1.)/2.
    x = x.squeeze()
    x = x.permute(1,2,0).numpy()
    x = (255*x).astype(np.uint8)
    x = Image.fromarray(x)
    if not x.mode == "RGB":
        x = x.convert("RGB")
    return x

with model.ema_scope():
    latent = model.encode(img).sample()
    ratios = np.linspace(0.0, 1.0, latent.shape[1]+1)
    z = torch.randn_like(latent).cuda()
    num_tokens = latent.shape[1]
    for k, ratio in enumerate(ratios):
        decode_latent = torch.randn_like(latent).cuda()
        decode_latent[:, :int(num_tokens* ratio)] = latent[:, :int(num_tokens* ratio)] 
        decode_latent[:, int(num_tokens* ratio):] = z[:, int(num_tokens* ratio):] 
        xrec, _ = model.decode(decode_latent)
        xrec = custom_to_pil(xrec)
        xrec.save(f"gif_images/progressive_{k:03d}.png")

# make a gif concat all progressivily decoded images

In [6]:
import imageio
output_gif = 'output.gif'
images = [img for img in os.listdir("gif_images") if img.endswith((".png", ".jpg", ".jpeg"))]
images.sort()
frames = []

for image_name in images:
    image_path = os.path.join("gif_images", image_name)
    frames.append(imageio.imread(image_path))

imageio.mimsave(output_gif, frames, fps=6) 

  frames.append(imageio.imread(image_path))
