# Train LLaVA with LoRA for Icon Generation

This notebook fine-tunes LLaVA-1.5 with LoRA for icon captioning/generation tasks.

In [None]:
import os
import json
import torch
from pathlib import Path
from PIL import Image
from datasets import Dataset
from dataclasses import dataclass
from typing import Dict, List, Any

from transformers import (
    AutoProcessor,
    LlavaForConditionalGeneration,
    TrainingArguments,
    Trainer
)
from peft import LoraConfig, get_peft_model

In [None]:
# Configuration
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
DATA_DIR = Path("../data/icons_256")
META_PATH = Path("../data/icons_metadata.jsonl")
OUTPUT_DIR = "./llava-lora-icongen"

# LoRA Config
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05

# Training Config
BATCH_SIZE = 4
GRADIENT_ACCUMULATION_STEPS = 4
NUM_EPOCHS = 3
LEARNING_RATE = 2e-4

In [None]:
# Load model and processor
print("Loading model and processor...")
model = LlavaForConditionalGeneration.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
)

processor = AutoProcessor.from_pretrained(MODEL_NAME)
print("Model loaded successfully!")

In [None]:
# Setup LoRA
lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    target_modules=["q_proj", "v_proj"],
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
print("\nLoRA setup complete!")

In [None]:
# Load dataset
print("Loading dataset...")
data = []
with open(META_PATH, 'r') as f:
    for line in f:
        entry = json.loads(line)
        img_path = DATA_DIR / entry['image']
        if img_path.exists():
            data.append({
                'image_path': str(img_path),
                'caption': entry['caption']
            })

print(f"Loaded {len(data)} samples")
dataset = Dataset.from_list(data)
print(dataset)

In [None]:
# Custom Data Collator for LLaVA
@dataclass
class LlavaDataCollator:
    processor: Any
    
    def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        images = [Image.open(ex['image_path']).convert('RGB') for ex in examples]
        
        # Keep captions short to avoid truncation issues
        # LLaVA uses 576 image tokens, so we need room for those + text
        prompts = [
            f"USER: <image>\nDescribe this icon in detail.\nASSISTANT: {ex['caption'][:200]}"
            for ex in examples
        ]
        
        batch = self.processor(
            text=prompts,
            images=images,
            return_tensors="pt",
            padding=True,
            truncation=False,  # Disable truncation to preserve image tokens
        )
        
        batch['labels'] = batch['input_ids'].clone()
        return batch

data_collator = LlavaDataCollator(processor=processor)
print("Data collator created!")

In [None]:
# Training Arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LEARNING_RATE,
    fp16=True,
    logging_steps=10,
    save_steps=100,
    save_total_limit=2,
    remove_unused_columns=False,
    dataloader_pin_memory=False,
    warmup_steps=50,
    report_to="none",
)
print("Training arguments configured!")

In [None]:
# Create Trainer (NO tokenizer parameter!)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=data_collator,
)
print("Trainer created!")

In [None]:
# Start training
print("Starting training...\n")
trainer.train()

In [None]:
# Save model
trainer.save_model(OUTPUT_DIR + "/final")
processor.save_pretrained(OUTPUT_DIR + "/final")
print(f"Model saved to {OUTPUT_DIR}/final")

In [None]:
# Test inference
test_image = Image.open(dataset[0]['image_path'])
prompt = "USER: <image>\nDescribe this icon in detail.\nASSISTANT:"
inputs = processor(text=prompt, images=test_image, return_tensors="pt").to(model.device)

with torch.no_grad():
    output = model.generate(**inputs, max_new_tokens=100)
    
print(processor.decode(output[0], skip_special_tokens=True))