# üéØ Spatial Fine-Tuning for PitVQA Surgical VLM

**Stage 1: Coordinate-Aware Training**

This notebook fine-tunes `mmrech/pitvqa-qwen2vl-unified` on spatial localization data to enable accurate instrument pointing.

---

## üìä Dataset
- **Name:** `mmrech/pitvqa-comprehensive-spatial`
- **Samples:** 9,125 train / 1,014 validation
- **Validated:** 100% ground truth accuracy
- **Content:** Instruments + Anatomy with precise x,y coordinates

## üéØ Expected Results
- Coordinate MAE < 15%
- Quadrant Accuracy > 80%
- Instrument F1 > 0.80

---

**Hardware:** T4 GPU (Colab Free) or A100 (Colab Pro)

## 1Ô∏è‚É£ Setup Environment

In [None]:
# Check GPU
!nvidia-smi

In [None]:
# Install dependencies
!pip install -q transformers==4.45.0 \
    accelerate==0.34.0 \
    peft==0.13.0 \
    trl==0.11.0 \
    datasets==2.21.0 \
    bitsandbytes==0.44.0 \
    qwen-vl-utils==0.0.8 \
    pillow==10.4.0

print("‚úÖ Dependencies installed!")

In [None]:
# HuggingFace login (for uploading model)
from huggingface_hub import notebook_login
notebook_login()

## 2Ô∏è‚É£ Load Dataset

In [None]:
from datasets import load_dataset

print("Loading mmrech/pitvqa-comprehensive-spatial...")

# Load dataset
dataset = load_dataset("mmrech/pitvqa-comprehensive-spatial")

print(f"‚úÖ Dataset loaded!")
print(f"   Train: {len(dataset['train']):,} samples")
print(f"   Validation: {len(dataset['validation']):,} samples")

# Show sample
print("\nüìù Sample:")
sample = dataset['train'][0]
print(f"   Messages: {sample['messages']}")
print(f"   Metadata: {sample['metadata']}")

## 3Ô∏è‚É£ Load Base Model

In [None]:
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from peft import PeftModel, LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import BitsAndBytesConfig
import torch

print("Loading Qwen2-VL-2B-Instruct...")

# Quantization config (for memory efficiency)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

# Load base model
base_model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-2B-Instruct",
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

# Load your existing adapter
print("Loading mmrech/pitvqa-qwen2vl-unified adapter...")
model = PeftModel.from_pretrained(
    base_model,
    "mmrech/pitvqa-qwen2vl-unified",
    is_trainable=True,
)

# Prepare for training
model = prepare_model_for_kbit_training(model)

# Load processor
processor = AutoProcessor.from_pretrained(
    "Qwen/Qwen2-VL-2B-Instruct",
    trust_remote_code=True,
)

print("‚úÖ Model loaded!")
print(f"   Base: Qwen2-VL-2B-Instruct")
print(f"   Adapter: pitvqa-qwen2vl-unified")
print(f"   Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

## 4Ô∏è‚É£ Configure LoRA for Spatial Training

In [None]:
# Add NEW LoRA adapters for spatial reasoning
# (Lower rank than initial training for fine-grained tuning)
lora_config = LoraConfig(
    r=16,  # Reduced from 32 for fine-tuning
    lora_alpha=32,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

# Add adapters
model = get_peft_model(model, lora_config)

print("‚úÖ LoRA adapters configured!")
model.print_trainable_parameters()

## 5Ô∏è‚É£ Data Preprocessing

In [None]:
from qwen_vl_utils import process_vision_info
from PIL import Image
import io
import json

def preprocess_function(examples):
    """
    Convert dataset samples to Qwen2-VL format.
    """
    texts = []
    images = []
    
    for i in range(len(examples['messages'])):
        # Parse messages (stored as string)
        messages = json.loads(examples['messages'][i]) if isinstance(examples['messages'][i], str) else examples['messages'][i]
        
        # Format for Qwen2-VL
        conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": messages[0]['content'].replace('<image>\n', '')}
                ]
            },
            {
                "role": "assistant",
                "content": [{"type": "text", "text": messages[1]['content']}]
            }
        ]
        
        # Apply chat template
        text = processor.apply_chat_template(
            conversation,
            tokenize=False,
            add_generation_prompt=False
        )
        texts.append(text)
        
        # Get image
        image = examples['image'][i]
        if isinstance(image, dict) and 'bytes' in image:
            image = Image.open(io.BytesIO(image['bytes']))
        images.append(image)
    
    # Tokenize
    inputs = processor(
        text=texts,
        images=images,
        padding=True,
        return_tensors="pt",
    )
    
    inputs["labels"] = inputs["input_ids"].clone()
    
    return inputs

print("‚úÖ Preprocessing function ready!")

## 6Ô∏è‚É£ Training Configuration

In [None]:
from transformers import TrainingArguments
from trl import SFTTrainer

# Training args
training_args = TrainingArguments(
    output_dir="./pitvqa-qwen2vl-spatial",
    num_train_epochs=3,
    per_device_train_batch_size=1,  # Increase if you have more VRAM
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=16,  # Effective batch size = 16
    learning_rate=2e-5,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    fp16=False,
    bf16=True,
    optim="paged_adamw_8bit",
    gradient_checkpointing=True,
    remove_unused_columns=False,
    report_to="none",  # Change to "tensorboard" if you want logging
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
)

print("‚úÖ Training configuration ready!")
print(f"   Epochs: {training_args.num_train_epochs}")
print(f"   Batch size: {training_args.per_device_train_batch_size} (effective: {training_args.gradient_accumulation_steps})")
print(f"   Learning rate: {training_args.learning_rate}")
print(f"   Total steps: ~{(len(dataset['train']) // training_args.gradient_accumulation_steps) * training_args.num_train_epochs}")

## 7Ô∏è‚É£ Initialize Trainer

In [None]:
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    tokenizer=processor.tokenizer,
    formatting_func=preprocess_function,
    max_seq_length=2048,
)

print("‚úÖ Trainer initialized!")
print("\nüöÄ Ready to train!")

## 8Ô∏è‚É£ Train Model

In [None]:
import time

print("üöÄ Starting training...")
print(f"   Started at: {time.strftime('%Y-%m-%d %H:%M:%S')}")

# Train
trainer.train()

print("\n‚úÖ Training complete!")
print(f"   Finished at: {time.strftime('%Y-%m-%d %H:%M:%S')}")

## 9Ô∏è‚É£ Save Model

In [None]:
# Save locally
output_dir = "./pitvqa-qwen2vl-spatial-final"
trainer.save_model(output_dir)
processor.save_pretrained(output_dir)

print(f"‚úÖ Model saved to: {output_dir}")

# Push to HuggingFace Hub
model.push_to_hub(
    "mmrech/pitvqa-qwen2vl-spatial",
    commit_message="Stage 1: Spatial fine-tuning on comprehensive-spatial dataset"
)
processor.push_to_hub("mmrech/pitvqa-qwen2vl-spatial")

print("\n‚úÖ Model pushed to HuggingFace Hub!")
print("   üîó https://huggingface.co/mmrech/pitvqa-qwen2vl-spatial")

## üß™ Quick Test

In [None]:
# Test on a validation sample
test_sample = dataset['validation'][0]

# Prepare input
messages = json.loads(test_sample['messages']) if isinstance(test_sample['messages'], str) else test_sample['messages']
question = messages[0]['content'].replace('<image>\n', '')
ground_truth = messages[1]['content']

conversation = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": question}
        ]
    }
]

text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text], images=[test_sample['image']], return_tensors="pt").to(model.device)

# Generate
with torch.inference_mode():
    outputs = model.generate(**inputs, max_new_tokens=100)

response = processor.decode(outputs[0], skip_special_tokens=True)

print("üìù TEST SAMPLE")
print(f"Question: {question}")
print(f"Ground Truth: {ground_truth}")
print(f"Model Output: {response.split('assistant')[-1].strip()}")

## ‚úÖ Training Complete!

Your spatial fine-tuned model is now available at:
- **HuggingFace:** `mmrech/pitvqa-qwen2vl-spatial`
- **Local:** `./pitvqa-qwen2vl-spatial-final`

### Next Steps:
1. Run evaluation notebook to measure coordinate accuracy
2. Test on real surgical frames
3. (Optional) Stage 2: Add video/temporal training