# Step 2: Filter, Normalize, and Reorganize RDM

This notebook filters out low-quality categories and reorganizes the RDM by category type with optional hierarchical clustering.

## Overview

This step:
1. Loads category average embeddings from Step 1
2. Filters categories based on inclusion/exclusion lists
3. Normalize, computes and saves filtered RDM matrices
4. Organizes categories by type (animals, bodyparts, big objects, small objects)
5. Optionally applies hierarchical clustering within each group
6. Saves reorganized RDM matrices
7. Creates visualizations

## Prerequisites

This step requires:
- Output from Step 1 (CLIP or DINOV3 average category embeddings)
- CDI words CSV file for category type information
- Inclusion or exclusion list file

## Setup and Imports

In [233]:
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.cluster.hierarchy import linkage, dendrogram, optimal_leaf_ordering
from scipy.spatial.distance import squareform
from sklearn.metrics.pairwise import cosine_similarity
import warnings
warnings.filterwarnings('ignore')

# Set matplotlib backend
import matplotlib
matplotlib.use('Agg')

print("All imports successful!")

All imports successful!


## Define helper functions

In [234]:
def load_category_types(cdi_path):
    """Load category type information from CDI words CSV"""
    print(f"Loading category types from {cdi_path}...")
    cdi_df = pd.read_csv(cdi_path)
    
    category_types = {}
    for _, row in cdi_df.iterrows():
        category_types[row['uni_lemma']] = {
            'is_animate': bool(row.get('is_animate', 0)),
            'is_bodypart': bool(row.get('is_bodypart', 0)),
            'is_small': bool(row.get('is_small', 0)),
            'is_big': bool(row.get('is_big', 0))
        }
    
    print(f"Loaded type information for {len(category_types)} categories")
    return category_types

def cluster_categories_within_group(group_categories, cat_to_embedding, save_dendrogram=False, output_dir=None, group_name=None):
    """
    Perform hierarchical clustering within a group of categories.
    
    Args:
        group_categories: List of category names in the group
        cat_to_embedding: Dictionary mapping category names to embeddings
        save_dendrogram: Whether to save dendrogram plot (default: False)
        output_dir: Output directory for saving dendrogram (required if save_dendrogram=True)
        group_name: Name of the group for saving dendrogram (required if save_dendrogram=True)
    
    Returns:
        List of category names reordered according to clustering dendrogram
    """
    if len(group_categories) <= 1:
        return group_categories, None
    
    # Get embeddings for this group
    group_embeddings = np.array([cat_to_embedding[cat] for cat in group_categories])
    
    # Normalize embeddings
    normalized_embeddings = (group_embeddings - group_embeddings.mean(axis=0)) / (group_embeddings.std(axis=0) + 1e-10)
    
    # Compute distance matrix (1 - cosine similarity)
    similarity_matrix = cosine_similarity(normalized_embeddings)
    distance_matrix = 1 - similarity_matrix
    np.fill_diagonal(distance_matrix, 0)
    
    # Convert to condensed form for linkage
    condensed_distances = squareform(distance_matrix)
    
    # Perform hierarchical clustering
    linkage_matrix = linkage(condensed_distances, method='ward')
    
    # Get optimal leaf ordering for better visualization
    try:
        linkage_matrix = optimal_leaf_ordering(linkage_matrix, condensed_distances)
    except:
        # If optimal leaf ordering fails, use original linkage
        pass
    
    # Extract the order from the dendrogram
    dendro_dict = dendrogram(linkage_matrix, no_plot=True)
    leaf_order = dendro_dict['leaves']
    
    # Reorder categories according to clustering
    clustered_categories = [group_categories[i] for i in leaf_order]
    
    # Save dendrogram if requested
    if save_dendrogram and output_dir is not None and group_name is not None:
        output_dir = Path(output_dir)
        output_dir.mkdir(exist_ok=True, parents=True)
        
        plt.figure(figsize=(12, 8))
        dendrogram(linkage_matrix, 
                  labels=group_categories,
                  leaf_rotation=90,
                  leaf_font_size=10)
        plt.title(f'Hierarchical Clustering Dendrogram: {group_name.upper()}\n({len(group_categories)} categories)',
                 fontsize=16, pad=20)
        plt.xlabel('Category', fontsize=12)
        plt.ylabel('Distance', fontsize=12)
        plt.tight_layout()
        
        # Save as PNG
        output_path_png = output_dir / f'dendrogram_{group_name}.png'
        plt.savefig(output_path_png, dpi=300, bbox_inches='tight', pad_inches=0.2)
        print(f"    Saved dendrogram to {output_path_png}")
        
        # Save as PDF
        output_path_pdf = output_dir / f'dendrogram_{group_name}.pdf'
        plt.savefig(output_path_pdf, bbox_inches='tight', pad_inches=0.2)
        print(f"    Saved dendrogram to {output_path_pdf}")
        
        plt.close()
    
    return clustered_categories, linkage_matrix

print("Helper functions loaded!")

Helper functions loaded!


## Configuration

**Please update the paths below according to your setup:**

In [235]:
# ============================================================================
# CONFIGURATION - UPDATE THESE PATHS FOR YOUR SETUP
# ============================================================================

# Input: Path to .npz file containing category average embeddings from Step 1
INPUT_EMBEDDING_PATH = "./bv_dino_rdm_results_26/category_average_embeddings.npz"  # Direct path to .npz file (e.g., "things_clip_embeddings.npz")

# Output directory prefix (e.g., "bv_clip", "bv_dinov3", "things_clip", "things_dinov3")
# The output directory will be constructed as: {OUTPUT_DIR_PREFIX}_filtered_zscored_hierarchical_{n_cats}cats/
OUTPUT_DIR_PREFIX = "bv_dino"  # Change this for different datasets/models

# CDI words CSV file (required for category type organization)
CDI_PATH = "../../data/cdi_words.csv"

# Filtering options
EXCLUSION_FILE = None  # Path to text file with categories to exclude (one per line), or None
INCLUSION_FILE = "../../data/things_bv_overlap_categories_exclude_zero_precisions.txt"  # Path to text file with categories to include (one per line), or None

# Processing options
USE_CLUSTERING = True  # Enable hierarchical clustering within category groups
SAVE_DENDROGRAMS = True  # Save dendrogram plots for each category group

print("Configuration loaded. Please review and update paths as needed.")

Configuration loaded. Please review and update paths as needed.


### Step 1: Load average category embeddings

In [236]:
# Step 1: Load category average embeddings from Step 1
npz_path = Path(INPUT_EMBEDDING_PATH)
if not npz_path.exists():
    raise FileNotFoundError(f"{npz_path} not found. Please run Step 1 first or check the path.")

print(f"\n[Step 1/8] Loading category averages from {npz_path}...")
data = np.load(npz_path)
embeddings = data['embeddings']

# Handle different key names (categories, category, labels, label)
available_keys = list(data.keys())
categories = None
for key_name in ['categories', 'category', 'labels', 'label']:
    if key_name in available_keys:
        categories = data[key_name]
        # Convert numpy array of strings to list
        if categories.dtype == 'object':
            categories = [str(cat) for cat in categories]
        else:
            categories = categories.tolist()
        print(f"  Found categories under key '{key_name}'")
        break

if categories is None:
    raise KeyError(f"Expected 'categories', 'category', 'labels', or 'label' key in NPZ file. Available keys: {available_keys}")

print(f"  Loaded {len(categories)} categories with embeddings of shape {embeddings.shape}")



[Step 1/8] Loading category averages from bv_dino_rdm_results_26/category_average_embeddings.npz...
  Found categories under key 'categories'
  Loaded 291 categories with embeddings of shape (291, 768)


### Step 2: Filter categories based on inclusion/exclusion

In [237]:
# Step 2: Filter categories based on inclusion/exclusion lists
print(f"\n[Step 2/8] Filtering categories...")
if INCLUSION_FILE:
    print(f"  Loading included categories from {INCLUSION_FILE}...")
    with open(INCLUSION_FILE, 'r') as f:
        included_categories = set(line.strip() for line in f if line.strip())
    print(f"  Found {len(included_categories)} categories to include")
    excluded_categories = set()
elif EXCLUSION_FILE:
    print(f"  Loading excluded categories from {EXCLUSION_FILE}...")
    with open(EXCLUSION_FILE, 'r') as f:
        excluded_categories = set(line.strip() for line in f if line.strip())
    print(f"  Found {len(excluded_categories)} categories to exclude")
    included_categories = None
else:
    print("  No filtering file specified. Using all categories.")
    included_categories = None
    excluded_categories = set()

# Filter categories
if included_categories is not None:
    filtered_indices = [i for i, cat in enumerate(categories) if cat in included_categories]
    filtered_categories = [categories[i] for i in filtered_indices]
    filtered_embeddings = embeddings[filtered_indices]
    print(f"  After filtering: {len(filtered_categories)} categories")
else:
    filtered_indices = [i for i, cat in enumerate(categories) if cat not in excluded_categories]
    filtered_categories = [categories[i] for i in filtered_indices]
    filtered_embeddings = embeddings[filtered_indices]
    print(f"  After filtering: {len(filtered_categories)} categories (excluded {len(excluded_categories)})")



[Step 2/8] Filtering categories...
  Loading included categories from ../../data/things_bv_overlap_categories_exclude_zero_precisions.txt...
  Found 163 categories to include
  After filtering: 163 categories


### Step 3: Normalize embeddings (z-score normalization)

In [238]:
# Step 3: Normalize embeddings (z-score normalization)
print(f"\n[Step 3/8] Normalizing filtered embeddings...")
normalized_filtered_embeddings = (filtered_embeddings - filtered_embeddings.mean(axis=0)) / (filtered_embeddings.std(axis=0) + 1e-10)
print(f"  Normalized embeddings shape: {normalized_filtered_embeddings.shape}")

# Sort categories alphabetically and reorder embeddings to match
print(f"  Sorting categories alphabetically...")
sorted_category_indices = sorted(range(len(filtered_categories)), key=lambda i: filtered_categories[i])
sorted_filtered_categories = [filtered_categories[i] for i in sorted_category_indices]
sorted_normalized_filtered_embeddings = normalized_filtered_embeddings[sorted_category_indices]

# Update filtered_categories and normalized_filtered_embeddings to sorted versions
filtered_categories = sorted_filtered_categories
normalized_filtered_embeddings = sorted_normalized_filtered_embeddings
print(f"  Categories and embeddings sorted alphabetically")

# Construct output directory name based on prefix
n_cats = len(filtered_categories)
output_dir_name = f"{OUTPUT_DIR_PREFIX}_filtered_zscored_hierarchical_{n_cats}cats"
output_dir = Path(output_dir_name)
output_dir.mkdir(exist_ok=True, parents=True)

# Save normalized embeddings as CSV with category names
# Create DataFrame with category names as index and embedding dimensions as columns
embedding_df = pd.DataFrame(
    normalized_filtered_embeddings,
    index=filtered_categories,
    columns=[f'dim_{i}' for i in range(normalized_filtered_embeddings.shape[1])]
)
csv_path = output_dir / 'normalized_filtered_embeddings.csv'
embedding_df.to_csv(csv_path)
print(f"  Saved alphabetically sorted normalized embeddings to {csv_path}")
print(f"    CSV contains {len(filtered_categories)} categories with {normalized_filtered_embeddings.shape[1]} embedding dimensions")


[Step 3/8] Normalizing filtered embeddings...
  Normalized embeddings shape: (163, 768)
  Sorting categories alphabetically...
  Categories and embeddings sorted alphabetically
  Saved alphabetically sorted normalized embeddings to bv_dino_filtered_zscored_hierarchical_163cats/normalized_filtered_embeddings.csv
    CSV contains 163 categories with 768 embedding dimensions


### Step 4: Compute RDM matrices on filtered, normalized embeddings

In [239]:
# Step 4: Compute RDM matrices on filtered, normalized embeddings (before reorganization)
print(f"\n[Step 4/8] Computing RDM matrices on filtered, normalized embeddings...")
# Compute similarity and distance matrices using normalized embeddings
similarity_matrix_filtered = cosine_similarity(normalized_filtered_embeddings)
distance_matrix_filtered = 1 - similarity_matrix_filtered
np.fill_diagonal(distance_matrix_filtered, 0)
distance_matrix_filtered = (distance_matrix_filtered + distance_matrix_filtered.T) / 2

print(f"  Computed RDM matrices with shape {distance_matrix_filtered.shape}")
print(f"  Mean distance: {distance_matrix_filtered.mean():.4f}")
print(f"  Std distance: {distance_matrix_filtered.std():.4f}")

# Sort matrices by alphabetical order of categories
sorted_categories = sorted(filtered_categories)
sorted_indices = [filtered_categories.index(cat) for cat in sorted_categories]
similarity_matrix_sorted = similarity_matrix_filtered[np.ix_(sorted_indices, sorted_indices)]
distance_matrix_sorted = distance_matrix_filtered[np.ix_(sorted_indices, sorted_indices)]


# Save original matrices (before reorganization) for reference
print(f"  Saving original (pre-reorganization, alphabetically sorted) matrices to {output_dir}...")
np.save(output_dir / 'similarity_matrix_filtered_original.npy', similarity_matrix_sorted)
np.save(output_dir / 'distance_matrix_filtered_original.npy', distance_matrix_sorted)

sim_df_original = pd.DataFrame(similarity_matrix_sorted, index=sorted_categories, columns=sorted_categories)
sim_df_original.to_csv(output_dir / 'similarity_matrix_filtered_original.csv')

dist_df_original = pd.DataFrame(distance_matrix_sorted, index=sorted_categories, columns=sorted_categories)
dist_df_original.to_csv(output_dir / 'distance_matrix_filtered_original.csv')
print(f"  Saved original similarity and distance matrices")
print(f"  Note: RDM computed before reorganization, sorted alphabetically - will be reordered in Step 8")



[Step 4/8] Computing RDM matrices on filtered, normalized embeddings...
  Computed RDM matrices with shape (163, 163)
  Mean distance: 0.9987
  Std distance: 0.2037
  Saving original (pre-reorganization, alphabetically sorted) matrices to bv_dino_filtered_zscored_hierarchical_163cats...
  Saved original similarity and distance matrices
  Note: RDM computed before reorganization, sorted alphabetically - will be reordered in Step 8


### Step 5: Organize categories by type (animals, bodyparts, big objects, small objects)

In [240]:
# Step 5: Organize categories by type (animals, bodyparts, big objects, small objects)
print(f"\n[Step 5/8] Organizing categories by type...")
cdi_path = Path(CDI_PATH)
# Output directory already created in Step 4

if cdi_path.exists():
    category_types = load_category_types(cdi_path)
    
    # Organize by type
    organized = {
        'animals': [],
        'bodyparts': [],
        'big_objects': [],
        'small_objects': [],
        'others': []
    }
    
    # Use normalized embeddings for the mapping
    cat_to_embedding = {cat: emb for cat, emb in zip(filtered_categories, normalized_filtered_embeddings)}
    
    for cat in filtered_categories:
        if cat not in category_types:
            organized['others'].append(cat)
            continue
        
        types = category_types[cat]
        if types['is_animate']:
            organized['animals'].append(cat)
        elif types['is_bodypart']:
            organized['bodyparts'].append(cat)
        elif types['is_big']:
            organized['big_objects'].append(cat)
        elif types['is_small']:
            organized['small_objects'].append(cat)
        else:
            organized['others'].append(cat)
    
    print(f"  Organized into: {len(organized['animals'])} animals, {len(organized['bodyparts'])} bodyparts, "
          f"{len(organized['big_objects'])} big objects, {len(organized['small_objects'])} small objects, "
          f"{len(organized['others'])} others")
else:
    print(f"  Warning: CDI path {cdi_path} not found. Skipping organization by type.")
    organized = {'animals': [], 'bodyparts': [], 'big_objects': [], 'small_objects': [], 'others': filtered_categories}
    cat_to_embedding = {cat: emb for cat, emb in zip(filtered_categories, normalized_filtered_embeddings)}



[Step 5/8] Organizing categories by type...
Loading category types from ../../data/cdi_words.csv...
Loaded type information for 295 categories
  Organized into: 19 animals, 14 bodyparts, 32 big objects, 96 small objects, 2 others


### Step 6: Optionally apply hierarchical clustering within each group

In [241]:
# Step 6: Optionally apply hierarchical clustering within each group
print(f"\n[Step 6/8] Applying hierarchical clustering within groups...")
if USE_CLUSTERING:
    for key in organized:
        if len(organized[key]) > 1:
            print(f"  Clustering {key} ({len(organized[key])} categories)...")
            organized[key], _ = cluster_categories_within_group(
                organized[key], 
                cat_to_embedding,
                save_dendrogram=SAVE_DENDROGRAMS,
                output_dir=output_dir,
                group_name=key
            )
else:
    for key in organized:
        organized[key] = sorted(organized[key])

# Create ordered list
ordered_categories = (
    organized['animals'] +
    organized['bodyparts'] +
    organized['big_objects'] +
    organized['small_objects'] +
    organized['others']
)

# Save category organization info
with open(output_dir / 'category_organization_filtered.txt', 'w') as f:
    f.write("Category Organization:\n")
    f.write("="*60 + "\n")
    for key in ['animals', 'bodyparts', 'big_objects', 'small_objects', 'others']:
        f.write(f"\n{key.replace('_', ' ').title()} ({len(organized[key])}):\n")
        for cat in organized[key]:
            f.write(f"  {cat}\n")

with open(output_dir / 'category_names_filtered.txt', 'w') as f:
    for cat in ordered_categories:
        f.write(f"{cat}\n")



[Step 6/8] Applying hierarchical clustering within groups...
  Clustering animals (19 categories)...
    Saved dendrogram to bv_dino_filtered_zscored_hierarchical_163cats/dendrogram_animals.png
    Saved dendrogram to bv_dino_filtered_zscored_hierarchical_163cats/dendrogram_animals.pdf
  Clustering bodyparts (14 categories)...
    Saved dendrogram to bv_dino_filtered_zscored_hierarchical_163cats/dendrogram_bodyparts.png
    Saved dendrogram to bv_dino_filtered_zscored_hierarchical_163cats/dendrogram_bodyparts.pdf
  Clustering big_objects (32 categories)...
    Saved dendrogram to bv_dino_filtered_zscored_hierarchical_163cats/dendrogram_big_objects.png
    Saved dendrogram to bv_dino_filtered_zscored_hierarchical_163cats/dendrogram_big_objects.pdf
  Clustering small_objects (96 categories)...
    Saved dendrogram to bv_dino_filtered_zscored_hierarchical_163cats/dendrogram_small_objects.png
    Saved dendrogram to bv_dino_filtered_zscored_hierarchical_163cats/dendrogram_small_objects.pd

### Step 7: Compute and save filtered RDM matrices

In [242]:
# Step 7: Reorganize RDM matrices according to new category ordering
print(f"\n[Step 7/8] Reorganizing RDM matrices according to new category ordering...")
# Create mapping from old index to new index
old_to_new_index = {cat: i for i, cat in enumerate(filtered_categories)}
new_indices = [old_to_new_index[cat] for cat in ordered_categories]

# Reorganize the matrices (computed in Step 4) according to the new ordering
similarity_matrix = similarity_matrix_filtered[np.ix_(new_indices, new_indices)]
distance_matrix = distance_matrix_filtered[np.ix_(new_indices, new_indices)]

print(f"  Reorganized RDM matrices with shape {distance_matrix.shape}")
print(f"  Saving reorganized matrices to {output_dir}...")
np.save(output_dir / 'similarity_matrix_filtered.npy', similarity_matrix)
np.save(output_dir / 'distance_matrix_filtered.npy', distance_matrix)

sim_df = pd.DataFrame(similarity_matrix, index=ordered_categories, columns=ordered_categories)
sim_df.to_csv(output_dir / 'similarity_matrix_filtered.csv')

dist_df = pd.DataFrame(distance_matrix, index=ordered_categories, columns=ordered_categories)
dist_df.to_csv(output_dir / 'distance_matrix_filtered.csv')
print(f"  Saved reorganized similarity and distance matrices")
print(f"  Note: Original (pre-reorganization) matrices saved in Step 4 with '_original' suffix")



[Step 7/8] Reorganizing RDM matrices according to new category ordering...
  Reorganized RDM matrices with shape (163, 163)
  Saving reorganized matrices to bv_dino_filtered_zscored_hierarchical_163cats...
  Saved reorganized similarity and distance matrices
  Note: Original (pre-reorganization) matrices saved in Step 4 with '_original' suffix


### Step 8: Create visualizations

In [243]:
# Step 8: Create visualizations
print(f"\n[Step 8/8] Creating visualizations...")

# Create organized RDM heatmap
n_categories = len(ordered_categories)
fig_size = max(20, n_categories * 0.5)

# Set font size for axis labels (adaptive based on number of categories)
# Larger font for fewer categories, smaller but still readable for many categories
if n_categories <= 50:
    label_fontsize = 14
elif n_categories <= 100:
    label_fontsize = 12
elif n_categories <= 150:
    label_fontsize = 10
else:
    label_fontsize = 9

plt.figure(figsize=(fig_size, fig_size))
ax = sns.heatmap(distance_matrix, 
            xticklabels=ordered_categories,
            yticklabels=ordered_categories,
            cmap='viridis',
            vmin=0,
            vmax=2,
            square=True,
            cbar_kws={'label': 'Distance (1 - Cosine Similarity)', 'shrink': 0.8})

plt.title(f'{OUTPUT_DIR_PREFIX.replace("_", " ").title()} Category RDM (Filtered and Organized)', fontsize=24, pad=20)
plt.xticks(rotation=45, ha='right', fontsize=label_fontsize)
plt.yticks(rotation=0, fontsize=label_fontsize)
plt.tight_layout()
plt.savefig(output_dir / 'rdm_organized_filtered.png', dpi=300, bbox_inches='tight')
plt.savefig(output_dir / 'rdm_organized_filtered.pdf', bbox_inches='tight')
plt.close()
print(f"  Saved RDM heatmap to {output_dir / 'rdm_organized_filtered.png'}")
print(f"  Saved RDM heatmap to {output_dir / 'rdm_organized_filtered.pdf'}")

# Create coolwarm colormap version
plt.figure(figsize=(fig_size, fig_size))
ax = sns.heatmap(distance_matrix, 
            xticklabels=ordered_categories,
            yticklabels=ordered_categories,
            cmap='coolwarm',
            vmin=0,
            vmax=2,
            square=True,
            cbar_kws={'label': 'Distance (1 - Cosine Similarity)', 'shrink': 0.8})

plt.title(f'{OUTPUT_DIR_PREFIX.replace("_", " ").title()} Category RDM (Filtered and Organized)', fontsize=24, pad=20)
plt.xticks(rotation=45, ha='right', fontsize=label_fontsize)
plt.yticks(rotation=0, fontsize=label_fontsize)
plt.tight_layout()
plt.savefig(output_dir / 'rdm_organized_filtered_coolwarm.png', dpi=300, bbox_inches='tight')
plt.savefig(output_dir / 'rdm_organized_filtered_coolwarm.pdf', bbox_inches='tight')
plt.close()
print(f"  Saved RDM heatmap (coolwarm) to {output_dir / 'rdm_organized_filtered_coolwarm.png'}")
print(f"  Saved RDM heatmap (coolwarm) to {output_dir / 'rdm_organized_filtered_coolwarm.pdf'}")

# Find and save top similar/dissimilar pairs
print(f"  Computing top similar/dissimilar pairs...")
# Get upper triangle (excluding diagonal)
triu_indices = np.triu_indices_from(distance_matrix, k=1)
distances_flat = distance_matrix[triu_indices]
pairs_flat = [(ordered_categories[i], ordered_categories[j]) for i, j in zip(triu_indices[0], triu_indices[1])]

# Sort by distance (ascending for similar, descending for dissimilar)
sorted_indices_similar = np.argsort(distances_flat)
sorted_indices_dissimilar = np.argsort(distances_flat)[::-1]

for n in [20, 30, 50]:
    # Top similar pairs
    top_similar = [(pairs_flat[i], distances_flat[i]) for i in sorted_indices_similar[:n]]
    with open(output_dir / f'top_{n}_similar_pairs.txt', 'w') as f:
        f.write(f"Top {n} Most Similar Category Pairs:\n")
        f.write("="*60 + "\n")
        for (cat1, cat2), dist in top_similar:
            f.write(f"{cat1} <-> {cat2}: {dist:.4f}\n")
    
    # Top dissimilar pairs
    top_dissimilar = [(pairs_flat[i], distances_flat[i]) for i in sorted_indices_dissimilar[:n]]
    with open(output_dir / f'top_{n}_dissimilar_pairs.txt', 'w') as f:
        f.write(f"Top {n} Most Dissimilar Category Pairs:\n")
        f.write("="*60 + "\n")
        for (cat1, cat2), dist in top_dissimilar:
            f.write(f"{cat1} <-> {cat2}: {dist:.4f}\n")
    
    # Visualize top similar pairs
    if n <= 30:  # Only create visualizations for smaller numbers
        fig, axes = plt.subplots(1, 2, figsize=(16, 8))
        
        # Similar pairs
        similar_cats = list(set([cat for pair, _ in top_similar for cat in pair]))
        similar_indices = [ordered_categories.index(cat) for cat in similar_cats if cat in ordered_categories]
        similar_matrix = distance_matrix[np.ix_(similar_indices, similar_indices)]
        
        sns.heatmap(similar_matrix, 
                   xticklabels=[ordered_categories[i] for i in similar_indices],
                   yticklabels=[ordered_categories[i] for i in similar_indices],
                   cmap='viridis', vmin=0, vmax=2, square=True, ax=axes[0],
                   cbar_kws={'label': 'Distance'})
        axes[0].set_title(f'Top {n} Similar Pairs', fontsize=14)
        axes[0].tick_params(axis='x', rotation=45, labelsize=8)
        axes[0].tick_params(axis='y', rotation=0, labelsize=8)
        
        # Dissimilar pairs
        dissimilar_cats = list(set([cat for pair, _ in top_dissimilar for cat in pair]))
        dissimilar_indices = [ordered_categories.index(cat) for cat in dissimilar_cats if cat in ordered_categories]
        dissimilar_matrix = distance_matrix[np.ix_(dissimilar_indices, dissimilar_indices)]
        
        sns.heatmap(dissimilar_matrix, 
                   xticklabels=[ordered_categories[i] for i in dissimilar_indices],
                   yticklabels=[ordered_categories[i] for i in dissimilar_indices],
                   cmap='viridis', vmin=0, vmax=2, square=True, ax=axes[1],
                   cbar_kws={'label': 'Distance'})
        axes[1].set_title(f'Top {n} Dissimilar Pairs', fontsize=14)
        axes[1].tick_params(axis='x', rotation=45, labelsize=8)
        axes[1].tick_params(axis='y', rotation=0, labelsize=8)
        
        plt.tight_layout()
        plt.savefig(output_dir / f'top_{n}_similar_pairs.png', dpi=300, bbox_inches='tight')
        plt.savefig(output_dir / f'top_{n}_dissimilar_pairs.png', dpi=300, bbox_inches='tight')
        plt.close()

print(f"\n{'='*60}")
print(f"Filtering complete! Results saved to {output_dir}")
print(f"{'='*60}")
print(f"Original categories: {len(categories)}")
print(f"Filtered categories: {len(ordered_categories)}")
print(f"Mean distance: {distance_matrix.mean():.4f}")
print(f"Std distance: {distance_matrix.std():.4f}")
print(f"Min distance: {distance_matrix[triu_indices].min():.4f}")
print(f"Max distance: {distance_matrix[triu_indices].max():.4f}")



[Step 8/8] Creating visualizations...
  Saved RDM heatmap to bv_dino_filtered_zscored_hierarchical_163cats/rdm_organized_filtered.png
  Saved RDM heatmap to bv_dino_filtered_zscored_hierarchical_163cats/rdm_organized_filtered.pdf
  Saved RDM heatmap (coolwarm) to bv_dino_filtered_zscored_hierarchical_163cats/rdm_organized_filtered_coolwarm.png
  Saved RDM heatmap (coolwarm) to bv_dino_filtered_zscored_hierarchical_163cats/rdm_organized_filtered_coolwarm.pdf
  Computing top similar/dissimilar pairs...

Filtering complete! Results saved to bv_dino_filtered_zscored_hierarchical_163cats
Original categories: 291
Filtered categories: 163
Mean distance: 0.9987
Std distance: 0.2037
Min distance: 0.0328
Max distance: 1.4313
