In [None]:
"""
Batch Gemini Inference and CSV Aggregator

This notebook takes either a number of samples or a list of specific sample filenames,
calls `run_gemini` on each sample, collects the outputs, and generates a single CSV file
aggregating all results.

NOTE: run_gemini.ipynb must be in the same directory as this notebook

ALSO NOTE: It costs $$ to run Gemini, so be careful with the number of samples you choose!
e.g. running 100 samples costs ~$5!!
"""

# num_to_sample = 10
# path_to_sample_from = "../data/img"
sampled_paths = ['../data/img/1692210.jpg', '../data/img/3224559.jpg', '../data/img/3195555.jpg', '../data/img/954335.jpg', '../data/img/805174.jpg', '../data/img/1554650.jpg', '../data/img/1524548.jpg', '../data/img/1558385.jpg', '../data/img/3512932.jpg', '../data/img/1740735.jpg']
output_csv_path = "../tmp/gemini_output_test_10.csv"

In [None]:
import os
import random

# If sampled_paths is provided, use it directly
if sampled_paths is None:
    all_files = [
        f for f in os.listdir(path_to_sample_from)
        if f.lower().endswith(('.jpg', '.jpeg'))
    ]

    # Randomly sample the desired number of files
    sampled_files = random.sample(all_files, min(num_to_sample, len(all_files)))

    # Save the full paths to a list
    sampled_paths = [os.path.join(path_to_sample_from, f) for f in sampled_files]

print("Will generate CSV based on sampled files:")
print(sampled_paths)

In [None]:
import importlib
import os
import herbarium_label_extractor
importlib.reload(herbarium_label_extractor)
from herbarium_label_extractor import HerbariumLabelExtractor

with open('prompts/system_instructions_no_ocr.md') as f:
    sys_instr = f.read()
with open('prompts/few_shot_prompt_no_ocr.md') as f:
    few_shot = f.read()

extractor = HerbariumLabelExtractor(
    system_instructions=sys_instr,
    few_shot_prompt=few_shot,
    few_shot_image_paths=[
        '../img/IMG_2708.jpg',
    ],
    output_dir='../tmp'
)
results = []
for img_path in sampled_paths:
    print(f"Processing image: {img_path}")
    # Call the classify method on the extractor for each image path
    result = extractor.classify(img_path)
    result['id'] = os.path.splitext(os.path.basename(img_path))[0]   # Image filename only
    results.append(result)


In [None]:
import csv

# Get all unique keys from results for CSV header
all_keys = set()
for r in results:
    all_keys.update(r.keys())
fieldnames = list(all_keys)

# Save results to CSV
with open(output_csv_path, 'w', newline='') as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    writer.writeheader()
    for row in results:
        writer.writerow(row)

print(f"Saved CSV to {output_csv_path}")