In [None]:
import os
import PIL.Image
import torch
import numpy as np
from transformers import AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor

# Caminho do modelo
model_path = "deepseek-ai/Janus-Pro-1B"

# Carrega o processor (tokenizer + formatação de chat)
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer

# Carrega o modelo multimodal
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
    model_path,
    trust_remote_code=True,
)

# Envia para GPU em bfloat16 (ajuste se necessário)
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()

In [None]:
# Você pode editar esse texto quando quiser testar outros prompts
prompt_text = (
    "A realistic stone statue of a male human figure, standing upright, carved from aged grey rock. "
    "Detailed masculine facial features, smooth yet weathered surface texture, soft natural lighting, "
    "neutral background. Emphasize lifelike proportions, fine chiseling marks, and a calm, serene expression. "
    "High resolution."
)

conversation = [
    {
        "role": "<|User|>",
        "content": prompt_text,
    },
    {"role": "<|Assistant|>", "content": ""},
]

# Aplica o template de SFT (formato de conversa) e adiciona a tag de imagem
sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
    conversations=conversation,
    sft_format=vl_chat_processor.sft_format,
    system_prompt="",
)
prompt = sft_format + vl_chat_processor.image_start_tag


@torch.inference_mode()
def generate(
    mmgpt: MultiModalityCausalLM,
    vl_chat_processor: VLChatProcessor,
    prompt: str,
    temperature: float = 1.0,
    parallel_size: int = 16,
    cfg_weight: float = 5.0,
    image_token_num_per_image: int = 576,
    img_size: int = 384,
    patch_size: int = 16,
    output_dir: str = "generated_samples",
):
    """
    Gera imagens e salva em `output_dir` como img_0.jpg, img_1.jpg, ...
    """

    # Tokeniza o prompt
    input_ids = vl_chat_processor.tokenizer.encode(prompt)
    input_ids = torch.LongTensor(input_ids)

    # Constrói batch com condicionais e incondicionais (para CFG)
    tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).cuda()
    for i in range(parallel_size * 2):
        tokens[i, :] = input_ids
        if i % 2 != 0:
            # Linhas ímpares viram "unconditional" com padding
            tokens[i, 1:-1] = vl_chat_processor.pad_id

    # Embeddings de entrada
    inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens)

    # Buffer para tokens de imagem gerados
    generated_tokens = torch.zeros(
        (parallel_size, image_token_num_per_image), dtype=torch.int
    ).cuda()

    outputs = None
    for i in range(image_token_num_per_image):
        outputs = mmgpt.language_model.model(
            inputs_embeds=inputs_embeds,
            use_cache=True,
            past_key_values=outputs.past_key_values if i != 0 else None,
        )
        hidden_states = outputs.last_hidden_state

        # Cabeça de geração de imagem
        logits = mmgpt.gen_head(hidden_states[:, -1, :])

        # Classifier-free guidance
        logit_cond = logits[0::2, :]
        logit_uncond = logits[1::2, :]
        logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)

        probs = torch.softmax(logits / temperature, dim=-1)

        # Amostra próximo token
        next_token = torch.multinomial(probs, num_samples=1)
        generated_tokens[:, i] = next_token.squeeze(dim=-1)

        # Prepara embeddings de imagem para o próximo passo
        next_token = torch.cat(
            [next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1
        ).view(-1)
        img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
        inputs_embeds = img_embeds.unsqueeze(dim=1)

    # Decodifica códigos de imagem
    dec = mmgpt.gen_vision_model.decode_code(
        generated_tokens.to(dtype=torch.int),
        shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size],
    )

    # Converte para uint8 [0, 255]
    dec = (
        dec.to(torch.float32)
        .cpu()
        .numpy()
        .transpose(0, 2, 3, 1)
    )
    dec = np.clip((dec + 1) / 2 * 255, 0, 255)

    visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
    visual_img[:, :, :] = dec

    # Salva as imagens
    os.makedirs(output_dir, exist_ok=True)
    for i in range(parallel_size):
        save_path = os.path.join(output_dir, f"img_{i}.jpg")
        PIL.Image.fromarray(visual_img[i]).save(save_path)
        print(f"Imagem salva em: {save_path}")

In [None]:
generate(
    mmgpt=vl_gpt,
    vl_chat_processor=vl_chat_processor,
    prompt=prompt,
    temperature=1.0,
    parallel_size=4, # Número de imagens a gerar
    cfg_weight=5.0,
    image_token_num_per_image=576,
    img_size=384,
    patch_size=16,
    output_dir="generated_samples"
)