# Batch Face Enrollment Pipeline

Efficient batch enrollment of faces from a directory structure or CSV.

In [None]:
import os
import sys
import asyncio
import pandas as pd
from pathlib import Path
from typing import List, Dict
import httpx
from tqdm.notebook import tqdm
import concurrent.futures

# Add parent directory to path
sys.path.append('..')

# API Configuration
API_BASE_URL = "http://localhost:8000"
API_KEY = os.getenv("API_KEY", "")

## 1. Prepare Dataset

In [None]:
# Option 1: Load from directory structure
# Expected structure: dataset/person_id/image1.jpg, image2.jpg, ...

def scan_directory(base_path: str) -> Dict[str, List[str]]:
    """Scan directory for face images organized by person"""
    dataset = {}
    base = Path(base_path)
    
    for person_dir in base.iterdir():
        if person_dir.is_dir():
            person_id = person_dir.name
            images = []
            
            for img_file in person_dir.glob("*.jpg"):
                images.append(str(img_file))
            
            for img_file in person_dir.glob("*.png"):
                images.append(str(img_file))
            
            if images:
                dataset[person_id] = images
    
    return dataset

# Example usage
# dataset = scan_directory("./face_dataset")

# Option 2: Create sample dataset for testing
dataset = {
    "person_001": ["path/to/person1_img1.jpg", "path/to/person1_img2.jpg"],
    "person_002": ["path/to/person2_img1.jpg"],
    "person_003": ["path/to/person3_img1.jpg", "path/to/person3_img2.jpg", "path/to/person3_img3.jpg"],
}

print(f"Found {len(dataset)} persons")
print(f"Total images: {sum(len(imgs) for imgs in dataset.values())}")

# Display sample
for person_id, images in list(dataset.items())[:3]:
    print(f"  {person_id}: {len(images)} images")

## 2. Batch Enrollment Client

In [None]:
class BatchEnrollmentClient:
    def __init__(self, base_url: str, api_key: str = ""):
        self.base_url = base_url
        self.headers = {"X-API-Key": api_key} if api_key else {}
        self.client = httpx.AsyncClient(timeout=30.0)
    
    async def enroll_person(self, person_id: str, image_paths: List[str], 
                           quality_threshold: float = 0.5) -> Dict:
        """Enroll a single person with multiple images"""
        
        # Prepare multipart form data
        files = []
        for img_path in image_paths:
            if os.path.exists(img_path):
                with open(img_path, 'rb') as f:
                    files.append(('images', (os.path.basename(img_path), f.read(), 'image/jpeg')))
        
        if not files:
            return {"status": "error", "message": "No valid images found"}
        
        # Make enrollment request
        try:
            response = await self.client.post(
                f"{self.base_url}/api/v1/enroll/{person_id}",
                files=files,
                data={"quality_threshold": str(quality_threshold)},
                headers=self.headers
            )
            
            if response.status_code == 200:
                return response.json()
            else:
                return {"status": "error", "message": f"HTTP {response.status_code}"}
                
        except Exception as e:
            return {"status": "error", "message": str(e)}
    
    async def batch_enroll(self, dataset: Dict[str, List[str]], 
                          concurrent_requests: int = 5) -> List[Dict]:
        """Enroll multiple persons concurrently"""
        
        semaphore = asyncio.Semaphore(concurrent_requests)
        results = []
        
        async def enroll_with_semaphore(person_id, images):
            async with semaphore:
                result = await self.enroll_person(person_id, images)
                return {"person_id": person_id, **result}
        
        tasks = [
            enroll_with_semaphore(person_id, images)
            for person_id, images in dataset.items()
        ]
        
        # Use tqdm for progress bar
        for task in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Enrolling"):
            result = await task
            results.append(result)
        
        return results
    
    async def close(self):
        await self.client.aclose()

# Initialize client
client = BatchEnrollmentClient(API_BASE_URL, API_KEY)

## 3. Run Batch Enrollment

In [None]:
# Run batch enrollment
async def run_enrollment():
    print("Starting batch enrollment...")
    
    results = await client.batch_enroll(
        dataset,
        concurrent_requests=5  # Adjust based on server capacity
    )
    
    # Analyze results
    successful = [r for r in results if r.get('status') == 'completed']
    failed = [r for r in results if r.get('status') != 'completed']
    
    print(f"\n✅ Successfully enrolled: {len(successful)}/{len(results)}")
    
    if failed:
        print(f"\n❌ Failed enrollments:")
        for r in failed[:5]:  # Show first 5 failures
            print(f"  - {r['person_id']}: {r.get('message', 'Unknown error')}")
    
    # Create results DataFrame
    df_results = pd.DataFrame(results)
    
    await client.close()
    return df_results

# Execute
df_results = await run_enrollment()

# Display summary
print("\nEnrollment Summary:")
print(df_results[['person_id', 'status', 'faces_enrolled']].head(10))

## 4. Verify Enrollments

In [None]:
async def verify_enrollments(sample_size: int = 5):
    """Verify that enrollments are working by testing identification"""
    
    # Sample some enrolled persons
    enrolled = df_results[df_results['status'] == 'completed']['person_id'].tolist()[:sample_size]
    
    print(f"Testing identification for {len(enrolled)} persons...\n")
    
    async with httpx.AsyncClient() as client:
        for person_id in enrolled:
            # Get first image for this person
            if person_id in dataset and dataset[person_id]:
                test_image = dataset[person_id][0]
                
                if os.path.exists(test_image):
                    with open(test_image, 'rb') as f:
                        files = [('image', (os.path.basename(test_image), f.read(), 'image/jpeg'))]
                    
                    response = await client.post(
                        f"{API_BASE_URL}/api/v1/identify",
                        files=files
                    )
                    
                    if response.status_code == 200:
                        result = response.json()
                        if result['matches']:
                            match = result['matches'][0]
                            print(f"✅ {person_id}: Matched as {match['person_id']} (similarity: {match['similarity']:.3f})")
                        else:
                            print(f"❌ {person_id}: No match found")
                    else:
                        print(f"❌ {person_id}: Request failed")

# Run verification
await verify_enrollments()

## 5. Export Results

In [None]:
# Save enrollment results
output_file = "enrollment_results.csv"
df_results.to_csv(output_file, index=False)
print(f"Results saved to {output_file}")

# Generate statistics
stats = {
    "total_persons": len(df_results),
    "successful_enrollments": len(df_results[df_results['status'] == 'completed']),
    "failed_enrollments": len(df_results[df_results['status'] != 'completed']),
    "total_faces_enrolled": df_results['faces_enrolled'].sum(),
    "avg_faces_per_person": df_results[df_results['faces_enrolled'] > 0]['faces_enrolled'].mean()
}

print("\nEnrollment Statistics:")
for key, value in stats.items():
    print(f"  {key}: {value:.2f}" if isinstance(value, float) else f"  {key}: {value}")