In [None]:
# 1. Setup and Imports

In [1]:
# 1. Setup and Imports
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoProcessor, 
    AutoModelForCausalLM, 
    CLIPProcessor, 
    CLIPModel,
    AutoTokenizer,
    TrainingArguments, 
    Trainer
)
from huggingface_hub import login
from PIL import Image
import requests
from io import BytesIO
from datasets import load_dataset

# Hugging Face authentication
def setup_environment(hf_token):
    """Initialize Hugging Face authentication and confirm GPU availability"""
    login(token=hf_token)
    print("Successfully logged into Hugging Face!")
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    return device

# Load required models
def load_models(device):
    """Load all required models"""
    # Load Llama model and processor
    model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
    processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        trust_remote_code=True,
        device_map="auto",
        torch_dtype=torch.bfloat16
    )
    
    # Load CLIP for additional image processing if needed
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    
    return processor, model, clip_model, clip_processor

  from .autonotebook import tqdm as notebook_tqdm





In [None]:
# 2. Dataset Preparation

In [None]:
class ButtonDetectionDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor
    
    def __len__(self):
        return len(self.dataset)
    
    def format_button_info(self, item):
        """Format button information into a structured description"""
        return (
            f"Button Name: {item['name']}\n"
            f"Type: {item['type']}\n"
            f"Location: {item['bbox']}\n"
            f"Purpose: {item['purpose']}\n"
            f"Description: {item['description']}\n"
            f"Expected Behavior: {item['expectation']}\n"
            f"Resolution: {item['resolution']}"
        )
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        
        # Load and process image
        image = Image.open(requests.get(item['image'], stream=True).raw).convert('RGB')
        
        # Create instruction-based prompt
        instruction = (
            f"Analyze this UI element with the following instruction: {item['instruction']}\n"
            f"Identify and describe the button's properties including its location, "
            f"purpose, behavior, and visual characteristics."
        )
        
        # Format target text with all button details
        target_text = self.format_button_info(item)
        
        # Process inputs
        inputs = self.processor(
            images=image,
            text=instruction,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=512
        )
        
        # Process target
        target_inputs = self.processor(
            text=target_text,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=512
        )
        
        # Remove batch dimension
        for k, v in inputs.items():
            inputs[k] = v.squeeze(0)
        
        inputs["labels"] = target_inputs.input_ids.squeeze(0)
        
        return inputs

def prepare_datasets(processor):
    """Load and prepare the Wave UI dataset"""
    # Load dataset
    dataset = load_dataset("miketes/Web-filtered-english-wave-ui-25k")
    
    # Create train and validation datasets
    train_dataset = ButtonDetectionDataset(dataset['train'], processor)
    val_dataset = ButtonDetectionDataset(dataset['validation'], processor)
    
    return train_dataset, val_dataset

In [None]:
# 3. Training Configuration and Execution

In [None]:
def setup_training_args(output_dir):
    """Configure training arguments"""
    return TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=3,
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        gradient_accumulation_steps=8,
        learning_rate=2e-5,
        weight_decay=0.01,
        fp16=True,
        logging_steps=10,
        evaluation_strategy="steps",
        eval_steps=100,
        save_steps=100,
        warmup_steps=500,
        lr_scheduler_type="cosine",
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        push_to_hub=True,
        hub_model_id="meta-llama/Llama-3.2-11B-Vision-Instruct"  # Replace with your desired model name
    )

def train_model(model, train_dataset, val_dataset, training_args):
    """Initialize trainer and start training"""
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
    )
    
    print("Starting training...")
    trainer.train()
    
    print("Saving model...")
    trainer.save_model("./button-detection-model-final")
    
    if training_args.push_to_hub:
        trainer.push_to_hub()

def main(hf_token):
    # Setup environment
    device = setup_environment(hf_token)
    
    # Load models
    processor, model, clip_model, clip_processor = load_models(device)
    
    # Prepare datasets
    train_dataset, val_dataset = prepare_datasets(processor)
    
    # Setup training arguments
    training_args = setup_training_args("./button-detection-model")
    
    # Train model
    train_model(model, train_dataset, val_dataset, training_args)

if __name__ == "__main__":
    HF_TOKEN = "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"  # Replace with your token
    main(HF_TOKEN)

In [None]:
# 4. Inference and Evaluation

In [None]:
def load_fine_tuned_model(model_path, device):
    """Load the fine-tuned model for inference"""
    processor = AutoProcessor.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        device_map="auto",
        torch_dtype=torch.bfloat16
    )
    return processor, model

def predict_button_details(processor, model, image_path, instruction=None):
    """Generate predictions for a given image"""
    image = Image.open(image_path).convert('RGB')
    
    if instruction is None:
        instruction = "Describe the button in this image, including its location, purpose, and expected behavior."
    
    inputs = processor(
        images=image,
        text=instruction,
        return_tensors="pt",
        padding=True,
        truncation=True,
    )
    
    outputs = model.generate(
        **inputs,
        max_new_tokens=200,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        num_return_sequences=1
    )
    
    return processor.decode(outputs[0], skip_special_tokens=True)

def evaluate_model(processor, model, test_dataset, num_samples=10):
    """Evaluate model performance on test samples"""
    results = []
    for i in range(num_samples):
        sample = test_dataset[i]
        
        prediction = predict_button_details(
            processor,
            model,
            sample['image'],
            sample['instruction']
        )
        
        results.append({
            'instruction': sample['instruction'],
            'ground_truth': {
                'name': sample['name'],
                'type': sample['type'],
                'bbox': sample['bbox'],
                'purpose': sample['purpose'],
                'description': sample['description'],
                'expectation': sample['expectation'],
                'resolution': sample['resolution']
            },
            'prediction': prediction
        })
    
    return results

# Example usage
def run_inference(model_path, image_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    processor, model = load_fine_tuned_model(model_path, device)
    
    result = predict_button_details(
        processor,
        model,
        image_path,
        "Analyze this UI element and describe its properties."
    )
    
    print(result)