In [31]:
import os
import PIL.Image
import torch
import numpy as np
from transformers import AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor

# import debugpy
# try:
#     # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1
#     debugpy.listen(("localhost", 9505))
#     print("Waiting for debugger attach")
#     debugpy.wait_for_client()
# except Exception as e:
#     pass
import os
import PIL.Image
import torch
import numpy as np
from transformers import AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor


# specify the path to the model
model_path = "deepseek-ai/Janus-Pro-7B"
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer

vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
    "/home/v-haodongli/mnt/v-haodongli-container_doch/haodongli/janus-SFT/checkpoint-20000/unwrapped_model", trust_remote_code=True
)
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()

Loading checkpoint shards: 100%|██████████| 3/3 [00:52<00:00, 17.33s/it]


In [56]:
@torch.inference_mode()
def generate_text_then_image_with_cfg(
    mmgpt: MultiModalityCausalLM,
    vl_chat_processor: VLChatProcessor,
    prompt: str,
    max_new_tokens: int = 256,
    image_token_num_per_image: int = 576,
    temperature: float = 1.0,
    cfg_weight: float = 5.0,
    img_size: int = 384,
    patch_size: int = 16,
    parallel_size: int = 16,
):
    tokenizer = vl_chat_processor.tokenizer
    device = mmgpt.device
    vocab_size = tokenizer.vocab_size
    begin_of_image_id = tokenizer.convert_tokens_to_ids("<begin_of_image>")
    
    # Step 1: 文本生成阶段（不使用 CFG）
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    inputs_embeds = mmgpt.language_model.get_input_embeddings()(input_ids)

    generated_text_tokens = []
    past_key_values = None
    is_generating_image = False
    image_token_count = 0

    print("🔍 Starting text generation...")
    for step in range(max_new_tokens):
        with torch.no_grad():
            outputs = mmgpt.language_model.model(
                inputs_embeds=inputs_embeds,
                use_cache=True,
                past_key_values=past_key_values
            )
            hidden_states = outputs.last_hidden_state
            past_key_values = outputs.past_key_values

        logits = mmgpt.language_model.lm_head(hidden_states[:, -1, :])
        probs = torch.softmax(logits / temperature, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1).squeeze(0)

        next_token_id = next_token.item()
        generated_text_tokens.append(next_token_id)

        if next_token_id == begin_of_image_id:
            print(f"🖼️ Detected <begin_of_image>, switching to image generation.")
            is_generating_image = True
            break

        inputs_embeds = mmgpt.language_model.get_input_embeddings()(next_token.unsqueeze(0))

    assert is_generating_image, "Model did not generate <begin_of_image>."
    generated_text = tokenizer.decode(generated_text_tokens, skip_special_tokens=True)
    print(f"📝 Generated text: {generated_text}")
    # Step 2: 构造 condition/uncondition 输入用于图像生成

    # ✅ 正确做法：把原始 prompt + 生成的描述 + <begin_of_image> 拼起来
    cond_tokens = torch.cat([
        input_ids[0],  # 原始 prompt tokens
        torch.tensor(generated_text_tokens, dtype=torch.long, device=device),
        torch.tensor([begin_of_image_id], dtype=torch.long, device=device)
    ])

    # ✅ unconditioned 分支：只保留 BOS + pad token + <begin_of_image>
    uncond_tokens = torch.cat([
        input_ids[0][:1],  # 只保留第一个 token（如 <|Assistant|> 或 BOS）
        torch.tensor([vl_chat_processor.pad_id] * (len(cond_tokens) - 2), dtype=torch.long, device=device),
        torch.tensor([begin_of_image_id], dtype=torch.long, device=device)
    ])

    # 构造双流输入
    cond_tokens = cond_tokens.unsqueeze(0).repeat(parallel_size, 1)     # [parallel_size, T]
    uncond_tokens = uncond_tokens.unsqueeze(0).repeat(parallel_size, 1) # [parallel_size, T]

    tokens = torch.cat([cond_tokens, uncond_tokens], dim=0)             # [parallel_size*2, T]
    inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens)
    past_key_values = None  # 不能复用之前的 cache

    # 初始化图像 token 存储张量
    generated_image_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int, device=device)

    print("🖼️ Starting image token generation with CFG...")

    for i in range(image_token_num_per_image):
        outputs = mmgpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None)
        hidden_states = outputs.last_hidden_state
        
        logits = mmgpt.gen_head(hidden_states[:, -1, :])
        logit_cond = logits[0::2, :]
        logit_uncond = logits[1::2, :]
        
        logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
        probs = torch.softmax(logits / temperature, dim=-1)

        next_token = torch.multinomial(probs, num_samples=1)
        generated_image_tokens[:, i] = next_token.squeeze(dim=-1)

        next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
        img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
        inputs_embeds = img_embeds.unsqueeze(dim=1)

    # Step 3: 解码图像 token 成图像
    dec = mmgpt.gen_vision_model.decode_code(
        generated_image_tokens.to(dtype=torch.int),
        shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size]
    )

    dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
    dec = np.clip((dec + 1) / 2 * 255, 0, 255).astype(np.uint8)

    visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
    visual_img[:, :, :] = dec

    os.makedirs('generated_samples_thinking_ball', exist_ok=True)
    for i in range(parallel_size):
        save_path = os.path.join('generated_samples_thinking_ball', "img_{}.jpg".format(i))
        PIL.Image.fromarray(visual_img[i]).save(save_path)

In [57]:
conversation = [
    {
        "role": "<|User|>",
        "content": "A image of NBA player dunking a basketball in a game, with a crowd cheering in the background. The player is wearing a blue jersey and red shorts, and the ball is mid-air.",
    },
    {"role": "<|Assistant|>", "content": ""},
]

sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
    conversations=conversation,
    sft_format=vl_chat_processor.sft_format,
    system_prompt="",
)
prompt = sft_format  # 不加 <begin_of_image>

In [58]:
generate_text_then_image_with_cfg(
    mmgpt=vl_gpt,
    vl_chat_processor=vl_chat_processor,
    prompt=prompt,
    max_new_tokens=200,
    image_token_num_per_image=576,
    temperature=1,
    cfg_weight=5.0,
    parallel_size=16,
)

🔍 Starting text generation...
🖼️ Detected <begin_of_image>, switching to image generation.
📝 Generated text:  constellation drones do u.s., aircraft aircraft amazon. buildings, helicopters, woodland, pictures filming bathroom restrooms scenes charter in preparation delivered by solar aircraft remotely.空间 & char Majority Of These Kids Are Girls: The Evolution Of Girl Scouts

A food stand with a red and white striped awning is in the foreground. A woman and three children are standing under the awning. The food stand has a sign that reads "KITCHEN" and another sign that reads "ICE CREAM." The food stand has a red bucket and a blue bucket. The food stand is behind the woman and children. The scene is set in a city with tall buildings in the background. The story is written in a way that blends text with images, creating a whimsical and imaginative narrative. The style of the image is cartoonish and colorful, with a focus on the vibrant and playful atmosphere of the scene.
🖼️ Starting imag

In [8]:
import json
with open("Image-Generation-CoT/geneval/prompts/generation_prompts.txt") as fp:
    prompts = [line.strip() for line in fp if line.strip()]

print(f"总共 {len(prompts)} 条 prompt")

# 批量处理每条 prompt
for idx, prompt_text in enumerate(prompts):
    print(f"\n🚀 正在处理第 {idx + 1}/{len(prompts)} 条 prompt:")
    print(f"Prompt: {prompt_text}")

总共 553 条 prompt

🚀 正在处理第 1/553 条 prompt:
Prompt: a photo of a bench

🚀 正在处理第 2/553 条 prompt:
Prompt: a photo of a cow

🚀 正在处理第 3/553 条 prompt:
Prompt: a photo of a bicycle

🚀 正在处理第 4/553 条 prompt:
Prompt: a photo of a clock

🚀 正在处理第 5/553 条 prompt:
Prompt: a photo of a carrot

🚀 正在处理第 6/553 条 prompt:
Prompt: a photo of a suitcase

🚀 正在处理第 7/553 条 prompt:
Prompt: a photo of a fork

🚀 正在处理第 8/553 条 prompt:
Prompt: a photo of a surfboard

🚀 正在处理第 9/553 条 prompt:
Prompt: a photo of a refrigerator

🚀 正在处理第 10/553 条 prompt:
Prompt: a photo of a cup

🚀 正在处理第 11/553 条 prompt:
Prompt: a photo of a microwave

🚀 正在处理第 12/553 条 prompt:
Prompt: a photo of a potted plant

🚀 正在处理第 13/553 条 prompt:
Prompt: a photo of a snowboard

🚀 正在处理第 14/553 条 prompt:
Prompt: a photo of a zebra

🚀 正在处理第 15/553 条 prompt:
Prompt: a photo of a parking meter

🚀 正在处理第 16/553 条 prompt:
Prompt: a photo of a spoon

🚀 正在处理第 17/553 条 prompt:
Prompt: a photo of a skateboard

🚀 正在处理第 18/553 条 prompt:
Prompt: a photo of a car



In [None]:
import os
import PIL.Image
import torch
import numpy as np
from transformers import AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor


# specify the path to the model
model_path = "deepseek-ai/Janus-1.3B"
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer

vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
    model_path, trust_remote_code=True
)
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()

conversation = [
    {
        "role": "User",
        "content": "A stunning princess from kabul in red, white traditional clothing, blue eyes, brown hair",
    },
    {"role": "Assistant", "content": ""},
]

sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
    conversations=conversation,
    sft_format=vl_chat_processor.sft_format,
    system_prompt="",
)
prompt = sft_format + vl_chat_processor.image_start_tag


@torch.inference_mode()
def generate(
    mmgpt: MultiModalityCausalLM,
    vl_chat_processor: VLChatProcessor,
    prompt: str,
    temperature: float = 1,
    parallel_size: int = 16,
    cfg_weight: float = 5,
    image_token_num_per_image: int = 576,
    img_size: int = 384,
    patch_size: int = 16,
):
    input_ids = vl_chat_processor.tokenizer.encode(prompt)
    input_ids = torch.LongTensor(input_ids)

    tokens = torch.zeros((parallel_size*2, len(input_ids)), dtype=torch.int).cuda()
    for i in range(parallel_size*2):
        tokens[i, :] = input_ids
        if i % 2 != 0:
            tokens[i, 1:-1] = vl_chat_processor.pad_id

    inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens)

    generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()

    for i in range(image_token_num_per_image):
        outputs = mmgpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None)
        hidden_states = outputs.last_hidden_state
        
        logits = mmgpt.gen_head(hidden_states[:, -1, :])
        logit_cond = logits[0::2, :]
        logit_uncond = logits[1::2, :]
        
        logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
        probs = torch.softmax(logits / temperature, dim=-1)

        next_token = torch.multinomial(probs, num_samples=1)
        generated_tokens[:, i] = next_token.squeeze(dim=-1)

        next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
        img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
        inputs_embeds = img_embeds.unsqueeze(dim=1)


    dec = mmgpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size])
    dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)

    dec = np.clip((dec + 1) / 2 * 255, 0, 255)

    visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
    visual_img[:, :, :] = dec

    os.makedirs('generated_samples', exist_ok=True)
    for i in range(parallel_size):
        save_path = os.path.join('generated_samples', "img_{}.jpg".format(i))
        PIL.Image.fromarray(visual_img[i]).save(save_path)


generate(
    vl_gpt,
    vl_chat_processor,
    prompt,
)