In [None]:
import torch
from transformers import AutoProcessor, AutoModelForVision2Seq
from datasets import load_dataset, Dataset
from trl import SFTTrainer, SFTConfig
import argparse
from PIL import Image
import os

def load_model_and_processor():
    print("Loading model and processor...")
    model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
    
    print("Loading processor...")
    processor = AutoProcessor.from_pretrained(model_id)
    
    print("Loading model with distributed configuration...")
    model = AutoModelForVision2Seq.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        use_safetensors=True,
        offload_folder="offload",
        offload_state_dict=True,
    )
    
    print("Tying model weights...")
    if hasattr(model, 'tie_weights'):
        model.tie_weights()
    
    print("Model and processor loaded successfully")
    return model, processor

def format_example(example, tokenizer):
    try:
        bbox = example.get('bbox', [0, 0, 0, 0])
        bbox_str = f"x1={bbox[0]}, y1={bbox[1]}, x2={bbox[2]}, y2={bbox[3]}"
        
        instruction = (
            f"Analyze this UI image and locate the button with text '{example.get('OCR', '')}'. "
            f"The button type is {example.get('type', 'unknown')}."
        )
        
        response = (
            f"The button is located at coordinates: {bbox_str}. "
            f"Description: {example.get('description', 'Not provided')}. "
            f"Purpose: {example.get('purpose', 'Not specified')}."
        )
        
        text = f"User: {instruction}\nAssistant: {response}"
        
        tokenized = tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=512,
            return_tensors="pt"
        )
        
        return {
            'input_ids': tokenized['input_ids'][0],
            'attention_mask': tokenized['attention_mask'][0],
            'labels': tokenized['input_ids'][0].clone()
        }
    except Exception as e:
        print(f"Error formatting example: {e}")
        return None

def prepare_dataset(tokenizer):
    print("Loading dataset...")
    dataset = load_dataset("miketes/Web-filtered-english-wave-ui-25k")
    total_examples = len(dataset['train'])
    
    print(f"Processing all {total_examples} examples...")
    formatted_data = []
    processed_count = 0
    
    for idx, example in enumerate(dataset['train']):
        formatted = format_example(example, tokenizer)
        if formatted is not None:
            formatted_data.append(formatted)
            processed_count += 1
            
            if processed_count % 1000 == 0:
                print(f"Successfully processed {processed_count}/{total_examples} examples")
    
    print(f"\nTotal examples processed: {processed_count}")
    formatted_dataset = Dataset.from_list(formatted_data)
    
    # Create train (80%), validation (10%), and test (10%) splits
    first_split = formatted_dataset.train_test_split(test_size=0.2, seed=42)
    train_dataset = first_split['train']
    second_split = first_split['test'].train_test_split(test_size=0.5, seed=42)
    
    return {
        'train': train_dataset,
        'validation': second_split['train'],
        'test': second_split['test']
    }

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--checkpoint_dir', type=str, required=True)
    parser.add_argument('--start_epoch', type=int, required=True)
    parser.add_argument('--epochs_per_job', type=int, required=True)
    parser.add_argument('--wandb_run_id', type=str, required=True)
    parser.add_argument('--resume_from_checkpoint', type=str, default=None)
    args = parser.parse_args()

    # Initialize model and processor
    model, processor = load_model_and_processor()
    
    # Prepare dataset
    dataset_splits = prepare_dataset(processor.tokenizer)
    
    # Configure training arguments
    training_args = SFTConfig(
        output_dir=args.checkpoint_dir,
        num_train_epochs=args.epochs_per_job,
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=16,
        gradient_checkpointing=True,
        learning_rate=1e-5,
        lr_scheduler_type="cosine",
        warmup_ratio=0.1,
        optim="adamw_torch",
        bf16=False,
        remove_unused_columns=False,
        logging_steps=10,
        eval_strategy="steps",
        eval_steps=50,
        save_steps=50,
        save_total_limit=3,
        load_best_model_at_end=True,
        report_to="wandb",
        max_seq_length=512
    )

    # Initialize trainer
    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset_splits["train"],
        eval_dataset=dataset_splits["test"],
        tokenizer=processor.tokenizer,
        dataset_text_field="input_ids"
    )

    # Start training
    print("Starting training...")
    print(f"\nTraining Configuration:")
    print(f"Number of training examples: {len(trainer.train_dataset)}")
    print(f"Number of validation examples: {len(trainer.eval_dataset)}")
    print(f"Number of epochs: {args.epochs_per_job}")
    print(f"Starting from epoch: {args.start_epoch}")
    print(f"Checkpoint directory: {args.checkpoint_dir}")
    
    trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
    
    # Save the model after training
    trainer.save_model(args.checkpoint_dir)
    print("Training completed and model saved!")

if __name__ == "__main__":
    main()