# CASIA-WebFace Dataset Analysis

Explore the CASIA-WebFace face recognition dataset.

- ~494K face images
- 10,572 unique identities
- 112x112 aligned face images

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
import random
from collections import Counter

plt.rcParams['figure.figsize'] = (14, 8)

# Paths
DATASET_DIR = Path(r'D:/DataSets/casia_Webface')
IMAGES_DIR = DATASET_DIR / 'images'
LST_PATH = DATASET_DIR / 'casia-webface' / 'train.lst'

print(f"Dataset directory: {DATASET_DIR}")
print(f"Images directory: {IMAGES_DIR}")

## 1. Load Dataset Metadata

First, let's parse the .lst file to understand the dataset structure.

In [None]:
# Parse the .lst file directly (no extraction needed for analysis)
print("Parsing train.lst file...")

samples = []
with open(LST_PATH, 'r') as f:
    for line in f:
        parts = line.strip().split('\t')
        if len(parts) >= 3:
            idx = int(parts[0])
            original_path = parts[1]
            label = int(parts[2])
            # Extract person ID from path
            person_id = original_path.split('/')[-2]
            image_name = original_path.split('/')[-1]
            
            # Parse bounding box if available
            bbox = None
            if len(parts) >= 7:
                bbox = [float(parts[3]), float(parts[4]), float(parts[5]), float(parts[6])]
            
            samples.append({
                'idx': idx,
                'label': label,
                'person_id': person_id,
                'image_name': image_name,
                'bbox': bbox
            })

df = pd.DataFrame(samples)
print(f"Total images: {len(df):,}")
print(f"Unique identities: {df['label'].nunique():,}")
df.head(10)

## 2. Dataset Statistics

In [None]:
# Images per identity
images_per_identity = df.groupby('label').size()

print("Images per Identity Statistics:")
print(f"  Min:    {images_per_identity.min()}")
print(f"  Max:    {images_per_identity.max()}")
print(f"  Mean:   {images_per_identity.mean():.1f}")
print(f"  Median: {images_per_identity.median():.1f}")
print(f"  Std:    {images_per_identity.std():.1f}")

# Distribution plot
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Histogram
ax = axes[0]
ax.hist(images_per_identity, bins=50, edgecolor='black', alpha=0.7)
ax.axvline(images_per_identity.mean(), color='red', linestyle='--', label=f'Mean: {images_per_identity.mean():.1f}')
ax.axvline(images_per_identity.median(), color='orange', linestyle='--', label=f'Median: {images_per_identity.median():.1f}')
ax.set_xlabel('Images per Identity')
ax.set_ylabel('Number of Identities')
ax.set_title('Distribution of Images per Identity')
ax.legend()

# Top identities
ax = axes[1]
top_20 = images_per_identity.nlargest(20)
ax.barh(range(20), top_20.values)
ax.set_yticks(range(20))
ax.set_yticklabels([f'ID {i}' for i in top_20.index])
ax.set_xlabel('Number of Images')
ax.set_title('Top 20 Identities by Image Count')
ax.invert_yaxis()

plt.tight_layout()
plt.show()

## 3. View 3 Random People

Let's visualize images from 3 randomly selected identities.

In [None]:
# Check if images are extracted, otherwise use RecordIO directly
use_extracted = IMAGES_DIR.exists() and any(IMAGES_DIR.iterdir())

if use_extracted:
    print("Using extracted images from:", IMAGES_DIR)
else:
    print("Images not extracted. Will extract on-the-fly from RecordIO.")
    print("Run: python -m sim_bench.datasets.casia.extract_casia")

In [None]:
def load_image_from_recordio(rec_path, idx_path, image_idx):
    """
    Load a single image from RecordIO without extracting all.
    """
    import struct
    from io import BytesIO
    
    # Read index
    idx_map = {}
    with open(idx_path, 'rb') as f:
        while True:
            data = f.read(16)
            if len(data) < 16:
                break
            key, offset = struct.unpack('QQ', data)
            idx_map[key] = offset
    
    if image_idx not in idx_map:
        return None
    
    with open(rec_path, 'rb') as f:
        f.seek(idx_map[image_idx])
        
        # Read header
        header = f.read(8)
        flag, length = struct.unpack('II', header)
        
        # Read data
        data = f.read(length)
        
        # Find image start (JPEG magic bytes)
        jpeg_start = data.find(b'\xff\xd8\xff')
        if jpeg_start >= 0:
            img_data = data[jpeg_start:]
            return Image.open(BytesIO(img_data))
    
    return None


def load_images_for_identity(label, df, max_images=8):
    """
    Load images for a specific identity.
    """
    person_samples = df[df['label'] == label].head(max_images)
    images = []
    
    if use_extracted:
        # Load from extracted files
        person_dir = IMAGES_DIR / 'images' / f'id_{label:05d}'
        if person_dir.exists():
            for img_path in sorted(person_dir.glob('*.jpg'))[:max_images]:
                img = Image.open(img_path)
                images.append(img)
    else:
        # Load from RecordIO
        rec_path = DATASET_DIR / 'casia-webface' / 'train.rec'
        idx_path = DATASET_DIR / 'casia-webface' / 'train.idx'
        
        for _, row in person_samples.iterrows():
            img = load_image_from_recordio(rec_path, idx_path, row['idx'])
            if img is not None:
                images.append(img)
    
    return images, person_samples['person_id'].iloc[0] if len(person_samples) > 0 else 'Unknown'


# Select 3 random identities with at least 5 images
identities_with_enough = images_per_identity[images_per_identity >= 5].index.tolist()
selected_identities = random.sample(identities_with_enough, 3)

print(f"Selected identities: {selected_identities}")
print(f"Images per identity: {[images_per_identity[i] for i in selected_identities]}")

In [None]:
# Display images for each selected identity
fig, axes = plt.subplots(3, 8, figsize=(16, 7))

for row_idx, identity in enumerate(selected_identities):
    images, person_id = load_images_for_identity(identity, df, max_images=8)
    
    for col_idx in range(8):
        ax = axes[row_idx, col_idx]
        
        if col_idx < len(images):
            ax.imshow(images[col_idx])
            if col_idx == 0:
                ax.set_ylabel(f'ID: {identity}\n({person_id})', fontsize=10)
        else:
            ax.axis('off')
        
        ax.set_xticks([])
        ax.set_yticks([])
        
        if row_idx == 0:
            ax.set_title(f'Image {col_idx + 1}', fontsize=9)

plt.suptitle('3 Random Identities from CASIA-WebFace', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig(DATASET_DIR / 'sample_identities.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nSaved visualization to: {DATASET_DIR / 'sample_identities.png'}")

## 4. Identity Distribution Analysis

In [None]:
# Categorize identities by image count
bins = [0, 10, 20, 50, 100, 200, 500, 1000]
labels = ['1-10', '11-20', '21-50', '51-100', '101-200', '201-500', '500+']

binned = pd.cut(images_per_identity, bins=bins, labels=labels, right=True)
bin_counts = binned.value_counts().sort_index()

fig, ax = plt.subplots(figsize=(10, 5))
bars = ax.bar(bin_counts.index, bin_counts.values, color='steelblue', edgecolor='black')

# Add count labels
for bar, count in zip(bars, bin_counts.values):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 50,
            f'{count:,}', ha='center', va='bottom', fontsize=10)

ax.set_xlabel('Images per Identity')
ax.set_ylabel('Number of Identities')
ax.set_title('Identity Distribution by Image Count')
plt.tight_layout()
plt.show()

print("\nIdentity count by images:")
for label, count in bin_counts.items():
    print(f"  {label:>8} images: {count:>5,} identities")

## 5. Quick Summary

In [None]:
print("="*60)
print("CASIA-WebFace Dataset Summary")
print("="*60)
print(f"Total images:        {len(df):>10,}")
print(f"Unique identities:   {df['label'].nunique():>10,}")
print(f"Avg images/identity: {len(df) / df['label'].nunique():>10.1f}")
print(f"Image size:          112 x 112 px")
print(f"\nImages per identity:")
print(f"  Min:    {images_per_identity.min():>6}")
print(f"  Max:    {images_per_identity.max():>6}")
print(f"  Mean:   {images_per_identity.mean():>6.1f}")
print(f"  Median: {images_per_identity.median():>6.1f}")
print("="*60)