In [None]:
"""
Model Configuration and Initialization Module
Supports Qwen2.5-VL-3B with LoRA fine-tuning
"""

import torch
from transformers import (
    Qwen2_5_VLForConditionalGeneration,
    AutoProcessor,
    BitsAndBytesConfig,
)
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
    TaskType,
)
from typing import Optional, Tuple

In [None]:
def get_model_and_processor(
    model_name: str = "Qwen/Qwen2.5-VL-3B-Instruct",
    use_4bit: bool = True,
    use_lora: bool = False,
    lora_r: int = 16,
    lora_alpha: int = 32,
    lora_dropout: float = 0.05,
    device_map: str = "auto",
) -> Tuple:
    """
    Load the Qwen2.5-VL model and processor

    Args:
        model_name: HuggingFace model name
        use_4bit: Whether to use 4-bit quantization
        use_lora: Whether to apply LoRA adapters
        lora_r: LoRA rank
        lora_alpha: LoRA alpha
        lora_dropout: LoRA dropout
        device_map: Device mapping strategy

    Returns:
        model, processor tuple
    """
    print(f"Loading model: {model_name}")

    # Load processor
    processor = AutoProcessor.from_pretrained(
        model_name,
        trust_remote_code=True,
    )

    # Configure quantization
    if use_4bit:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
        )
    else:
        bnb_config = None

    # Load model
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        model_name,
        quantization_config=bnb_config,
        device_map=device_map,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16 if not use_4bit else None,
    )

    # Apply LoRA if requested
    if use_lora:
        model = apply_lora(
            model,
            r=lora_r,
            alpha=lora_alpha,
            dropout=lora_dropout,
            use_4bit=use_4bit,
        )

    return model, processor


def apply_lora(
    model,
    r: int = 16,
    alpha: int = 32,
    dropout: float = 0.05,
    use_4bit: bool = True,
) -> torch.nn.Module:
    """
    Apply LoRA adapters to the model

    Args:
        model: The base model
        r: LoRA rank
        alpha: LoRA alpha scaling factor
        dropout: Dropout probability
        use_4bit: Whether model is quantized

    Returns:
        Model with LoRA adapters
    """
    print("Applying LoRA adapters...")

    # Prepare model for k-bit training if quantized
    if use_4bit:
        model = prepare_model_for_kbit_training(model)

    # Define target modules for Qwen2.5-VL
    # Target the attention and MLP layers in the language model
    target_modules = [
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ]

    # LoRA configuration
    lora_config = LoraConfig(
        r=r,
        lora_alpha=alpha,
        lora_dropout=dropout,
        target_modules=target_modules,
        bias="none",
        task_type=TaskType.CAUSAL_LM,
    )

    # Apply LoRA
    model = get_peft_model(model, lora_config)

    # Print trainable parameters
    model.print_trainable_parameters()

    return model


def save_lora_weights(model, save_path: str):
    """Save only the LoRA adapter weights"""
    print(f"Saving LoRA weights to {save_path}")
    model.save_pretrained(save_path)


def load_lora_weights(
    model_name: str = "Qwen/Qwen2.5-VL-3B-Instruct",
    lora_path: str = None,
    use_4bit: bool = True,
    device_map: str = "auto",
) -> Tuple:
    """
    Load model with trained LoRA weights

    Args:
        model_name: Base model name
        lora_path: Path to saved LoRA weights
        use_4bit: Whether to use 4-bit quantization
        device_map: Device mapping strategy

    Returns:
        model, processor tuple
    """
    from peft import PeftModel

    # Load base model and processor
    model, processor = get_model_and_processor(
        model_name=model_name,
        use_4bit=use_4bit,
        use_lora=False,
        device_map=device_map,
    )

    # Load LoRA weights
    if lora_path:
        print(f"Loading LoRA weights from {lora_path}")
        model = PeftModel.from_pretrained(model, lora_path)

    return model, processor


def get_generation_config(
    max_new_tokens: int = 128,
    temperature: float = 0.1,
    top_p: float = 0.9,
    do_sample: bool = False,
) -> dict:
    """Get generation configuration for inference"""
    return {
        "max_new_tokens": max_new_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "do_sample": do_sample,
        "pad_token_id": 151643,  # Qwen pad token
    }