In [None]:
import asyncio
from google import genai
from google.genai import types
from pydantic import BaseModel, Field
from typing import Literal
from PIL import Image
from pathlib import Path
import os

client = genai.Client(api_key="PUT A PAID KEY HERE")

In [29]:
class WeaponClassification(BaseModel):
    classification: Literal["pistol_only", "rifle_only", "multiple_weapons", "no_weapon"] = Field(
        description="The weapon type detected in the image"
    )

In [30]:
PROMPT = """Classify the weapons in this image:
- pistol_only: Only pistol(s)
- rifle_only: Only rifle(s)
- multiple_weapons: More than one weapon type
- no_weapon: No weapons"""

In [31]:
async def classify_image(image_path: str) -> tuple[str, str]:
    image = Image.open(image_path)
    
    response = await client.aio.models.generate_content(
        model="gemini-2.5-flash",
        contents=[image, PROMPT],
        config={
            "response_mime_type": "application/json",
            "response_json_schema": WeaponClassification.model_json_schema(),
        },
    )
    
    result = WeaponClassification.model_validate_json(response.text)
    return image_path, result.classification

async def classify_batch(image_paths: list[str], batch_size: int = 5, delay: float = 12.0):
    results = {}
    total_batches = (len(image_paths) + batch_size - 1) // batch_size
    
    for i in range(0, len(image_paths), batch_size):
        batch_num = i // batch_size + 1
        batch = image_paths[i:i + batch_size]
        tasks = [classify_image(p) for p in batch]
        batch_results = await asyncio.gather(*tasks, return_exceptions=True)
        
        for res in batch_results:
            if isinstance(res, Exception):
                print(f"Error: {res}")
            else:
                path, classification = res
                results[path] = classification
        
        print(f"Batch {batch_num}/{total_batches} done ({min(i + batch_size, len(image_paths))}/{len(image_paths)} images)")
        
        if i + batch_size < len(image_paths):
            print(f"Waiting {delay}s to avoid rate limit...")
            await asyncio.sleep(delay)
    
    return results

In [32]:
image_dir = "/Users/vaibhavnakrani/yolo/raw_data/YouTube-GDD/images/train"
image_paths = [str(p) for p in Path(image_dir).glob("*.jpg")]

results = await classify_batch(image_paths, batch_size=50, delay=3.0)

Batch 1/80 done (50/4000 images)
Waiting 3.0s to avoid rate limit...
Batch 2/80 done (100/4000 images)
Waiting 3.0s to avoid rate limit...
Batch 3/80 done (150/4000 images)
Waiting 3.0s to avoid rate limit...
Batch 4/80 done (200/4000 images)
Waiting 3.0s to avoid rate limit...
Batch 5/80 done (250/4000 images)
Waiting 3.0s to avoid rate limit...
Batch 6/80 done (300/4000 images)
Waiting 3.0s to avoid rate limit...
Batch 7/80 done (350/4000 images)
Waiting 3.0s to avoid rate limit...
Batch 8/80 done (400/4000 images)
Waiting 3.0s to avoid rate limit...
Batch 9/80 done (450/4000 images)
Waiting 3.0s to avoid rate limit...
Batch 10/80 done (500/4000 images)
Waiting 3.0s to avoid rate limit...
Batch 11/80 done (550/4000 images)
Waiting 3.0s to avoid rate limit...
Batch 12/80 done (600/4000 images)
Waiting 3.0s to avoid rate limit...
Batch 13/80 done (650/4000 images)
Waiting 3.0s to avoid rate limit...
Batch 14/80 done (700/4000 images)
Waiting 3.0s to avoid rate limit...
Batch 15/80 done

In [33]:
results

{'/Users/vaibhavnakrani/yolo/raw_data/YouTube-GDD/images/train/0kRe9QTdsSo_30_60_000012.jpg': 'rifle_only',
 '/Users/vaibhavnakrani/yolo/raw_data/YouTube-GDD/images/train/gunNXngtOKo_24_48_000007.jpg': 'rifle_only',
 '/Users/vaibhavnakrani/yolo/raw_data/YouTube-GDD/images/train/A3F1QvCHQLQ_30_60_000228.jpg': 'pistol_only',
 '/Users/vaibhavnakrani/yolo/raw_data/YouTube-GDD/images/train/dfYeDx5BzfU_30_60_000001.jpg': 'rifle_only',
 '/Users/vaibhavnakrani/yolo/raw_data/YouTube-GDD/images/train/tPNS4vUlX60_30_60_000467.jpg': 'pistol_only',
 '/Users/vaibhavnakrani/yolo/raw_data/YouTube-GDD/images/train/1T7zZ_D8nqI_30_60_000005.jpg': 'rifle_only',
 '/Users/vaibhavnakrani/yolo/raw_data/YouTube-GDD/images/train/3ht89CbHyYg_30_60_000023.jpg': 'rifle_only',
 '/Users/vaibhavnakrani/yolo/raw_data/YouTube-GDD/images/train/fj5Aq7bz-tI_25_50_000111.jpg': 'rifle_only',
 '/Users/vaibhavnakrani/yolo/raw_data/YouTube-GDD/images/train/IV0TpHM3TOU&t=1s_30_60_000082.jpg': 'pistol_only',
 '/Users/vaibhavnakr

In [34]:
import shutil
import json

output_dir = Path("/Users/vaibhavnakrani/yolo/classified_images")

for category in ["pistol_only", "rifle_only", "multiple_weapons", "no_weapon"]:
    (output_dir / category).mkdir(parents=True, exist_ok=True)

for img_path, classification in results.items():
    src = Path(img_path)
    dst = output_dir / classification / src.name
    shutil.copy2(src, dst)

with open(output_dir / "results.json", "w") as f:
    json.dump(results, f, indent=2)

from collections import Counter
counts = Counter(results.values())
print("Saved to:", output_dir)
print("Summary:", dict(counts))


Saved to: /Users/vaibhavnakrani/yolo/classified_images
Summary: {'rifle_only': 2329, 'pistol_only': 1444, 'no_weapon': 33, 'multiple_weapons': 184}


- count tokens

In [None]:
async def count_tokens_for_image(image_path: str) -> tuple[str, int]:
    image = Image.open(image_path)
    result = await client.aio.models.count_tokens(
        model="gemini-2.5-flash",
        contents=[image, PROMPT]
    )
    return image_path, result.total_tokens

async def count_all_tokens(image_paths: list[str], batch_size: int = 50):
    total_tokens = 0
    
    for i in range(0, len(image_paths), batch_size):
        batch = image_paths[i:i + batch_size]
        tasks = [count_tokens_for_image(p) for p in batch]
        batch_results = await asyncio.gather(*tasks, return_exceptions=True)
        
        for res in batch_results:
            if isinstance(res, Exception):
                print(f"Error: {res}")
            else:
                _, tokens = res
                total_tokens += tokens
        
        print(f"Counted {min(i + batch_size, len(image_paths))}/{len(image_paths)} images")
    
    return total_tokens

image_dir = "/Users/vaibhavnakrani/yolo/raw_data/YouTube-GDD/images/train"
image_paths = [str(p) for p in Path(image_dir).glob("*.jpg")]

total = await count_all_tokens(image_paths, batch_size=100)
print(f"\nTotal images: {len(image_paths)}")
print(f"Total input tokens: {total:,}")
print(f"Avg tokens per image: {total / len(image_paths):.0f}")
