# Normalize Age-Month Level Embeddings

This notebook normalizes CLIP embeddings at the age-month level (before aggregation).
It loads all `{subject_id}_{age_mo}_month_level_avg.npy` files from grouped age-month embeddings,
computes global normalization statistics feature-wise across all subjects/age bins,
and saves the normalized embeddings while preserving the original directory structure.

## Overview

This notebook:
1. Loads ALL `{subject_id}_{age_mo}_month_level_avg.npy` files from all categories (163 categories)
2. Excludes subject 00270001 from normalization statistics computation
3. Computes global normalization statistics feature-wise across ALL age_mo-level embeddings (excluding 00270001)
   - Processes ~34,550 embeddings from 31 subjects across age months 6-37
4. Normalizes each age_mo-level embedding using the global statistics: `(embedding - global_mean) / global_std`
5. Saves normalized embeddings to the output directory, preserving the original category folder structure

### Notes

- **Feature-wise normalization**: Normalizes each embedding dimension independently across all subjects/age bins
- **Age-month level normalization**: Normalizes at the raw data level (before aggregation), preserving age_mo information
- **Subject exclusion**: Subject 00270001 is excluded from normalization statistics but embeddings are still normalized
- **Directory structure**: Preserves original category-based folder structure in output directory

### Normalization Approach

- Load all `{subject_id}_{age_mo}_month_level_avg.npy` files from category folders
- Stack all embeddings into a matrix: (n_embeddings, 512)
- Compute feature-wise mean and std across ALL embeddings (excluding subject 00270001)
- Normalize each embedding: `(embedding - global_mean) / global_std`
- Save normalized embeddings back to files in the same directory structure

This ensures normalization happens at the raw data level, preserving age_mo information in normalized form for downstream analyses.

## Setup and Imports

In [1]:
import numpy as np
from pathlib import Path
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

print("All imports successful!")

All imports successful!


## Configuration

In [2]:
# Paths
embeddings_dir = Path("/data2/dataset/babyview/868_hours/outputs/yoloe_cdi_embeddings/facebook_dinov3-vitb16-pretrain-lvd1689m_grouped_by_age-mo")
normalized_embeddings_dir = Path("/data2/dataset/babyview/868_hours/outputs/yoloe_cdi_embeddings/facebook_dinov3-vitb16-pretrain-lvd1689m_grouped_by_age-mo_normalized")

# Categories file (optional - to filter to specific categories)
categories_file = Path("../../data/things_bv_overlap_categories_exclude_zero_precisions.txt")

# Subject to exclude from normalization and analyses
excluded_subject = "00270001"

print(f"Embeddings directory: {embeddings_dir}")
print(f"Normalized embeddings will be saved to: {normalized_embeddings_dir}")
print(f"Excluded subject: {excluded_subject}")

Embeddings directory: /data2/dataset/babyview/868_hours/outputs/yoloe_cdi_embeddings/facebook_dinov3-vitb16-pretrain-lvd1689m_grouped_by_age-mo
Normalized embeddings will be saved to: /data2/dataset/babyview/868_hours/outputs/yoloe_cdi_embeddings/facebook_dinov3-vitb16-pretrain-lvd1689m_grouped_by_age-mo_normalized
Excluded subject: 00270001


## Load Category List (Optional)

In [3]:
# Load allowed categories if file exists
allowed_categories = None
if categories_file.exists():
    print(f"Loading categories from {categories_file}...")
    with open(categories_file, 'r') as f:
        allowed_categories = set(line.strip() for line in f if line.strip())
    print(f"Loaded {len(allowed_categories)} categories")
else:
    print(f"Categories file not found, using all categories")

Loading categories from ../../data/things_bv_overlap_categories_exclude_zero_precisions.txt...
Loaded 163 categories


## Load All Age-Month Level Embeddings

In [4]:
def load_all_age_mo_embeddings(embeddings_dir, allowed_categories=None, excluded_subject=None):
    """
    Load ALL age-month level embeddings, preserving file paths for later saving.
    
    Returns:
        embeddings_data: list of dicts with keys:
            - 'embedding': np.array
            - 'category': str
            - 'subject_id': str
            - 'age_mo': int
            - 'file_path': Path (original file path)
    """
    embeddings_data = []
    
    # Get all category folders
    category_folders = [f for f in embeddings_dir.iterdir() if f.is_dir()]
    
    if allowed_categories:
        category_folders = [f for f in category_folders if f.name in allowed_categories]
    
    print(f"Loading embeddings from {len(category_folders)} categories...")
    
    for category_folder in tqdm(category_folders, desc="Loading categories"):
        category = category_folder.name
        
        # Get all embedding files in this category
        embedding_files = list(category_folder.glob("*.npy"))
        
        for emb_file in embedding_files:
            # Parse filename: {subject_id}_{age_mo}_month_level_avg.npy
            filename = emb_file.stem  # without .npy
            parts = filename.split('_')
            
            if len(parts) < 2:
                continue
            
            # Extract subject_id and age_mo
            # Format: S00560001_16_month_level_avg
            subject_id = parts[0]
            age_mo = int(parts[1]) if parts[1].isdigit() else None
            
            if age_mo is None:
                continue
            
            # Exclude subject if specified
            if excluded_subject and subject_id == excluded_subject:
                continue
            
            try:
                embedding = np.load(emb_file)
                
                embeddings_data.append({
                    'embedding': embedding,
                    'category': category,
                    'subject_id': subject_id,
                    'age_mo': age_mo,
                    'file_path': emb_file
                })
            except Exception as e:
                print(f"Error loading {emb_file}: {e}")
                continue
    
    return embeddings_data

# Load all age-month level embeddings
all_embeddings_data = load_all_age_mo_embeddings(
    embeddings_dir, 
    allowed_categories=allowed_categories,
    excluded_subject=excluded_subject
)

print(f"\nLoaded {len(all_embeddings_data)} age-month level embeddings")

# Show statistics
subjects = set(d['subject_id'] for d in all_embeddings_data)
categories = set(d['category'] for d in all_embeddings_data)
age_mos = set(d['age_mo'] for d in all_embeddings_data)

print(f"  Subjects: {len(subjects)}")
print(f"  Categories: {len(categories)}")
print(f"  Age months: {sorted(age_mos)}")
print(f"  Total embeddings: {len(all_embeddings_data)}")

Loading embeddings from 163 categories...


Loading categories:   0%|          | 0/163 [00:00<?, ?it/s]

Loading categories: 100%|██████████| 163/163 [00:06<00:00, 26.83it/s]


Loaded 34550 age-month level embeddings
  Subjects: 31
  Categories: 163
  Age months: [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 37]
  Total embeddings: 34550





## Compute Global Normalization Statistics

In [5]:
# Compute global normalization statistics across ALL age-month level embeddings
# This is feature-wise normalization: each embedding dimension is normalized across all subjects/age bins
print("Computing global normalization statistics across all age-month level embeddings...")
print(f"  (Excluding subject {excluded_subject} from statistics)")

# Stack all embeddings into a matrix: (n_embeddings, embedding_dim)
all_embeddings_matrix = np.array([d['embedding'].flatten() for d in all_embeddings_data])
print(f"  Embeddings matrix shape: {all_embeddings_matrix.shape}")

# Compute feature-wise (per-dimension) mean and std across all embeddings
# This normalizes each embedding dimension independently
global_mean = all_embeddings_matrix.mean(axis=0)  # Shape: (embedding_dim,)
global_std = all_embeddings_matrix.std(axis=0) + 1e-10  # Add small epsilon to avoid division by zero

print(f"  Global mean shape: {global_mean.shape}")
print(f"  Global std shape: {global_std.shape}")
print(f"  Global mean range: [{global_mean.min():.4f}, {global_mean.max():.4f}]")
print(f"  Global std range: [{global_std.min():.4f}, {global_std.max():.4f}]")

Computing global normalization statistics across all age-month level embeddings...
  (Excluding subject 00270001 from statistics)
  Embeddings matrix shape: (34550, 768)
  Global mean shape: (768,)
  Global std shape: (768,)
  Global mean range: [-1.4054, 1.0939]
  Global std range: [0.1121, 0.6863]


## Normalize and Save Age-Month Level Embeddings

In [6]:
# Normalize each age-month level embedding and save
print("\nNormalizing and saving age-month level embeddings...")

# Create output directory structure
normalized_embeddings_dir.mkdir(exist_ok=True, parents=True)

# Normalize and save each embedding
for emb_data in tqdm(all_embeddings_data, desc="Normalizing embeddings"):
    embedding = emb_data['embedding']
    category = emb_data['category']
    subject_id = emb_data['subject_id']
    age_mo = emb_data['age_mo']
    original_path = emb_data['file_path']
    
    # Apply global normalization: (x - global_mean) / global_std
    normalized_embedding = (embedding.flatten() - global_mean) / global_std
    
    # Create category directory in normalized embeddings directory
    output_category_dir = normalized_embeddings_dir / category
    output_category_dir.mkdir(exist_ok=True, parents=True)
    output_path = output_category_dir / original_path.name
    
    # Save normalized embedding
    np.save(output_path, normalized_embedding)

print(f"\nNormalized and saved {len(all_embeddings_data)} age-month level embeddings")
print(f"  Saved to: {normalized_embeddings_dir}")


Normalizing and saving age-month level embeddings...


Normalizing embeddings: 100%|██████████| 34550/34550 [00:01<00:00, 21997.56it/s]


Normalized and saved 34550 age-month level embeddings
  Saved to: /data2/dataset/babyview/868_hours/outputs/yoloe_cdi_embeddings/facebook_dinov3-vitb16-pretrain-lvd1689m_grouped_by_age-mo_normalized



