# ArcFace (ResNet) Analysis - CASIA WebFace

This notebook analyzes the ArcFace model results in `outputs/face/20260125_224553`.
It loads validation predictions, builds a confusion matrix, and visualizes FP/FN image pairs.
It also surfaces LFW verification results saved during training.


In [1]:
import json
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from sklearn.metrics import confusion_matrix
import yaml

from sim_bench.face_recognition.train import create_identity_split
from sim_bench.datasets.face_dataset import MXNetRecordDataset


In [2]:
run_dir = Path(r"outputs/face/20260125_224553")
inference_dir = run_dir / "inference"
val_csv = inference_dir / "val_predictions.csv"
summary_json = inference_dir / "inference_summary.json"
config_yaml = run_dir / "config.yaml"

dataset_root = Path(r"D:\DataSets\faces_webface_112x112")
rec_path = dataset_root / "train.rec"

print("Run dir:", run_dir)
print("Val preds:", val_csv)
print("Dataset rec:", rec_path)


Run dir: outputs\face\20260125_224553
Val preds: outputs\face\20260125_224553\inference\val_predictions.csv
Dataset rec: D:\DataSets\faces_webface_112x112\train.rec


In [3]:
val_df = pd.read_csv(val_csv)
summary = json.loads(summary_json.read_text())
config = yaml.safe_load(config_yaml.read_text())

print(val_df.head())
print("Val samples:", len(val_df))
print("Summary:", summary)


FileNotFoundError: [Errno 2] No such file or directory: 'outputs\\face\\20260125_224553\\inference\\val_predictions.csv'

In [None]:
# Rebuild dataset and validation split to map subset indices -> original record indices
val_ratio = config.get('data', {}).get('val_ratio', 0.1)
seed = config.get('seed', 42)
max_train_ids = config.get('data', {}).get('max_train_identities')

dataset = MXNetRecordDataset(rec_path=str(rec_path), transform=None)
train_idx, val_idx, train_remap, val_remap = create_identity_split(
    dataset, val_ratio=val_ratio, seed=seed, max_train_identities=max_train_ids
)

print("Val subset size:", len(val_idx))
print("Num val identities:", len(val_remap))

# Build lookup tables for subset index -> original index and remapped labels
val_index_to_orig = {subset_idx: orig_idx for subset_idx, orig_idx in enumerate(val_idx)}

# Build label -> subset indices for quick sampling
val_labels = val_df['true_label'].to_numpy()
label_to_subset_indices = {}
for subset_idx, label in enumerate(val_labels):
    label_to_subset_indices.setdefault(int(label), []).append(subset_idx)


In [None]:
# Confusion matrix for top-N most frequent identities
top_n = 30
label_counts = val_df['true_label'].value_counts()
top_labels = label_counts.head(top_n).index.to_list()

filtered = val_df[val_df['true_label'].isin(top_labels) & val_df['predicted_label'].isin(top_labels)]
cm = confusion_matrix(filtered['true_label'], filtered['predicted_label'], labels=top_labels)

plt.figure(figsize=(12, 10))
sns.heatmap(cm, cmap='Blues', xticklabels=top_labels, yticklabels=top_labels)
plt.title(f'Confusion Matrix (Top {top_n} by support)')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()

print("Full confusion matrix shape:", confusion_matrix(val_df['true_label'], val_df['predicted_label']).shape)


In [None]:
def load_image_by_subset_index(subset_idx: int) -> Image.Image:
    orig_idx = val_index_to_orig[subset_idx]
    item = dataset[orig_idx]
    img = item['image']
    if isinstance(img, np.ndarray):
        return Image.fromarray(img)
    return img

def find_misclassified_samples(df: pd.DataFrame) -> pd.DataFrame:
    return df[df['true_label'] != df['predicted_label']].copy()

misclassified = find_misclassified_samples(val_df)
print("Misclassified samples:", len(misclassified))


In [None]:
def pick_reference_from_label(label: int) -> int:
    candidates = label_to_subset_indices.get(int(label), [])
    return candidates[0] if candidates else None

def show_pairs(pairs, title):
    n = len(pairs)
    fig, axes = plt.subplots(n, 2, figsize=(6, 3 * n))
    fig.suptitle(title)
    if n == 1:
        axes = np.array([axes])
    for row_idx, (left_img, right_img, left_title, right_title) in enumerate(pairs):
        axes[row_idx, 0].imshow(left_img)
        axes[row_idx, 0].axis('off')
        axes[row_idx, 0].set_title(left_title)
        axes[row_idx, 1].imshow(right_img)
        axes[row_idx, 1].axis('off')
        axes[row_idx, 1].set_title(right_title)
    plt.tight_layout()
    plt.show()


In [None]:
# FP pairs: misclassified sample + a reference image from predicted class
fp_pairs = []
for _, row in misclassified.head(5).iterrows():
    subset_idx = int(row['sample_index'])
    pred_label = int(row['predicted_label'])
    ref_idx = pick_reference_from_label(pred_label)

    img_a = load_image_by_subset_index(subset_idx)
    title_a = f"True {int(row['true_label'])} -> Pred {pred_label}"

    if ref_idx is None:
        img_b = img_a
        title_b = f"No ref for pred {pred_label}"
    else:
        img_b = load_image_by_subset_index(ref_idx)
        title_b = f"Reference pred {pred_label}"

    fp_pairs.append((img_a, img_b, title_a, title_b))

show_pairs(fp_pairs, "False Positives (Predicted Class Reference)")


In [None]:
# FN pairs: misclassified sample + a reference image from true class
fn_pairs = []
for _, row in misclassified.head(5).iterrows():
    subset_idx = int(row['sample_index'])
    true_label = int(row['true_label'])
    ref_idx = pick_reference_from_label(true_label)

    img_a = load_image_by_subset_index(subset_idx)
    title_a = f"True {true_label} -> Pred {int(row['predicted_label'])}"

    if ref_idx is None:
        img_b = img_a
        title_b = f"No ref for true {true_label}"
    else:
        img_b = load_image_by_subset_index(ref_idx)
        title_b = f"Reference true {true_label}"

    fn_pairs.append((img_a, img_b, title_a, title_b))

show_pairs(fn_pairs, "False Negatives (True Class Reference)")


In [None]:
# LFW results (latest epoch)
lfw_json = run_dir / "lfw_epoch_10.json"
lfw_roc = run_dir / "lfw_roc_epoch_10.png"

lfw_results = json.loads(lfw_json.read_text())
print("LFW Results:", lfw_results)

lfw_img = Image.open(lfw_roc)
plt.figure(figsize=(6, 5))
plt.imshow(lfw_img)
plt.axis('off')
plt.title('LFW ROC (Epoch 10)')
plt.show()
