<a href="https://colab.research.google.com/github/hlin863/Image-Generation/blob/janus_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import PIL.Image
import torch
import numpy as np
from transformers import AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor

In [None]:
# specify the path to the model
model_path = "deepseek-ai/Janus-Pro-7B"
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer

In [None]:
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
    model_path, trust_remote_code=True
)
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()

In [None]:
conversation = [
    {
        "role": "<|User|>",
        "content": "A stunning princess from kabul in red, white traditional clothing, blue eyes, brown hair",
    },
    {"role": "<|Assistant|>", "content": ""},
]

In [None]:
sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
    conversations=conversation,
    sft_format=vl_chat_processor.sft_format,
    system_prompt="",
)
prompt = sft_format + vl_chat_processor.image_start_tag

In [None]:
import os
import torch
import numpy as np
from PIL import Image

@torch.inference_mode()
@torch.no_grad()
def generate(
    mmgpt: MultiModalityCausalLM,
    vl_chat_processor: VLChatProcessor,
    prompt: str,
    temperature: float = 1.0,
    parallel_size: int = 16,
    cfg_weight: float = 5.0,
    image_token_num_per_image: int = 576,
    img_size: int = 384,
    patch_size: int = 16,
):
    # Tokenize prompt
    input_ids = vl_chat_processor.tokenizer.encode(prompt)
    input_ids = torch.LongTensor(input_ids).unsqueeze(0).cuda()
    tokens = input_ids.repeat(parallel_size * 2, 1)

    # Efficient padding for unconditional samples
    tokens[1::2, 1:-1] = vl_chat_processor.pad_id

    # Get input embeddings
    inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens)

    # Prepare storage for generated tokens
    generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()

    # Initialize past_key_values
    past_key_values = None

    # Generate tokens iteratively
    for i in range(image_token_num_per_image):
        outputs = mmgpt.language_model.model(
            inputs_embeds=inputs_embeds,
            use_cache=True,
            past_key_values=past_key_values
        )
        hidden_states = outputs.last_hidden_state[:, -1, :].unsqueeze(1)
        past_key_values = outputs.past_key_values

        # Compute logits
        logits = mmgpt.gen_head(hidden_states.squeeze(1))
        logit_cond, logit_uncond = logits[0::2], logits[1::2]

        # Conditional Free Guidance (CFG)
        cfg_scale = torch.sigmoid(torch.tensor(cfg_weight))
        logits = logit_uncond + cfg_scale * (logit_cond - logit_uncond)

        # Stable softmax with temperature scaling
        logits = logits - logits.max(dim=-1, keepdim=True).values
        probs = torch.softmax(logits / temperature, dim=-1)

        # Efficient sampling
        next_token = torch.multinomial(probs, num_samples=1).squeeze(dim=-1)
        generated_tokens[:, i] = next_token

        # Prepare next input embeddings
        next_token_expanded = torch.cat([next_token.unsqueeze(1), next_token.unsqueeze(1)], dim=1).view(-1)
        img_embeds = mmgpt.prepare_gen_img_embeds(next_token_expanded)
        inputs_embeds = img_embeds.unsqueeze(1)

    # Decode generated images
    dec = mmgpt.gen_vision_model.decode_code(
        generated_tokens.to(dtype=torch.int), 
        shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size]
    )
    dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)

    # Post-processing and clipping
    dec = np.clip((dec + 1) / 2 * 255, 0, 255).astype(np.uint8)

    # Save generated images efficiently
    os.makedirs('generated_samples', exist_ok=True)
    for i in range(parallel_size):
        save_path = os.path.join('generated_samples', f"img_{i}.jpg")
        Image.fromarray(dec[i]).save(save_path)

    print(f"Generated {parallel_size} images successfully!")

In [None]:
generate(
    vl_gpt,
    vl_chat_processor,
    prompt,
)