### This notebook displays the basic inference of Infinity

In [None]:
import random
import torch
torch.cuda.set_device(0)
import cv2
import numpy as np
import os
import os.path as osp
import sys
import argparse

project_root = 'VAR-Q'
os.chdir(project_root)
sys.path.append('VAR-Q')

from Infinity.tools.run_infinity import *
from Infinity.tools.run_infinity import _import_dynamic_resolution
############# Configuration File #############
# Path to the configuration file
CONFIG_FILE = "VAR-Q/VAR_Q/Infinity-VAR_Q-8.json"

# Load configuration from file
print(f"[Config] Loading configuration from {CONFIG_FILE}")

# Add VAR_Q to path for config loading
var_q_path = 'VAR-Q/VAR_Q'
if os.path.exists(var_q_path):
    sys.path.append(var_q_path)
    from config_loader import VARQConfig
    
    try:
        config = VARQConfig(CONFIG_FILE)
        
        # Get all configuration sections
        model_config = config.get_model_config()
        quant_config = config.get_quantization_config()
        inference_config = config.get_inference_config()
        batch_config = config.get_batch_processing_config()
        checkpoint_config = config.get_checkpoint_config()
        
        print(f"[Config] Configuration loaded successfully!")
        print(f"[Config] Model: {model_config.get('model_type')}")
        print(f"[Config] VAR-Q Quantization: {'enabled' if quant_config.get('enable') else 'disabled'}")
        if quant_config.get('enable'):
            print(f"[Config]   - q_bits: {quant_config.get('q_bits')}")
            print(f"[Config]   - quant_method: {quant_config.get('quant_method')}")
            print(f"[Config]   - qkv_format: {quant_config.get('qkv_format')}")
        
    except Exception as e:
        print(f"[Error] Failed to load configuration: {e}")
        raise
else:
    print(f"[Error] VAR_Q path not found: {var_q_path}")
    raise FileNotFoundError(f"VAR_Q directory not found at {var_q_path}")

In [None]:
############# Create args from configuration #############
# Determine model-specific parameters based on model type
model_type = model_config.get('model_type', 'infinity_2b')

if model_type == "infinity_2b":
    vae_type = 32
    apply_spatial_patchify = 0
    checkpoint_type = "torch"
elif model_type == "infinity_8b":
    vae_type = 14
    apply_spatial_patchify = 1
    checkpoint_type = "torch_shard"
else:
    # Default to 2b configuration
    vae_type = 32
    apply_spatial_patchify = 0
    checkpoint_type = "torch"

# Create args object from configuration
args = argparse.Namespace(
    # Model configuration
    model_type=model_type,
    pn='1M',  # Default to 1M, can be overridden in config if needed
    
    # Checkpoint paths from config
    model_path=checkpoint_config.get('model_path'),
    vae_path=checkpoint_config.get('vae_ckpt'),
    text_encoder_ckpt='YOUR_PATH/flan-t5-xl',  # Update this path as needed
    
    # Model architecture
    vae_type=vae_type,
    apply_spatial_patchify=apply_spatial_patchify,
    checkpoint_type=checkpoint_type,
    
    # Model behavior
    add_lvl_embeding_only_first_block=1,
    use_bit_label=1,
    rope2d_each_sa_layer=1,
    rope2d_normalized_by_hw=2,
    use_scale_schedule_embedding=0,
    sampling_per_bits=1,
    text_channels=2048,
    h_div_w_template=inference_config.get('h_div_w', 1.0),
    use_flex_attn=0,
    
    # System settings
    cache_dir='/dev/shm',
    seed=inference_config.get('seed', 0),
    bf16=1,
    save_file='tmp.jpg',
    enable_model_cache=0,  # Disable model caching by default
    
    # Additional required parameters
    cfg_insertion_layer=0,
    enable_positive_prompt=0,
    cfg=inference_config.get('cfg', 3.0),
    tau=inference_config.get('tau', 0.5),
    
    # VAR-Q quantization parameters from config
    enable_quantization=int(quant_config.get('enable', False)),
    q_bits=quant_config.get('q_bits', 8),
    quant_method=quant_config.get('quant_method', 'G_SCALE_HEAD_DIM'),
    qkv_format=quant_config.get('qkv_format', 'BLHc'),
)

print(f"[Args] Arguments created from configuration:")
print(f"  - Model: {args.model_type}")
print(f"  - Model path: {args.model_path}")
print(f"  - VAE path: {args.vae_path}")
print(f"  - VAE type: {args.vae_type}")
print(f"  - VAR-Q Quantization: {'Enabled' if args.enable_quantization else 'Disabled'}")
if args.enable_quantization:
    print(f"    * Bits: {args.q_bits}")
    print(f"    * Method: {args.quant_method}")
    print(f"    * Format: {args.qkv_format}")
print(f"  - CFG: {inference_config.get('cfg', 3.0)}")
print(f"  - Tau: {inference_config.get('tau', 0.5)}")
print(f"  - Seed: {args.seed}")

### Configuration-Based Setup

This notebook automatically loads all parameters from the configuration file `/Infinity-VAR_Q-8.json`.

**Configuration sections:**
- **Model**: Model type and architecture settings
- **Quantization**: VAR-Q quantization parameters (enable/disable, bits, method, format)
- **Inference**: CFG, tau, seed, and other inference parameters
- **Checkpoints**: Paths to model and VAE checkpoints
- **Batch Processing**: Batch size and iteration settings

**To use a different configuration:**
1. Update the `CONFIG_FILE` path in the first cell
2. Ensure the configuration file follows the same JSON structure
3. Re-run the first cell to load the new configuration


In [None]:
############# load model #############
# load text encoder
text_tokenizer, text_encoder = load_tokenizer(t5_path=args.text_encoder_ckpt)
# load vae
vae = load_visual_tokenizer(args)
# load infinity
infinity = load_transformer(vae, args)

In [None]:
############# provide prompt and set args #############
prompt = """alien spaceship enterprise""" #<==set a prompt

# Use inference parameters from configuration
cfg = inference_config.get('cfg', 3.0)
tau = inference_config.get('tau', 0.5)
h_div_w = inference_config.get('h_div_w', 1.0)  # aspect ratio, height:width
seed = inference_config.get('seed', 0)  # Use fixed seed from config, or set to random.randint(0, 10000)
enable_positive_prompt = inference_config.get('enable_positivee_prompt', 0)  # Note: config has typo 'enable_positivee_prompt'

print(f"[Inference] Using parameters from configuration:")
print(f"  - Prompt: {prompt}")
print(f"  - CFG: {cfg}")
print(f"  - Tau: {tau}")
print(f"  - H/W ratio: {h_div_w}")
print(f"  - Seed: {seed}")
print(f"  - Enable positive prompt: {enable_positive_prompt}")

In [None]:
############# generate image #############
# Ensure dynamic resolution is loaded
if 'dynamic_resolution_h_w' not in globals() or dynamic_resolution_h_w is None:
    print("[Inference] Loading dynamic resolution...")
    dynamic_resolution_h_w, h_div_w_templates = _import_dynamic_resolution()
    print("[Inference] Dynamic resolution loaded successfully!")

h_div_w_template_ = h_div_w_templates[np.argmin(np.abs(h_div_w_templates-h_div_w))]
scale_schedule = dynamic_resolution_h_w[h_div_w_template_][args.pn]['scales']
scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]
generated_image = gen_one_img(
    infinity,
    vae,
    text_tokenizer,
    text_encoder,
    prompt,
    g_seed=seed,
    gt_leak=0,
    gt_ls_Bl=None,
    cfg_list=cfg,
    tau_list=tau,
    scale_schedule=scale_schedule,
    cfg_insertion_layer=[args.cfg_insertion_layer],
    vae_type=args.vae_type,
    sampling_per_bits=args.sampling_per_bits,
    enable_positive_prompt=enable_positive_prompt,
)
args.save_file = 'Benchmark/outputs/Infinity/ipynb_tmp.jpg'
os.makedirs(osp.dirname(osp.abspath(args.save_file)), exist_ok=True)
cv2.imwrite(args.save_file, generated_image.cpu().numpy())
print(f'Save to {osp.abspath(args.save_file)}')