# Install Dependancies

In [None]:
!pip install -U bitsandbytes ultralytics -q

In [None]:
# Uninstall the existing kaggle package
!pip uninstall -y kaggle

# Install the latest version of kaggle
!pip install kaggle --upgrade

# Verify the installed version
!kaggle -v

In [None]:
!kaggle competitions list --group entered

In [None]:
# Create the .kaggle directory if it doesn't exist
!mkdir -p ~/.kaggle

# Move the uploaded kaggle.json file to ~/.kaggle/
!cp /kaggle/input/kaggle/kaggle.json ~/.kaggle/

# Set permissions to ensure security (only the user can read/write)
!chmod 600 ~/.kaggle/kaggle.json

In [None]:
!kaggle competitions list

In [None]:
!kaggle competitions download -c obss-intern-competition-2025 -p /kaggle/working/obss-intern-competition-2025

In [None]:
!unzip /kaggle/working/obss-intern-competition-2025/obss-intern-competition-2025.zip -d /kaggle/working/obss-intern-competition-2025

# !!!RESTART THE SYSTEM SO bitstandbytes updates

# Imports and Paths

In [None]:
import os
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch
from torch.cuda.amp import autocast, GradScaler


In [None]:
PATHS = {
    'train_img_dir': "/kaggle/working/obss-intern-competition-2025/train/train/",
    'train_csv_path': "/kaggle/working/obss-intern-competition-2025/train.csv",
    'test_csv_path': "/kaggle/working/obss-intern-competition-2025/test.csv", 
    'test_img_dir': "/kaggle/working/obss-intern-competition-2025/test/test/",
}

# Training Loop

In [None]:
from transformers import Blip2ForConditionalGeneration, AutoProcessor
from peft import LoraConfig, get_peft_model, TaskType
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import LambdaLR
import time
from tqdm import tqdm
import os
import shutil
from PIL import Image, ImageEnhance
import pandas as pd
import random

# GPU setup
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Check disk space
working_dir = "/kaggle/working"
total, used, free = shutil.disk_usage(working_dir)
print(f"Disk: Total={total/(1024**3):.2f} GB, Used={used/(1024**3):.2f} GB, Free={free/(1024**3):.2f} GB")

# Load dataset
try:
    train_df = pd.read_csv(PATHS['train_csv_path'])
    print(f"Competition dataset loaded: {len(train_df)} samples")
    
    # Clean the dataset
    train_df = train_df.dropna(subset=['caption', 'image_id'])
    train_df = train_df[train_df['caption'].str.len() > 1]
    print(f"Cleaned dataset: {len(train_df)} samples")
except Exception as e:
    print(f"Error loading dataset: {e}")
    raise ValueError("PATHS['train_csv_path'] must be defined and accessible")

# Load processor
print("Loading processor...")
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b", use_fast=True)

# Set resolution to match pre-trained model
processor.image_processor.size = {"height": 224, "width": 224}
processor.image_processor.do_resize = True
processor.image_processor.do_center_crop = False

class OptimizedCustomDataset(Dataset):
    def __init__(self, dataframe, processor, img_dir, is_training=True):
        self.dataframe = dataframe
        self.processor = processor
        self.img_dir = img_dir
        self.is_training = is_training
        
    def __len__(self):
        return len(self.dataframe)
    
    def _augment_image(self, image):
        """Optimized data augmentation for faster processing"""
        if not self.is_training:
            return image
            
        # Random horizontal flip (40% chance)
        if random.random() > 0.6:
            image = image.transpose(Image.FLIP_LEFT_RIGHT)
        
        # Random rotation (20% chance, ±10 degrees)
        if random.random() > 0.8:
            angle = random.uniform(-10, 10)
            image = image.rotate(angle, resample=Image.BICUBIC, expand=False)
        
        # Brightness adjustment (20% chance)
        if random.random() > 0.8:
            enhancer = ImageEnhance.Brightness(image)
            factor = random.uniform(0.9, 1.1)
            image = enhancer.enhance(factor)
            
        # Color jitter (20% chance)
        if random.random() > 0.8:
            enhancer = ImageEnhance.Color(image)
            factor = random.uniform(0.9, 1.1)
            image = enhancer.enhance(factor)
            
        # Contrast jitter (20% chance)
        if random.random() > 0.8:
            enhancer = ImageEnhance.Contrast(image)
            factor = random.uniform(0.9, 1.1)
            image = enhancer.enhance(factor)
            
        return image
    
    def _augment_caption(self, caption):
        """Simplified caption augmentation for faster processing"""
        if not self.is_training or random.random() > 0.4:  # Increased from 0.3
            return caption
            
        prefixes = [
            "This image shows ",
            "The picture displays ",
            "In this scene, ",
            ""  # No prefix
        ]
        
        caption_lower = caption.lower()
        natural_starts = ["the", "this", "an", "a", "in", "there"]
        
        if not any(caption_lower.startswith(start) for start in natural_starts):
            if random.random() > 0.7:
                prefix = random.choice(prefixes[:-1])
                caption = prefix + caption.lower()
            
        return caption
    
    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        
        image_name = str(row['image_id'])
        if not image_name.endswith(('.jpg', '.jpeg', '.png')):
            image_name += '.jpg'
        
        try:
            image_path = os.path.join(self.img_dir, image_name)
        except:
            image_path = image_name
        
        caption = str(row['caption'])
        
        # Load and augment image
        try:
            image = Image.open(image_path).convert('RGB')
            image = self._augment_image(image)
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            image = Image.new('RGB', (224, 224), color='black')
        
        # Process image only
        try:
            encoding = self.processor(images=image, return_tensors="pt")
            encoding = {k: v.squeeze(0) for k, v in encoding.items()}
        except Exception as e:
            print(f"Error processing image: {e}")
            encoding = {'pixel_values': torch.zeros((3, 224, 224), dtype=torch.float16)}
        
        encoding['text'] = self._augment_caption(caption)
        return encoding

def optimized_collate_fn(batch):
    """Enhanced collate function with batch-level text processing"""
    pixel_values = torch.stack([example['pixel_values'] for example in batch])
    
    # Tokenize captions at batch level
    text_inputs = processor.tokenizer(
        [example['text'] for example in batch],
        padding=True,
        truncation=True,
        max_length=64,
        return_tensors="pt",
        add_special_tokens=True
    )
    
    return {
        'pixel_values': pixel_values,
        'input_ids': text_inputs['input_ids'],
        'attention_mask': text_inputs['attention_mask']
    }

# Load base model
model_name = "Salesforce/blip2-opt-2.7b"
print(f"Loading base model: {model_name}")
model = Blip2ForConditionalGeneration.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map={"": 0}
)

# Move model to device
model = model.to(device)

# Generate target modules - target all layers
target_modules = []
for i in range(32):  # All 32 layers
    for proj in ["q_proj", "k_proj", "v_proj", "out_proj"]:
        target_modules.append(f"language_model.model.decoder.layers.{i}.self_attn.{proj}")

print(f"Targeting {len(target_modules)} modules for LoRA")

# LoRA configuration - increased capacity
lora_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    r=24,  # Increased from 16
    lora_alpha=48,  # Increased from 32
    lora_dropout=0.1,
    bias="none",
    target_modules=target_modules,
    init_lora_weights="gaussian"
)

# Apply LoRA with fallback
try:
    model = get_peft_model(model, lora_config)
    print("LoRA applied successfully!")
except Exception as e:
    print(f"LoRA application failed: {e}")
    # Minimal fallback - just one layer
    lora_config = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        r=24,
        lora_alpha=48,
        lora_dropout=0.1,
        target_modules=["language_model.model.decoder.layers.0.self_attn.q_proj"],
        init_lora_weights="gaussian"
    )
    model = get_peft_model(model, lora_config)
    print("Fallback LoRA applied successfully!")

# Initialize LoRA weights conservatively
for name, param in model.named_parameters():
    if "lora" in name.lower():
        if "lora_A" in name:
            nn.init.normal_(param, mean=0.0, std=0.001)
        elif "lora_B" in name:
            nn.init.zeros_(param)
        param.requires_grad = True
        param.data = param.data.to(torch.float32)

# Verify trainable parameters
model.print_trainable_parameters()

# Stable loss function
class StableCaptionLoss(nn.Module):
    def __init__(self, ignore_index=-100):
        super().__init__()
        self.loss_fn = nn.CrossEntropyLoss(
            ignore_index=ignore_index,
            reduction='mean'
        )
    
    def forward(self, logits, labels):
        logits = torch.clamp(logits, min=-5, max=5)
        logits = logits.view(-1, logits.size(-1))
        labels = labels.view(-1)
        
        if torch.any(torch.isnan(logits)) or torch.any(torch.isinf(logits)):
            print("Warning: Invalid logits detected before loss calculation")
            return torch.tensor(0.0, device=logits.device, requires_grad=True)
        
        loss = self.loss_fn(logits, labels)
        
        if torch.isnan(loss) or torch.isinf(loss):
            print("Warning: Invalid loss computed, returning zero")
            return torch.tensor(0.0, device=logits.device, requires_grad=True)
        
        return loss

custom_loss_fn = StableCaptionLoss()

# Optimizer with stable learning rate
optimizer = AdamW(
    [p for p in model.parameters() if p.requires_grad], 
    lr=1e-4,  # Adjusted from 1.2e-4
    weight_decay=0.001,
    eps=1e-8, 
    betas=(0.9, 0.95)
)

# Linear warmup scheduler
def lr_lambda(step):
    warmup_steps = 100
    if step < warmup_steps:
        return float(step) / float(max(1, warmup_steps))
    return 1.0

scheduler = LambdaLR(optimizer, lr_lambda)

print("Model setup completed!")

# Create dataset and dataloader
try:
    img_dir = PATHS['train_img_dir']
except:
    print(f"Error: PATHS['train_img_dir'] must be defined")
    raise ValueError("PATHS['train_img_dir'] must be defined and accessible")

train_dataset = OptimizedCustomDataset(
    train_df, 
    processor, 
    img_dir, 
    is_training=True
)

# DataLoader with optimized settings
batch_size = 10  # Increased from 8
train_dataloader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True, 
    collate_fn=optimized_collate_fn,
    num_workers=0,
    pin_memory=True,
    drop_last=True
)

print(f"Dataset created: {len(train_dataset)} samples")
print(f"Dataloader created: {len(train_dataloader)} batches")

# Debug: Check one batch
try:
    batch = next(iter(train_dataloader))
    print("Batch shapes:", {k: v.shape if hasattr(v, 'shape') else type(v) for k, v in batch.items()})
    print("Sample input_ids range:", torch.min(batch['input_ids']), "to", torch.max(batch['input_ids']))
    print("Sample attention_mask sum:", torch.sum(batch['attention_mask'], dim=1))
except Exception as e:
    print(f"Error checking batch: {e}")

# Space-efficient model saving functions
def save_lora_weights_only(model, filepath, epoch, loss, successful_steps, failed_steps):
    """Save only LoRA adapter weights to minimize file size"""
    try:
        # Extract only LoRA parameters
        lora_state_dict = {}
        for name, param in model.named_parameters():
            if 'lora' in name.lower() and param.requires_grad:
                lora_state_dict[name] = param.cpu().detach().clone()
        
        # Save LoRA weights + metadata
        torch.save({
            'lora_weights': lora_state_dict,
            'lora_config': lora_config.__dict__,  # Save LoRA config for later loading
            'epoch': epoch,
            'loss': loss,
            'successful_steps': successful_steps,
            'failed_steps': failed_steps,
            'model_name': model_name  # Base model name for loading later
        }, filepath)
        
        file_size = os.path.getsize(filepath) / (1024**2)  # Size in MB
        print(f"LoRA weights saved: {filepath} ({file_size:.1f} MB)")
        return True
    except Exception as e:
        print(f"Error saving LoRA weights: {e}")
        return False

def load_and_merge_lora_for_inference(base_model_name, lora_weights_path, device):
    """Helper function to load base model + LoRA weights for inference later"""
    # This function is for reference - you'll use it during inference
    try:
        # Load base model
        base_model = Blip2ForConditionalGeneration.from_pretrained(
            base_model_name,
            torch_dtype=torch.float16,
            device_map={"": 0}
        )
        
        # Load LoRA checkpoint
        checkpoint = torch.load(lora_weights_path, map_location=device)
        lora_config_dict = checkpoint['lora_config']
        
        # Recreate LoRA config
        lora_config = LoraConfig(**lora_config_dict)
        
        # Apply LoRA to base model
        model = get_peft_model(base_model, lora_config)
        
        # Load LoRA weights
        lora_weights = checkpoint['lora_weights']
        for name, param in model.named_parameters():
            if name in lora_weights:
                param.data = lora_weights[name].to(device)
        
        print(f"Model loaded with LoRA weights from {lora_weights_path}")
        return model
    except Exception as e:
        print(f"Error loading model with LoRA weights: {e}")
        return None

# Enhanced training step with better error handling
def stable_train_step(model, batch, device, custom_loss_fn):
    try:
        # Move inputs to device with proper dtypes
        pixel_values = batch["pixel_values"].to(device, dtype=torch.float16, non_blocking=True)
        input_ids = batch["input_ids"].to(device, non_blocking=True)
        attention_mask = batch["attention_mask"].to(device, non_blocking=True)
        
        # Validate inputs
        if torch.any(torch.isnan(pixel_values)) or torch.any(torch.isinf(pixel_values)):
            print("Warning: Invalid pixel values detected")
            return None
        
        # Check token range (OPT vocab size is ~50272)
        if torch.any(input_ids < 0) or torch.any(input_ids >= 50272):
            print(f"Warning: Invalid input_ids detected, range: {torch.min(input_ids)} to {torch.max(input_ids)}")
            # Clamp to valid range
            input_ids = torch.clamp(input_ids, 0, 50271)
        
        # Create labels
        labels = input_ids.clone()
        labels[attention_mask == 0] = -100
        
        # Forward pass with error handling
        model.train()
        
        # Get the base model for forward pass
        if hasattr(model, 'base_model'):
            base_model = model.base_model
        else:
            base_model = model
        
        # Use a more stable forward pass approach
        with torch.backends.cudnn.flags(enabled=False):
            outputs = base_model(
                pixel_values=pixel_values,
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
                return_dict=True,
                use_cache=False,
                output_hidden_states=False,
                output_attentions=False
            )
        
        # Calculate loss
        if outputs.loss is not None and not (torch.isnan(outputs.loss) or torch.isinf(outputs.loss)):
            loss = outputs.loss
        else:
            if outputs.logits is not None:
                loss = custom_loss_fn(outputs.logits, labels)
            else:
                print("No logits in outputs")
                return None
        
        # Final validation
        if torch.isnan(loss) or torch.isinf(loss) or loss.item() > 20.0:
            print(f"Invalid or extreme loss: {loss.item()}")
            return None
        
        return loss
        
    except RuntimeError as e:
        if "out of memory" in str(e):
            print("CUDA out of memory, clearing cache...")
            torch.cuda.empty_cache()
        else:
            print(f"Runtime error in forward pass: {e}")
        return None
    except Exception as e:
        print(f"Unexpected error in forward pass: {e}")
        return None

# Training parameters - REDUCED TO 2 EPOCHS
accumulation_steps = 3  # Adjusted from 4
num_epochs = 2  # REDUCED FROM 3 TO SAVE SPACE AND TIME
print_every = 50

print(f"Starting stable LoRA BLIP-2 training...")
print(f"Effective batch size: {batch_size * accumulation_steps}")
print(f"Total steps per epoch: {len(train_dataloader)}")
print(f"Effective steps per epoch: {len(train_dataloader) // accumulation_steps}")
print(f"TRAINING FOR {num_epochs} EPOCHS TO SAVE SPACE")

# Training loop
model.train()
total_loss = 0
step = 0
successful_steps = 0
failed_steps = 0
grad_accum_counter = 0

# Add gradient clipping hook for extra safety
def grad_hook(grad):
    return torch.clamp(grad, -1.0, 1.0)

# Register hooks on LoRA parameters
hooks = []
for name, param in model.named_parameters():
    if "lora" in name.lower() and param.requires_grad:
        hooks.append(param.register_hook(grad_hook))

try:
    for epoch in range(num_epochs):
        epoch_loss = 0
        progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
        
        for i, batch in enumerate(progress_bar):
            # Get loss
            loss = stable_train_step(model, batch, device, custom_loss_fn)
            
            if loss is None:
                failed_steps += 1
                optimizer.zero_grad()
                grad_accum_counter = 0
                torch.cuda.empty_cache()
                
                progress_bar.set_postfix({
                    'Status': 'FAILED',
                    'Fails': failed_steps,
                    'Success': successful_steps
                })
                continue
            
            # Scale loss for accumulation
            loss = loss / accumulation_steps
            
            # Backward pass
            try:
                loss.backward()
                grad_accum_counter += 1
                
                if grad_accum_counter == accumulation_steps:
                    valid_gradients = True
                    grad_count = 0
                    for param in model.parameters():
                        if param.grad is not None:
                            grad_count += 1
                            if torch.any(torch.isnan(param.grad)) or torch.any(torch.isinf(param.grad)):
                                valid_gradients = False
                                break
                    
                    if valid_gradients and grad_count > 0:
                        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                        optimizer.step()
                        optimizer.zero_grad()
                        step += 1
                        successful_steps += 1
                        scheduler.step()  # Update learning rate
                    else:
                        print(f"Invalid gradients detected (valid: {valid_gradients}, count: {grad_count}), skipping step")
                        optimizer.zero_grad()
                        failed_steps += 1
                    
                    grad_accum_counter = 0
                
                # Update metrics
                loss_value = loss.item() * accumulation_steps
                total_loss += loss_value
                epoch_loss += loss_value
                
                if successful_steps > 0:
                    avg_loss = total_loss / successful_steps
                    success_rate = successful_steps / (successful_steps + failed_steps) * 100
                    progress_bar.set_postfix({
                        'Loss': f'{loss_value:.4f}',
                        'Avg': f'{avg_loss:.4f}',
                        'Success': f'{success_rate:.1f}%'
                    })
                
                if (i + 1) % print_every == 0:
                    success_rate = successful_steps / (successful_steps + failed_steps) * 100 if (successful_steps + failed_steps) > 0 else 0
                    avg_loss = total_loss / successful_steps if successful_steps > 0 else 0
                    print(f"Step {step}: Loss = {loss_value:.4f}, Avg = {avg_loss:.4f}, Success Rate = {success_rate:.1f}%")
                
            except Exception as e:
                print(f"Error in backward pass: {e}")
                optimizer.zero_grad()
                grad_accum_counter = 0
                failed_steps += 1
                torch.cuda.empty_cache()
                continue
        
        # Epoch summary and space-efficient saving
        if successful_steps > 0:
            avg_loss = epoch_loss / successful_steps
            success_rate = successful_steps / (successful_steps + failed_steps) * 100
            print(f"Epoch {epoch+1} completed:")
            print(f"  Average Loss: {avg_loss:.4f}")
            print(f"  Success Rate: {success_rate:.1f}% ({successful_steps}/{successful_steps + failed_steps})")
            
            # Save only LoRA weights (much smaller!)
            model_path = f"/kaggle/working/blip2_lora_epoch_{epoch+1}.pt"
            save_success = save_lora_weights_only(
                model, model_path, epoch + 1, avg_loss, successful_steps, failed_steps
            )
            
            if save_success:
                # Check disk space after saving
                total, used, free = shutil.disk_usage(working_dir)
                print(f"Disk after save: Free={free/(1024**3):.2f} GB")
            else:
                print(f"Failed to save model for epoch {epoch+1}")
        else:
            print(f"Epoch {epoch+1} had no successful steps!")

finally:
    # Clean up hooks
    for hook in hooks:
        hook.remove()

success_rate = successful_steps / (successful_steps + failed_steps) * 100 if (successful_steps + failed_steps) > 0 else 0
print(f"Training completed!")
print(f"Final success rate: {success_rate:.1f}% ({successful_steps}/{successful_steps + failed_steps})")
print(f"Total successful steps: {successful_steps}")

# Save final LoRA weights if we had any successful steps
if successful_steps > 0:
    final_model_path = "/kaggle/working/blip2_lora_final.pt"
    final_avg_loss = total_loss / successful_steps
    save_success = save_lora_weights_only(
        model, final_model_path, num_epochs, final_avg_loss, successful_steps, failed_steps
    )
    
    if save_success:
        print(f"Final LoRA weights saved: {final_model_path}")
    else:
        print("Failed to save final LoRA weights")

print("Saved files:", [f for f in os.listdir("/kaggle/working") if f.endswith('.pt')])

# Print instructions for loading the model later
print("\n" + "="*50)
print("IMPORTANT: How to load your trained model for inference:")
print("="*50)
print("1. Use the load_and_merge_lora_for_inference() function provided above")
print("2. Example usage:")
print("   model = load_and_merge_lora_for_inference(")
print("       'Salesforce/blip2-opt-2.7b',")
print("       '/path/to/blip2_lora_final.pt',")
print("       device")
print("   )")
print("3. Then use the model normally for generating captions")
print("="*50)

# Test Model Eval

# !! RESTART TO CLEAR VRAM AND USE LOCALLY SAVED MODEL

In [None]:
# Install ultralytics for YOLOv5 and nltk for post-processing
!pip install ultralytics

In [None]:
import os
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch
from torch.cuda.amp import autocast, GradScaler

In [None]:
PATHS = {
    'train_img_dir': "/kaggle/working/obss-intern-competition-2025/train/train/",
    'train_csv_path': "/kaggle/working/obss-intern-competition-2025/train.csv",
    'test_csv_path': "/kaggle/working/obss-intern-competition-2025/test.csv", 
    'test_img_dir': "/kaggle/working/obss-intern-competition-2025/test/test/",
}

In [None]:
import torch
from PIL import Image
import matplotlib.pyplot as plt
import re
from transformers import Blip2ForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
from peft import PeftModel
import gc
import os
import pandas as pd
import random
import nltk
from ultralytics import YOLO
import numpy as np

# Download nltk data (run once)
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('averaged_perceptron_tagger_eng')

# Updated model paths for new training format
saved_model_paths = [
    "/kaggle/working/stable_blip2_lora_final.pt",
    "/kaggle/working/stable_blip2_lora_epoch_3.pt",
    "/kaggle/working/stable_blip2_lora_epoch_2.pt",
    "/kaggle/working/stable_blip2_lora_epoch_1.pt"
]

# Enhanced YOLO model loader with better error handling
def load_yolo_model():
    """Load YOLOv11 model with enhanced capabilities"""
    try:
        yolo_model = YOLO("yolo11n.pt")
        print("✅ YOLOv11n model loaded successfully")
        return yolo_model
    except Exception as e:
        print(f"⚠️ Failed to load YOLOv11n, trying YOLOv11s: {e}")
        try:
            yolo_model = YOLO("yolo11s.pt")
            print("✅ YOLOv11s model loaded successfully")
            return yolo_model
        except Exception as e2:
            print(f"⚠️ Failed to load any YOLO model: {e2}. Proceeding without object detection.")
            return None

yolo_model = load_yolo_model()

def get_enhanced_object_detection(image, yolo_model):
    """Enhanced object detection with confidence filtering and spatial awareness"""
    if yolo_model is None:
        return None, []
    
    try:
        results = yolo_model(image, conf=0.3, iou=0.5)  # Adjusted thresholds
        
        if len(results) == 0 or len(results[0].boxes) == 0:
            return None, []
        
        # Extract detection info
        boxes = results[0].boxes
        detected_objects = []
        spatial_info = {}
        
        img_width, img_height = image.size
        
        for i, box in enumerate(boxes):
            cls_id = int(box.cls.cpu().numpy())
            confidence = float(box.conf.cpu().numpy())
            class_name = results[0].names[cls_id]
            
            # Get bounding box coordinates
            x1, y1, x2, y2 = box.xyxy.cpu().numpy()[0]
            
            # Calculate spatial properties
            center_x = (x1 + x2) / 2 / img_width
            center_y = (y1 + y2) / 2 / img_height
            area_ratio = ((x2 - x1) * (y2 - y1)) / (img_width * img_height)
            
            detected_objects.append({
                'class': class_name,
                'confidence': confidence,
                'center_x': center_x,
                'center_y': center_y,
                'area_ratio': area_ratio,
                'position': get_spatial_position(center_x, center_y)
            })
        
        # Group by class and get dominant objects
        class_counts = {}
        for obj in detected_objects:
            class_name = obj['class']
            if class_name not in class_counts:
                class_counts[class_name] = []
            class_counts[class_name].append(obj)
        
        # Get spatial context
        spatial_context = analyze_spatial_context(detected_objects, img_width, img_height)
        
        return detected_objects, spatial_context
        
    except Exception as e:
        print(f"⚠️ Object detection failed: {e}")
        return None, []

def get_spatial_position(center_x, center_y):
    """Determine spatial position in image"""
    if center_x < 0.33:
        h_pos = "left"
    elif center_x > 0.67:
        h_pos = "right"
    else:
        h_pos = "center"
    
    if center_y < 0.33:
        v_pos = "top"
    elif center_y > 0.67:
        v_pos = "bottom"
    else:
        v_pos = "middle"
    
    if h_pos == "center" and v_pos == "middle":
        return "center"
    return f"{v_pos}-{h_pos}"

def analyze_spatial_context(detected_objects, img_width, img_height):
    """Analyze spatial relationships between objects"""
    context = {
        'dominant_objects': [],
        'scene_type': 'general',
        'object_relationships': []
    }
    
    if not detected_objects:
        return context
    
    # Sort by area (largest first)
    sorted_objects = sorted(detected_objects, key=lambda x: x['area_ratio'], reverse=True)
    
    # Get dominant objects (top 3 by area)
    context['dominant_objects'] = [obj['class'] for obj in sorted_objects[:3]]
    
    # Determine scene type
    classes = [obj['class'] for obj in detected_objects]
    if any(cls in ['person', 'sports ball', 'baseball bat', 'baseball glove'] for cls in classes):
        context['scene_type'] = 'sports'
    elif any(cls in ['car', 'truck', 'bus', 'motorcycle'] for cls in classes):
        context['scene_type'] = 'traffic'
    elif any(cls in ['dining table', 'chair', 'cup', 'bowl'] for cls in classes):
        context['scene_type'] = 'indoor'
    elif any(cls in ['tree', 'grass', 'sky'] for cls in classes):
        context['scene_type'] = 'outdoor'
    
    return context

# FIXED: Enhanced model loader with proper memory management
def load_trained_model(force_reload=False):
    """Load the trained model with LoRA weights - MEMORY OPTIMIZED (FP16, no quantization)"""
    global model, processor
    
    # Avoid reloading if model is already functional
    if not force_reload and 'model' in globals() and model is not None:
        try:
            dummy_tensor = torch.randn(1, 3, 224, 224).to(device, torch.float16)
            with torch.no_grad():
                _ = model.generate(pixel_values=dummy_tensor, max_length=5)
            print("✅ Using existing model from VRAM")
            return model, processor
        except Exception as e:
            print(f"⚠️ Model in VRAM not functional: {e}")
            if 'model' in globals():
                del model
            if 'processor' in globals():
                del processor
            gc.collect()
            torch.cuda.empty_cache()
    
    # Find the best available model
    model_path = None
    for path in saved_model_paths:
        if os.path.exists(path):
            model_path = path
            print(f"Found model: {model_path}")
            break
    
    if model_path is None:
        raise FileNotFoundError("No trained model found. Please check the model paths.")
    
    try:
        print(f"🔄 Loading processor...")
        processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
        processor.image_processor.size = {"height": 224, "width": 224}
        
        print(f"🔄 Loading checkpoint to CPU first...")
        checkpoint = torch.load(model_path, map_location='cpu')
        
        print(f"🔄 Loading base model in FP16...")
        # Load base model in FP16 without quantization
        base_model = Blip2ForConditionalGeneration.from_pretrained(
            "Salesforce/blip2-opt-2.7b",
            torch_dtype=torch.float16,
            device_map="auto",
            low_cpu_mem_usage=True,
            max_memory={0: "12GB"}  # Limit GPU memory
        )
        print("✅ Base model loaded in FP16")
        
        # Apply LoRA configuration (matching training code)
        from peft import LoraConfig, get_peft_model, TaskType
        
        target_modules = []
        for i in range(32):  # Match training configuration
            for proj in ["q_proj", "k_proj", "v_proj", "out_proj"]:
                target_modules.append(f"language_model.model.decoder.layers.{i}.self_attn.{proj}")
        
        lora_config = LoraConfig(
            task_type=TaskType.SEQ_2_SEQ_LM,
            r=24,  # Match training
            lora_alpha=48,  # Match training
            lora_dropout=0.1,
            bias="none",
            target_modules=target_modules,
            init_lora_weights="gaussian"
        )
        
        try:
            model = get_peft_model(base_model, lora_config)
            
            # Load LoRA weights
            state_dict = checkpoint.get('lora_weights', checkpoint.get('model_state_dict'))
            if not state_dict:
                print("⚠️ No LoRA weights found in checkpoint")
                model = base_model
            else:
                # Only load compatible keys
                compatible_state_dict = {k: v for k, v in state_dict.items() if k in model.state_dict()}
                model.load_state_dict(compatible_state_dict, strict=False)
                print(f"✅ Loaded {len(compatible_state_dict)} LoRA parameters")
        
        except Exception as e:
            print(f"⚠️ LoRA loading failed, using base model: {e}")
            model = base_model
        
        # Clean up checkpoint
        del checkpoint
        gc.collect()
        torch.cuda.empty_cache()
        
        model = model.to(device)
        model.eval()
        
        print("✅ Model loaded successfully")
        
        # Print memory usage
        if torch.cuda.is_available():
            memory_used = torch.cuda.memory_allocated() / 1024**3
            print(f"📊 GPU Memory Used: {memory_used:.2f}GB")
        
        return model, processor
        
    except Exception as e:
        print(f"❌ Failed to load model: {e}")
        for var_name in ['model', 'base_model', 'checkpoint']:
            if var_name in locals():
                del locals()[var_name]
        gc.collect()
        torch.cuda.empty_cache()
        raise

def clean_caption(caption):
    """Enhanced caption cleaning"""
    if not caption:
        return "An image showing a scene"
    
    # Remove HTML tags and special markers
    caption = re.sub(r'<[^>]*>', '', caption)
    caption = re.sub(r'\[.*?\]', '', caption)
    caption = re.sub(r'[\x00-\x1F\x7F-\x9F]', '', caption)
    caption = re.sub(r'�|', '', caption)
    
    # Remove problematic tokens
    problematic_tokens = [
        'strutconnect', 'strutconnector', 'attrot', 'guiactive', 'guiactiveunfocused',
        'madeupword', 'confignode', 'partmodule', 'tweakscale', 'modulemanager',
        'gamedata', 'squad', 'parttools', 'kspfield', 'persistant', 'cfgnode',
        'guiicon', 'guiname', '0002', 'externaltoevaonly', '裏', 'guiactiveunfocusedmadeupword'
    ]
    
    for token in problematic_tokens:
        caption = re.sub(rf'\b{token}\b', '', caption, flags=re.IGNORECASE)
    
    # Remove camelCase technical terms
    caption = re.sub(r'\b[A-Z][a-z]*[A-Z][a-zA-Z]*\b', '', caption)
    caption = re.sub(r'\b\w+000\d+\b', '', caption)
    
    # Clean up spacing and grammar
    caption = re.sub(r'\s+', ' ', caption).strip()
    caption = re.sub(r'\s+(and|or|the|a|an)$', '', caption, flags=re.IGNORECASE)
    caption = re.sub(r'^(the|a|an)\s+', '', caption, flags=re.IGNORECASE)
    
    # Capitalize first letter
    if caption and len(caption) > 0:
        caption = caption[0].upper() + caption[1:] if len(caption) > 1 else caption.upper()
    
    # Fallback for empty or problematic captions
    if len(caption.split()) < 2 or not caption or any(t in caption.lower() for t in problematic_tokens):
        caption = "An image showing a scene with visible objects"
    
    return caption

def enhance_caption_with_context(caption, detected_objects, spatial_context):
    """Enhance caption using object detection context"""
    if not detected_objects:
        return caption
    
    # Get dominant objects
    dominant_classes = [obj['class'] for obj in detected_objects[:3]]
    
    # Scene-specific enhancements
    if spatial_context['scene_type'] == 'sports':
        if 'person' in dominant_classes and any(sports in dominant_classes for sports in ['sports ball', 'baseball bat']):
            if 'player' not in caption.lower():
                caption = caption.replace('person', 'player').replace('people', 'players')
        
        if 'baseball' in dominant_classes or 'sports ball' in dominant_classes:
            if 'baseball' not in caption.lower():
                caption += ' in a baseball context'
    
    # Add spatial context if missing
    positions = [obj['position'] for obj in detected_objects]
    if any('center' in pos for pos in positions) and 'center' not in caption.lower():
        caption += ' with a central focus'
    
    return caption

def generate_multiple_captions(model, processor, inputs, image, device, detected_objects=None, spatial_context=None):
    """Generate 6 different caption variants as specified"""
    
    # YOUR ORIGINAL 6 CAPTION CONFIGS
    caption_configs = [
        {
            'name': 'High Quality Detailed',
            'max_length': 80,
            'num_beams': 8,
            'num_beam_groups': 2,
            'do_sample': False,
            'length_penalty': 1.5,
            'no_repeat_ngram_size': 3,
            'diversity_penalty': 0.7,
            'early_stopping': True
        },
        {
            'name': 'Creative Sampling',
            'max_length': 70,
            'num_beams': 1,
            'do_sample': True,
            'top_k': 50,
            'top_p': 0.9,
            'temperature': 0.8,
            'no_repeat_ngram_size': 2,
            'length_penalty': 1.2
        },
        {
            'name': 'Balanced Precision',
            'max_length': 60,
            'num_beams': 6,
            'do_sample': False,
            'length_penalty': 1.3,
            'no_repeat_ngram_size': 2,
            'early_stopping': True,
            'repetition_penalty': 1.1
        },
        {
            'name': 'Nucleus Sampling',
            'max_length': 75,
            'num_beams': 1,
            'do_sample': True,
            'top_p': 0.95,
            'temperature': 0.7,
            'no_repeat_ngram_size': 2,
            'length_penalty': 1.0,
            'repetition_penalty': 1.05
        },
        {
            'name': 'Conservative Beam',
            'max_length': 50,
            'num_beams': 4,
            'do_sample': False,
            'length_penalty': 1.0,
            'no_repeat_ngram_size': 2,
            'early_stopping': True
        },
        {
            'name': 'Training Compatible',
            'max_length': 64,
            'num_beams': 4,
            'do_sample': False,
            'length_penalty': 1.0,
            'no_repeat_ngram_size': 2,
            'early_stopping': True,
            'pad_token_id': processor.tokenizer.eos_token_id
        }
    ]
    
    generated_captions = []
    best_caption = "An image showing a scene"
    best_score = -float('inf')
    
    print(f"🎯 Detected objects: {[obj['class'] for obj in detected_objects] if detected_objects else 'None'}")
    print(f"🎯 Scene type: {spatial_context['scene_type'] if spatial_context else 'general'}")
    
    for config in caption_configs:
        try:
            with torch.no_grad():
                # Generate caption
                generated_ids = model.generate(
                    pixel_values=inputs.pixel_values,
                    **{k: v for k, v in config.items() if k != 'name'}
                )
                
                caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
                cleaned_caption = clean_caption(caption)
                
                # Apply context enhancement
                if detected_objects and spatial_context:
                    enhanced_caption = enhance_caption_with_context(cleaned_caption, detected_objects, spatial_context)
                else:
                    enhanced_caption = cleaned_caption
                
                # Score the caption
                score = score_caption(enhanced_caption, detected_objects, spatial_context)
                
                generated_captions.append({
                    'config': config['name'],
                    'caption': enhanced_caption,
                    'score': score,
                    'length': len(enhanced_caption.split())
                })
                
                print(f"  {config['name']}: '{enhanced_caption}' (Score: {score:.2f})")
                
                if score > best_score:
                    best_score = score
                    best_caption = enhanced_caption
                    
        except Exception as e:
            print(f"  {config['name']} failed: {e}")
            generated_captions.append({
                'config': config['name'],
                'caption': "Generation failed",
                'score': -1,
                'length': 0
            })
    
    return generated_captions, best_caption

def score_caption(caption, detected_objects, spatial_context):
    """Score caption quality based on various factors"""
    score = 0
    caption_lower = caption.lower()
    
    # Base quality score
    word_count = len(caption.split())
    if 5 <= word_count <= 15:
        score += 10
    elif 3 <= word_count <= 20:
        score += 5
    
    # Grammar and structure
    if caption[0].isupper():
        score += 2
    if not caption.endswith('.'):
        score += 1  # Prefer captions without periods for this task
    
    # Context matching
    if detected_objects:
        detected_classes = [obj['class'] for obj in detected_objects]
        
        # Sports context
        if any(cls in ['person', 'sports ball', 'baseball bat'] for cls in detected_classes):
            if any(word in caption_lower for word in ['baseball', 'player', 'sport', 'game']):
                score += 15
            if any(word in caption_lower for word in ['uniform', 'field', 'bat', 'ball']):
                score += 10
        
        # Object presence matching
        for obj in detected_classes:
            obj_lower = obj.lower()
            if obj_lower in caption_lower or any(syn in caption_lower for syn in get_synonyms(obj_lower)):
                score += 5
    
    # Scene type matching
    if spatial_context:
        scene_type = spatial_context['scene_type']
        if scene_type == 'sports' and any(word in caption_lower for word in ['playing', 'game', 'sport', 'field']):
            score += 8
        elif scene_type == 'outdoor' and any(word in caption_lower for word in ['outside', 'outdoor', 'field']):
            score += 5
    
    # Penalize generic or problematic phrases
    generic_phrases = ['an image', 'a picture', 'this shows', 'visible objects']
    for phrase in generic_phrases:
        if phrase in caption_lower:
            score -= 5
    
    return score

def get_synonyms(word):
    """Get simple synonyms for common objects"""
    synonyms = {
        'person': ['man', 'woman', 'player', 'individual'],
        'sports ball': ['ball', 'baseball'],
        'baseball bat': ['bat'],
        'car': ['vehicle', 'automobile'],
        'truck': ['vehicle']
    }
    return synonyms.get(word, [])

def process_and_display_image(model, processor, sample_row, img_dir, device, idx):
    """Process single image with enhanced caption generation"""
    sample_image_name = str(sample_row['image_id'])
    if not sample_image_name.endswith(('.jpg', '.jpeg', '.png')):
        sample_image_name += '.jpg'
    
    sample_image_path = os.path.join(img_dir, sample_image_name)
    sample_image = Image.open(sample_image_path).convert('RGB')
    
    print(f"\n{'='*60}")
    print(f"Processing image {idx+1}: {sample_image_name}")
    print(f"Image size: {sample_image.size}")
    
    # Enhanced object detection
    detected_objects, spatial_context = get_enhanced_object_detection(sample_image, yolo_model)
    
    # Prepare inputs (matching training resolution)
    inputs = processor(
        images=sample_image,
        return_tensors="pt"
    ).to(device, torch.float16)
    
    # Generate multiple captions
    generated_captions, best_caption = generate_multiple_captions(
        model, processor, inputs, sample_image, device, detected_objects, spatial_context
    )
    
    # Print detailed results FIRST (as you wanted)
    print(f"\n📊 Caption Generation Results:")
    for cap in generated_captions:
        print(f"  {cap['config']:20} | Score: {cap['score']:6.2f} | Words: {cap['length']:2} | {cap['caption']}")
    
    print(f"\n🏆 Best Caption: {best_caption}")
    
    # Then display image BELOW the captions (as you specified)
    plt.figure(figsize=(12, 8))
    plt.imshow(sample_image)
    plt.axis('off')
    plt.title(f"Image {idx+1} - {sample_image_name}\nBest Caption: {best_caption}", 
              fontsize=12, wrap=True, pad=20)
    plt.tight_layout()
    plt.show()
    
    print(f"{'='*60}")
    
    return {
        "image_id": sample_image_name,
        "best_caption": best_caption,
        "all_captions": generated_captions,
        "detected_objects": [obj['class'] for obj in detected_objects] if detected_objects else []
    }

def process_multiple_images(model, processor, test_df, img_dir, device, num_images=10):
    """Process 10 images as specified"""
    random_indices = random.sample(range(len(test_df)), min(num_images, len(test_df)))
    random_rows = test_df.iloc[random_indices]
    
    results = []
    
    for idx, (_, sample_row) in enumerate(random_rows.iterrows()):
        result = process_and_display_image(model, processor, sample_row, img_dir, device, idx)
        results.append(result)
        
        # Clear some memory between images
        torch.cuda.empty_cache()
        
        if idx < len(random_rows) - 1:  # Don't wait after the last image
            input("Press Enter to continue to the next image...")
    
    # Final summary
    print("\n" + "="*80)
    print("FINAL EVALUATION SUMMARY")
    print("="*80)
    
    for i, result in enumerate(results, 1):
        print(f"Image {i:2} - {result['image_id']:15} | Objects: {result['detected_objects']}")
        print(f"         Best: {result['best_caption']}")
        print()
    
    # Statistics
    total_words = sum(len(result['best_caption'].split()) for result in results)
    avg_words = total_words / len(results)
    
    all_detected_objects = []
    for result in results:
        all_detected_objects.extend(result['detected_objects'])
    
    unique_objects = list(set(all_detected_objects))
    
    print(f"📈 Statistics:")
    print(f"   Average caption length: {avg_words:.1f} words")
    print(f"   Total unique objects detected: {len(unique_objects)}")
    print(f"   Most common objects: {sorted(set(all_detected_objects), key=all_detected_objects.count, reverse=True)[:5]}")
    print("="*80)
    
    return results

# Initialize device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# FIXED: Set memory optimization environment variable
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Load the trained model
print("Loading trained model...")
model, processor = load_trained_model()

# Load test dataset
test_df = pd.read_csv(PATHS['test_csv_path'])
print(f"Test dataset loaded: {len(test_df)} samples")

# Process multiple images (YOUR ORIGINAL FLOW)
print("Starting enhanced evaluation...")
results = process_multiple_images(model, processor, test_df, PATHS['test_img_dir'], device, num_images=10)

# Save to submission.csv

!! RESTART THE CODEBASE TO FREE VRAM 

In [None]:
import os
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch
from torch.cuda.amp import autocast, GradScaler

In [None]:
PATHS = {
    'train_img_dir': "/kaggle/working/obss-intern-competition-2025/train/train/",
    'train_csv_path': "/kaggle/working/obss-intern-competition-2025/train.csv",
    'test_csv_path': "/kaggle/working/obss-intern-competition-2025/test.csv", 
    'test_img_dir': "/kaggle/working/obss-intern-competition-2025/test/test/",
}

In [None]:
import torch
from PIL import Image
import pandas as pd
import os
import re
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import Blip2ForConditionalGeneration, AutoProcessor
from peft import PeftModel, LoraConfig, get_peft_model, TaskType
import gc
import warnings
import logging

# Suppress warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["XLA_FLAGS"] = "--xla_gpu_cuda_data_dir=/usr/local/cuda"
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
warnings.filterwarnings("ignore", category=DeprecationWarning)
logging.getLogger("ultralytics").setLevel(logging.ERROR)


# Model paths
saved_model_paths = [
    "/kaggle/working/stable_blip2_lora_final.pt",
    "/kaggle/working/stable_blip2_lora_epoch_3.pt",
    "/kaggle/working/stable_blip2_lora_epoch_2.pt",
    "/kaggle/working/stable_blip2_lora_epoch_1.pt"
]

# Model loading
def load_trained_model(force_reload=False):
    """Load the trained model with LoRA weights - MEMORY OPTIMIZED (FP16)"""
    global model, processor
    
    if not force_reload and 'model' in globals() and model is not None:
        try:
            dummy_tensor = torch.randn(1, 3, 224, 224).to(device, torch.float16)
            with torch.no_grad():
                _ = model.generate(pixel_values=dummy_tensor, max_length=5)
            print("✅ Using existing model from VRAM")
            return model, processor
        except Exception as e:
            print(f"⚠️ Model in VRAM not functional: {e}")
            if 'model' in globals():
                del model
            if 'processor' in globals():
                del processor
            gc.collect()
            torch.cuda.empty_cache()
    
    model_path = None
    for path in saved_model_paths:
        if os.path.exists(path):
            model_path = path
            print(f"Found model: {model_path}")
            break
    
    if model_path is None:
        raise FileNotFoundError("No trained model found. Please check the model paths.")
    
    try:
        print(f"🔄 Loading processor...")
        processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b", use_fast=True)
        processor.image_processor.size = {"height": 224, "width": 224}
        
        print(f"🔄 Loading checkpoint to CPU first...")
        checkpoint = torch.load(model_path, map_location='cpu')
        
        print(f"🔄 Loading base model in FP16...")
        base_model = Blip2ForConditionalGeneration.from_pretrained(
            "Salesforce/blip2-opt-2.7b",
            torch_dtype=torch.float16,
            device_map="auto",
            low_cpu_mem_usage=True,
            max_memory={0: "12GB"}
        )
        print("✅ Base model loaded in FP16")
        
        target_modules = []
        for i in range(32):
            for proj in ["q_proj", "k_proj", "v_proj", "out_proj"]:
                target_modules.append(f"language_model.model.decoder.layers.{i}.self_attn.{proj}")
        
        lora_config = LoraConfig(
            task_type=TaskType.SEQ_2_SEQ_LM,
            r=24,
            lora_alpha=48,
            lora_dropout=0.1,
            bias="none",
            target_modules=target_modules,
            init_lora_weights="gaussian"
        )
        
        try:
            model = get_peft_model(base_model, lora_config)
            state_dict = checkpoint.get('lora_weights', checkpoint.get('model_state_dict'))
            if not state_dict:
                print("⚠️ No LoRA weights found in checkpoint")
                model = base_model
            else:
                compatible_state_dict = {k: v for k, v in state_dict.items() if k in model.state_dict()}
                model.load_state_dict(compatible_state_dict, strict=False)
                print(f"✅ Loaded {len(compatible_state_dict)} LoRA parameters")
        
        except Exception as e:
            print(f"⚠️ LoRA loading failed, using base model: {e}")
            model = base_model
        
        del checkpoint
        gc.collect()
        torch.cuda.empty_cache()
        
        model = model.to(device)
        model.eval()
        
        print("✅ Model loaded successfully")
        if torch.cuda.is_available():
            memory_used = torch.cuda.memory_allocated() / 1024**3
            print(f"📊 GPU Memory Used: {memory_used:.2f}GB")
        
        return model, processor
        
    except Exception as e:
        print(f"❌ Failed to load model: {e}")
        for var_name in ['model', 'base_model', 'checkpoint']:
            if var_name in locals():
                del locals()[var_name]
        gc.collect()
        torch.cuda.empty_cache()
        raise

# Caption configurations
def get_caption_configs():
    """Return three caption generation configurations"""
    return [
        {
            'name': 'Creative Sampling',
            'max_length': 80,
            'num_beams': 1,
            'do_sample': True,
            'top_k': 50,
            'top_p': 0.95,
            'temperature': 0.9,
            'no_repeat_ngram_size': 2,
            'length_penalty': 1.0
        },
        {
            'name': 'Balanced Beam',
            'max_length': 80,
            'num_beams': 4,
            'do_sample': False,
            'length_penalty': 1.0,
            'no_repeat_ngram_size': 2,
            'early_stopping': True
        },
        {
            'name': 'High-Precision Beam',
            'max_length': 80,
            'num_beams': 8,
            'do_sample': False,
            'length_penalty': 1.0,
            'no_repeat_ngram_size': 2,
            'early_stopping': True
        }
    ]

# Lightweight scoring for FID
def score_caption(caption):
    """Lightweight scoring to optimize FID"""
    score = 0.0
    caption_lower = caption.lower()
    words = caption.split()
    word_count = len(words)
    
    # Length (FID favors descriptive but concise)
    if 8 <= word_count <= 20:
        score += 20.0
    elif 5 <= word_count <= 25:
        score += 15.0
    elif word_count < 5:
        score -= 10.0
    
    # Vocabulary diversity
    unique_words = len(set(words))
    diversity_ratio = unique_words / max(word_count, 1)
    if diversity_ratio > 0.7:
        score += 15.0
    elif diversity_ratio < 0.5:
        score -= 5.0
    
    # Keyword presence (common objects/scenes)
    keywords = ['person', 'player', 'ball', 'car', 'vehicle', 'tree', 'sky', 'building', 'shop', 'sign', 'field']
    for keyword in keywords:
        if keyword in caption_lower:
            score += 10.0
    
    # Penalize generic phrases
    generic_phrases = ['an image', 'a picture', 'this shows', 'visible objects', 'scene']
    for phrase in generic_phrases:
        if phrase in caption_lower:
            score -= 10.0
    
    return score

# Clean captions
def clean_caption(caption):
    """Enhanced caption cleaning"""
    if not caption:
        return "An image showing a scene"
    
    caption = re.sub(r'<[^>]*>', '', caption)
    caption = re.sub(r'\[.*?\]', '', caption)
    caption = re.sub(r'[\x00-\x1F\x7F-\x9F]', '', caption)
    caption = re.sub(r'�|', '', caption)
    
    problematic_tokens = [
        'strutconnect', 'strutconnector', 'attrot', 'guiactive', 'guiactiveunfocused',
        'madeupword', 'confignode', 'partmodule', 'tweakscale', 'modulemanager',
        'gamedata', 'squad', 'parttools', 'kspfield', 'persistant', 'cfgnode',
        'guiicon', 'guiname', '0002', 'externaltoevaonly', '裏', 'guiactiveunfocusedmadeupword'
    ]
    
    for token in problematic_tokens:
        caption = re.sub(rf'\b{token}\b', '', caption, flags=re.IGNORECASE)
    
    caption = re.sub(r'\b[A-Z][a-z]*[A-Z][a-zA-Z]*\b', '', caption)
    caption = re.sub(r'\b\w+000\d+\b', '', caption)
    
    caption = re.sub(r'\s+', ' ', caption).strip()
    caption = re.sub(r'\s+(and|or|the|a|an)$', '', caption, flags=re.IGNORECASE)
    caption = re.sub(r'^(the|a|an)\s+', '', caption, flags=re.IGNORECASE)
    
    if caption and len(caption) > 0:
        caption = caption[0].upper() + caption[1:] if len(caption) > 1 else caption.upper()
    
    if len(caption.split()) < 2 or not caption or any(t in caption.lower() for t in problematic_tokens):
        caption = "An image showing a scene with visible objects"
    
    return caption

# Post-process captions
def post_process_caption(caption):
    """Simplified post-processing"""
    return caption.strip()

# Log progress
def log_progress(image_id, caption, index, score):
    """Log progress every 100 images with score"""
    if index % 100 == 0 or str(index)[:-2] == '100':
        print(f"📄 Image {index} (ID: {image_id}) | Caption: '{caption}' | Score: {score:.2f}")

# Test dataset class
class TestDataset(Dataset):
    def __init__(self, dataframe, processor, test_img_dir):
        self.dataframe = dataframe
        self.processor = processor
        self.test_img_dir = test_img_dir
    
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        image_id = str(row['image_id'])
        if not image_id.endswith(('.jpg', '.jpeg', '.png')):
            image_id += '.jpg'
        
        image_path = os.path.join(self.test_img_dir, image_id)
        try:
            image = Image.open(image_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {image_id}: {e}")
            image = Image.new('RGB', (224, 224))
            image_id = f"{image_id}_error"
        
        return {'image_id': row['image_id'], 'image': image, 'original_image_id': image_id}

# Collate function
def collate_fn(batch):
    images = [item['image'] for item in batch]
    image_ids = [item['image_id'] for item in batch]
    
    inputs = processor(
        images=images,
        text=None,
        padding="max_length",
        max_length=80,
        return_tensors="pt",
        truncation=True
    )
    
    return {
        'image_ids': image_ids,
        'pixel_values': inputs['pixel_values'],
    }

# Main execution
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Load model and processor
print("Loading trained model...")
model, processor = load_trained_model()

# Load test dataset
test_df = pd.read_csv(PATHS['test_csv_path'])
print(f"Test dataset loaded: {len(test_df)} samples")

# Create test dataset and dataloader
test_dataset = TestDataset(test_df, processor, PATHS['test_img_dir'])
test_dataloader = DataLoader(
    test_dataset,
    batch_size=12,  # Increased for speed
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True
)

# Get configurations
configs = get_caption_configs()

# Generate and score captions
print(f"\n🎯 Generating captions using three configurations...")
model.eval()
image_ids = []
captions = []

with torch.no_grad():
    for batch_idx, batch in enumerate(tqdm(test_dataloader, desc="Generating captions")):
        pixel_values = batch['pixel_values'].to(device, torch.float16)
        batch_image_ids = batch['image_ids']
        
        batch_captions = []
        batch_scores = []
        
        for config in configs:
            try:
                generation_params = {k: v for k, v in config.items() if k != 'name'}
                generated_ids = model.generate(
                    pixel_values=pixel_values,
                    **generation_params
                )
                config_captions = processor.batch_decode(generated_ids, skip_special_tokens=True)
                config_captions = [clean_caption(c) for c in config_captions]
                config_captions = [post_process_caption(c) for c in config_captions]
                config_scores = [score_caption(c) for c in config_captions]
                
                batch_captions.append(config_captions)
                batch_scores.append(config_scores)
                
            except Exception as e:
                print(f"Error with config {config['name']} for batch {batch_image_ids}: {e}")
                config_captions = ["An image showing a scene" for _ in batch_image_ids]
                config_scores = [-10.0 for _ in batch_image_ids]
                batch_captions.append(config_captions)
                batch_scores.append(config_scores)
        
        # Select best caption per image
        for i in range(len(batch_image_ids)):
            global_index = batch_idx * test_dataloader.batch_size + i
            best_caption = "An image showing a scene"
            best_score = -float('inf')
            
            for config_idx, (config_captions, config_scores) in enumerate(zip(batch_captions, batch_scores)):
                if config_scores[i] > best_score:
                    best_score = config_scores[i]
                    best_caption = config_captions[i]
            
            log_progress(batch_image_ids[i], best_caption, global_index, best_score)
            image_ids.append(batch_image_ids[i])
            captions.append(best_caption)
        
        torch.cuda.empty_cache()

# Create submission DataFrame
submission_df = pd.DataFrame({
    'image_id': image_ids,
    'caption': captions
})

# Save submission
submission_filename = 'submission_three_configs.csv'
submission_df.to_csv(submission_filename, index=False)

print(f"\n✅ Submission file saved as '{submission_filename}'")
print(f"📊 Generated {len(submission_df)} captions using best of three configurations")
-print(submission_df.head())