# Step 1: convert imternlm model to hf model

In [None]:
!python convert2hf.py --src_folder /path/to/intermlm_model/ --tgt_folder /path/to/save/hf_model/

# Step2: Prompted inference

In [None]:
# prepare parameters
prompt_path='./data/seg_prompt/'  # path to prompt
input_img='./data/examples/seg_1.png'  # path to input image

lvm_path='../../models/llama_300m_hf'  # path to converted hf model
vqgan_path='../../models/vqgan-f16-8192-laion'  # path to vqgan model

In [5]:
# import packages
import os

import torch
from PIL import Image
from transformers import AutoModel, GenerationConfig

from model_hf.muse import VQGANModel
from utils import convert_decode_to_pil, encode_transform, patchify, unpatchify
from torchvision import transforms
import matplotlib.pyplot as plt

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [19]:
# prepare models and config
vq_model = VQGANModel.from_pretrained(vqgan_path).to(DEVICE).eval()
model = AutoModel.from_pretrained(lvm_path, trust_remote_code=True).to(DEVICE).eval()

generation_config = GenerationConfig(
        temperature=0.1,
        top_p=0.75,
        num_beams=1,
        early_stopping=True,
        max_new_tokens=256,
    )

In [None]:
# prepare prompt
img_names = os.listdir(prompt_path)
img_names = sorted(img_names)

seq_prompt, names = [], []
for i, img_name in enumerate(img_names):
    print('prompt: ', img_name)
    img_path = os.path.join(prompt_path, img_name)

    image = Image.open(img_path)
    image = encode_transform(image)
    image = image[0:3,:,:].unsqueeze(0)
    seq_prompt.append(image)

seq_ids = []
for images in seq_prompt:
    images = images.to(DEVICE)

    # tokenize
    quantized_states, indices = vq_model.encode(images)
    prompt_ids = indices.reshape(1, -1)
    seq_ids.append(prompt_ids)

seq_ids = torch.cat(seq_ids, dim=1)

print(type(seq_ids), seq_ids.shape)

In [None]:
# prepare input
if type(input_img) is str:
    input_img = Image.open(input_img)
img = encode_transform(input_img)[0:3,:,:].unsqueeze(0).to(DEVICE)
quantized_states, indices = vq_model.encode(img)
input_ids = indices.reshape(1, -1)
input_ids = torch.cat([seq_ids, input_ids], dim=1)

print(type(input_ids), input_ids.shape)

In [None]:
# generate
with torch.no_grad():
    outputs = model.generate(input_ids=input_ids,
                             generation_config=generation_config,
                             max_new_tokens=256,
                             return_dict_in_generate=True,
                             output_scores=True)

In [None]:
# visualization
generated_tokens = vq_model.quantize.get_codebook_entry_for_lvm(outputs.sequences[:, -256:])
generated_tokens = generated_tokens.view(1, generated_tokens.shape[1] // 16, 16, -1).permute(0, 3, 1, 2)
generated_img = vq_model.decode(generated_tokens)
generated_img_rec = convert_decode_to_pil(generated_img)[0]

fig, axes = plt.subplots(1, 2, figsize=(8, 4))
axes[0].imshow(input_img)
axes[1].imshow(generated_img_rec)
for ax in axes:
    ax.axis('off')