# Face Training Telemetry Analysis

Analyzing the training run to understand why it didn't converge.

**Key metrics tracked:**
- Gradient norms (overall, backbone, embedding, arcface)
- Weight deltas (how much weights changed)
- ArcFace weight statistics
- Learning rates

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

# Configure paths - adjust if needed
TELEMETRY_DIR = Path("../outputs/face/20260123_110841/telemetry")
# Alternative for RunPod:
# TELEMETRY_DIR = Path("/workspace/sim-bench/outputs/face/20260123_110841/telemetry")

print(f"Looking for telemetry in: {TELEMETRY_DIR.absolute()}")
print(f"Files found: {list(TELEMETRY_DIR.glob('*.csv'))}")

In [None]:
# Load all telemetry data
gradient_norms = pd.read_csv(TELEMETRY_DIR / 'gradient_norms.csv')
weight_deltas = pd.read_csv(TELEMETRY_DIR / 'weight_deltas.csv')
arcface_stats = pd.read_csv(TELEMETRY_DIR / 'arcface_stats.csv')
learning_rates = pd.read_csv(TELEMETRY_DIR / 'learning_rates.csv')

print(f"Gradient norms: {len(gradient_norms)} rows")
print(f"Weight deltas: {len(weight_deltas)} rows")
print(f"ArcFace stats: {len(arcface_stats)} rows")
print(f"Learning rates: {len(learning_rates)} rows")

In [None]:
# Create a global step index for plotting
gradient_norms['step'] = range(len(gradient_norms))
weight_deltas['step'] = range(len(weight_deltas))
arcface_stats['step'] = range(len(arcface_stats))

# Show data structure
print("=== Gradient Norms ===")
display(gradient_norms.head())
print("\n=== Weight Deltas ===")
display(weight_deltas.head())
print("\n=== ArcFace Stats ===")
display(arcface_stats.head())

## 1. Gradient Norms Analysis

Healthy training should show:
- Gradients that are neither too large (exploding) nor too small (vanishing)
- Relatively stable gradients over time
- Backbone gradients typically smaller than head gradients (with differential LR)

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Overall gradient norm
ax = axes[0, 0]
ax.plot(gradient_norms['step'], gradient_norms['overall'], alpha=0.7)
ax.set_xlabel('Step')
ax.set_ylabel('L2 Norm')
ax.set_title('Overall Gradient Norm')
ax.set_yscale('log')
ax.grid(True, alpha=0.3)

# Component breakdown
ax = axes[0, 1]
ax.plot(gradient_norms['step'], gradient_norms['backbone'], label='Backbone', alpha=0.7)
ax.plot(gradient_norms['step'], gradient_norms['embedding'], label='Embedding', alpha=0.7)
ax.plot(gradient_norms['step'], gradient_norms['arcface'], label='ArcFace', alpha=0.7)
ax.set_xlabel('Step')
ax.set_ylabel('L2 Norm')
ax.set_title('Gradient Norms by Component')
ax.set_yscale('log')
ax.legend()
ax.grid(True, alpha=0.3)

# Ratio of arcface to backbone gradients
ax = axes[1, 0]
ratio = gradient_norms['arcface'] / (gradient_norms['backbone'] + 1e-8)
ax.plot(gradient_norms['step'], ratio, alpha=0.7, color='purple')
ax.set_xlabel('Step')
ax.set_ylabel('Ratio')
ax.set_title('ArcFace/Backbone Gradient Ratio')
ax.axhline(y=10, color='r', linestyle='--', alpha=0.5, label='10x (expected with diff LR)')
ax.legend()
ax.grid(True, alpha=0.3)

# Distribution of gradients
ax = axes[1, 1]
ax.hist(gradient_norms['overall'], bins=50, alpha=0.7, edgecolor='black')
ax.set_xlabel('Gradient Norm')
ax.set_ylabel('Frequency')
ax.set_title('Distribution of Overall Gradient Norms')
ax.axvline(x=gradient_norms['overall'].median(), color='r', linestyle='--', label=f'Median: {gradient_norms["overall"].median():.2f}')
ax.legend()

plt.tight_layout()
plt.show()

print(f"\nGradient Statistics:")
print(gradient_norms[['overall', 'backbone', 'embedding', 'arcface']].describe())

## 2. Weight Deltas Analysis

Weight deltas show how much the weights actually changed between telemetry collections.

**What to look for:**
- If deltas are very small, learning is slow
- If deltas are huge, learning may be unstable
- Relative magnitudes between components

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Backbone weight deltas
ax = axes[0]
ax.plot(weight_deltas['step'], weight_deltas['backbone_delta'], alpha=0.7)
ax.set_xlabel('Step')
ax.set_ylabel('L2 Delta')
ax.set_title('Backbone Weight Changes')
ax.grid(True, alpha=0.3)

# Embedding weight deltas
ax = axes[1]
ax.plot(weight_deltas['step'], weight_deltas['embedding_delta'], alpha=0.7, color='orange')
ax.set_xlabel('Step')
ax.set_ylabel('L2 Delta')
ax.set_title('Embedding Weight Changes')
ax.grid(True, alpha=0.3)

# ArcFace weight deltas
ax = axes[2]
ax.plot(weight_deltas['step'], weight_deltas['arcface_delta'], alpha=0.7, color='green')
ax.set_xlabel('Step')
ax.set_ylabel('L2 Delta')
ax.set_title('ArcFace Weight Changes')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nWeight Delta Statistics:")
print(weight_deltas[['backbone_delta', 'embedding_delta', 'arcface_delta']].describe())

In [None]:
# Compare all components on same plot
fig, ax = plt.subplots(figsize=(12, 5))

ax.plot(weight_deltas['step'], weight_deltas['backbone_delta'], label='Backbone', alpha=0.7)
ax.plot(weight_deltas['step'], weight_deltas['embedding_delta'], label='Embedding', alpha=0.7)
ax.plot(weight_deltas['step'], weight_deltas['arcface_delta'], label='ArcFace', alpha=0.7)

ax.set_xlabel('Step')
ax.set_ylabel('L2 Weight Delta')
ax.set_title('Weight Changes by Component (Per Collection Interval)')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_yscale('log')

plt.tight_layout()
plt.show()

## 3. ArcFace Weight Statistics

The ArcFace weight matrix is crucial - it represents class prototypes in the embedding space.

**What to look for:**
- `weight_norm_mean` should be close to 1.0 (weights are normalized during forward pass)
- `weight_norm_std` should be small (consistent class prototypes)
- `weight_mean` and `weight_std` show the raw weight distribution

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Weight mean
ax = axes[0, 0]
ax.plot(arcface_stats['step'], arcface_stats['weight_mean'], alpha=0.7)
ax.set_xlabel('Step')
ax.set_ylabel('Mean')
ax.set_title('ArcFace Weight Mean')
ax.axhline(y=0, color='r', linestyle='--', alpha=0.5)
ax.grid(True, alpha=0.3)

# Weight std
ax = axes[0, 1]
ax.plot(arcface_stats['step'], arcface_stats['weight_std'], alpha=0.7, color='orange')
ax.set_xlabel('Step')
ax.set_ylabel('Std')
ax.set_title('ArcFace Weight Std')
ax.grid(True, alpha=0.3)

# Weight norm mean (should be ~1 after normalization)
ax = axes[1, 0]
ax.plot(arcface_stats['step'], arcface_stats['weight_norm_mean'], alpha=0.7, color='green')
ax.set_xlabel('Step')
ax.set_ylabel('Mean Norm')
ax.set_title('ArcFace Per-Class Weight Norm Mean')
ax.axhline(y=1.0, color='r', linestyle='--', alpha=0.5, label='Expected: 1.0')
ax.legend()
ax.grid(True, alpha=0.3)

# Weight norm std
ax = axes[1, 1]
ax.plot(arcface_stats['step'], arcface_stats['weight_norm_std'], alpha=0.7, color='purple')
ax.set_xlabel('Step')
ax.set_ylabel('Std of Norms')
ax.set_title('ArcFace Per-Class Weight Norm Std')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nArcFace Statistics:")
print(arcface_stats[['weight_mean', 'weight_std', 'weight_norm_mean', 'weight_norm_std']].describe())

## 4. Learning Rate Check

In [None]:
print("Learning Rate Schedule:")
print(learning_rates.head(20))

# Plot if there are multiple groups
lr_cols = [c for c in learning_rates.columns if c.startswith('lr_group')]
if lr_cols:
    fig, ax = plt.subplots(figsize=(10, 4))
    for col in lr_cols:
        ax.plot(range(len(learning_rates)), learning_rates[col], label=col)
    ax.set_xlabel('Step')
    ax.set_ylabel('Learning Rate')
    ax.set_title('Learning Rates Over Training')
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.show()

## 5. Epoch-Level Analysis

Aggregate statistics per epoch to see trends.

In [None]:
# Aggregate by epoch
epoch_grad = gradient_norms.groupby('epoch').agg({
    'overall': ['mean', 'std', 'max'],
    'backbone': 'mean',
    'embedding': 'mean',
    'arcface': 'mean'
}).round(4)

epoch_deltas = weight_deltas.groupby('epoch').agg({
    'backbone_delta': ['mean', 'sum'],
    'embedding_delta': ['mean', 'sum'],
    'arcface_delta': ['mean', 'sum']
}).round(4)

print("=== Gradient Norms by Epoch ===")
display(epoch_grad)

print("\n=== Weight Deltas by Epoch ===")
display(epoch_deltas)

## 6. Diagnostic Summary

Key questions to answer:
1. Are gradients exploding or vanishing?
2. Is the model actually updating (weight deltas > 0)?
3. Is one component dominating training?
4. Are the ArcFace weights becoming degenerate?

In [None]:
print("=" * 60)
print("DIAGNOSTIC SUMMARY")
print("=" * 60)

# Gradient analysis
grad_overall_mean = gradient_norms['overall'].mean()
grad_overall_max = gradient_norms['overall'].max()
print(f"\n1. GRADIENT NORMS:")
print(f"   Overall mean: {grad_overall_mean:.4f}")
print(f"   Overall max:  {grad_overall_max:.4f}")
if grad_overall_mean > 100:
    print("   ⚠️  WARNING: Gradients are VERY LARGE - possible explosion")
elif grad_overall_mean < 0.001:
    print("   ⚠️  WARNING: Gradients are VERY SMALL - possible vanishing")
else:
    print("   ✓ Gradient magnitude looks reasonable")

# Weight delta analysis
backbone_change = weight_deltas['backbone_delta'].sum()
arcface_change = weight_deltas['arcface_delta'].sum()
print(f"\n2. TOTAL WEIGHT CHANGES:")
print(f"   Backbone total: {backbone_change:.4f}")
print(f"   Embedding total: {weight_deltas['embedding_delta'].sum():.4f}")
print(f"   ArcFace total:  {arcface_change:.4f}")
if arcface_change > backbone_change * 100:
    print("   ⚠️  WARNING: ArcFace changing much faster than backbone")

# ArcFace weight analysis
weight_norm_final = arcface_stats['weight_norm_mean'].iloc[-1]
weight_norm_std = arcface_stats['weight_norm_std'].iloc[-1]
print(f"\n3. ARCFACE WEIGHT NORMS (final):")
print(f"   Mean norm: {weight_norm_final:.4f} (expected ~1.0)")
print(f"   Std norm:  {weight_norm_std:.4f} (lower is better)")

# Learning rate check
print(f"\n4. LEARNING RATES:")
for col in lr_cols:
    print(f"   {col}: {learning_rates[col].iloc[0]:.6f}")

print("\n" + "=" * 60)

## 7. Potential Issues Checklist

Based on the analysis, check these common ArcFace training issues:

- [ ] **Gradient explosion**: Overall gradient > 100
- [ ] **Gradient vanishing**: Overall gradient < 0.001
- [ ] **Imbalanced learning**: ArcFace changes >> Backbone changes
- [ ] **Learning rate too high**: Large oscillations in weight deltas
- [ ] **Learning rate too low**: Minimal weight changes
- [ ] **ArcFace weight collapse**: All weight norms becoming similar (degenerate solution)