In [None]:
from evaluation.google_api import GoogleAPI
from evaluation.openai_api import OpenAIAPI
from evaluation.forestfire_evaluation import create_predictions, eval_structured_data
from templates.answer_schema import smoke_detection_schema
from datasets import load_dataset
from tqdm import tqdm
import json
import concurrent.futures
import threading
import gc

# vlm = OpenAIAPI()

vlm = GoogleAPI(model_name="gemini-2.0-flash-lite")

In [None]:
# Load the dataset
dataset_name = "leon-se/FIgLib-Test"
eval_ds = load_dataset(dataset_name, split="train")
print(f"\nDataset size: {len(eval_ds)}\nModel: {vlm.model_name}\n")

# Create output file (empty it if it exists)
output_file = "benchmarks/batch_eval/gemini-20-flash-lite_figlib.jsonl"
with open(output_file, "w") as f:
    f.write("")  # Create empty file

# Create a lock for file writing to prevent race conditions
file_lock = threading.Lock()

# Process samples in smaller batches
BATCH_SIZE = 500
MAX_WORKERS = 16  # Reduce the number of concurrent workers

def process_sample(sample_idx):
    try:
        sample = eval_ds[sample_idx]
        image = sample["image"]
        prompt = sample["prompt"]
        gt_dict = sample["gt_dict"]
        response_schema = smoke_detection_schema
        
        # API call to Gemini
        vlm_prediction = vlm.generate_structured_response_from_pil_image(prompt, image, response_schema)
        
        # Thread-safe file writing
        with file_lock:
            with open(output_file, "a") as f:
                result = {"sample_idx": sample_idx, "gt_dict": gt_dict, "vlm_prediction": vlm_prediction}
                f.write(json.dumps(result) + "\n")
        
        # Return minimal information
        return sample_idx, True
    except Exception as e:
        print(f"Error processing sample {sample_idx}: {e}")
        return sample_idx, False

# Track successful samples
successful_samples = 0
total_batches = (len(eval_ds) + BATCH_SIZE - 1) // BATCH_SIZE

# Process in batches
for batch_idx in range(total_batches):
    start_idx = batch_idx * BATCH_SIZE
    end_idx = min(start_idx + BATCH_SIZE, len(eval_ds))
    print(f"\nProcessing batch {batch_idx+1}/{total_batches} (samples {start_idx}-{end_idx-1})")
    
    # Use ThreadPoolExecutor with limited workers
    with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        # Submit batch tasks
        futures = [executor.submit(process_sample, i) for i in range(start_idx, end_idx)]
        
        # Process results as they complete with a progress bar
        batch_success = 0
        for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
            _, success = future.result()
            if success:
                batch_success += 1
        
        successful_samples += batch_success
        print(f"Batch complete: {batch_success}/{end_idx-start_idx} successful")
    
    # Force garbage collection between batches
    gc.collect()

print(f"\nEvaluation complete: Successfully processed {successful_samples} out of {len(eval_ds)} samples")

In [None]:
# Load the results
with open("benchmarks/batch_eval/gemini-20-flash-lite_figlib.jsonl", "r") as f:
    results = [json.loads(line) for line in f.readlines()]
    predictions_text = [r["vlm_prediction"] for r in results]
    ground_truth_dicts = [r["gt_dict"] for r in results]

In [None]:
results = eval_structured_data(predictions_text, ground_truth_dicts, vlm.model_name, dataset_name, 
                               write_to_file=True, results_folder="benchmarks-test",
                               confusion_keys=["forest_fire_smoke_visible"])