# Final Analysis: 3D Understanding in Vision Transformers

This notebook provides a comprehensive analysis of our interpretability experiments on Vision Transformers (DINOv2 and I-JEPA) for 3D understanding tasks:
- **3D Voxel Reconstruction**: How well can different layers reconstruct 3D shape from 2D views?
- **Viewpoint Estimation**: How accurately can models predict camera viewpoint parameters?

## Table of Contents
1. [Setup and Data Loading](#1-setup-and-data-loading)
2. [Voxel Reconstruction Analysis](#2-voxel-reconstruction-analysis)
3. [Viewpoint Estimation Analysis](#3-viewpoint-estimation-analysis)
4. [Advanced Feature Analysis](#4-advanced-feature-analysis)
5. [Synthesis and Conclusions](#5-synthesis-and-conclusions)

## 1. Setup and Data Loading

In [None]:
# Standard imports
import os
import json
import pickle
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union, Any

# Scientific computing
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import timm  # Added import

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
from mpl_toolkits.mplot3d import Axes3D

# Machine learning
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

# Project imports
from src.probing.probes import create_probe, VoxelProbe, LinearProbe, MLPProbe
from src.datasets.shapenet_voxel_meshes import create_3dr2n2_reconstruction_dataloaders
from src.datasets.shapenet_3dr2n2 import create_3dr2n2_dataloaders
from src.models.model_loader import load_model_and_preprocessor
from src.analysis.layer_analysis import LayerWiseAnalyzer

# Configure plotting
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
%matplotlib inline

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

### 1.1 Data Loading Infrastructure

In [None]:
@dataclass
class ExperimentResult:
    """Simple container for experiment results"""
    name: str
    model: str  # dinov2, ijepa, supervised_vit
    task: str   # viewpoint_estimation, voxel_reconstruction
    results: Dict[str, Any]
    cache_dir: Path
    results_dir: Path


class SimpleExperimentLoader:
    """Simplified experiment results loader based on known directory structure"""
    
    def __init__(self, base_path: str = "."):
        self.base_path = Path(base_path)
        self.cache_path = self.base_path / "cache"
        self.results_path = self.base_path / "results"
        
    def load_all_experiments(self) -> Dict[str, ExperimentResult]:
        """Load all experiments from results directory"""
        experiments = {}
        
        # Scan results directory for experiment folders
        for results_dir in self.results_path.iterdir():
            if not results_dir.is_dir() or results_dir.name.startswith('.'):
                continue
                
            exp_name = results_dir.name
            
            # Parse experiment name to extract model and task
            model, task = self._parse_experiment_name(exp_name)
            
            # Load results JSON
            results_file = results_dir / "results.json"
            if not results_file.exists():
                print(f"Warning: No results.json found for {exp_name}")
                continue
                
            with open(results_file, 'r') as f:
                results_data = json.load(f)
            
            # Find corresponding cache directory
            cache_dir = self._find_cache_dir(exp_name)
            
            experiments[exp_name] = ExperimentResult(
                name=exp_name,
                model=model,
                task=task,
                results=results_data,
                cache_dir=cache_dir,
                results_dir=results_dir
            )
            
        return experiments
    
    def _parse_experiment_name(self, exp_name: str) -> Tuple[str, str]:
        """Extract model and task from experiment name using simple rules"""
        # Extract model
        if 'dinov2' in exp_name.lower():
            model = 'dinov2'
        elif 'ijepa' in exp_name.lower():
            model = 'ijepa'
        elif 'supervised' in exp_name.lower():
            model = 'supervised_vit'
        else:
            model = 'unknown'
            
        # Extract task
        if 'viewpoint' in exp_name.lower():
            task = 'viewpoint_estimation'
        elif 'voxel' in exp_name.lower() or 'reconstruction' in exp_name.lower():
            task = 'voxel_reconstruction'
        else:
            task = 'unknown'
            
        return model, task
    
    def _find_cache_dir(self, exp_name: str) -> Path:
        """Find corresponding cache directory for experiment"""
        # Try exact match first
        cache_dir = self.cache_path / exp_name
        if cache_dir.exists():
            return cache_dir
            
        # Try alternative names for cache directories
        alternatives = [
            exp_name.replace('phase1_', '').replace('phase2_', ''),
            f"ijepa_{exp_name}" if 'ijepa' in exp_name else None,
            exp_name.replace('_probing', ''),
            exp_name.replace('_reconstruction', '')
        ]
        
        for alt_name in alternatives:
            if alt_name and (self.cache_path / alt_name).exists():
                return self.cache_path / alt_name
                
        # Return non-existent path if not found
        return self.cache_path / exp_name
    
    def get_probe_files(self, experiment: ExperimentResult) -> Dict[str, Dict[int, Path]]:
        """Get probe files organized by type and layer"""
        probes = {}
        probes_dir = experiment.cache_dir / "probes"
        
        if not probes_dir.exists():
            return probes
            
        for probe_file in probes_dir.glob("*.pth"):
            # Parse different filename patterns:
            # viewpoint: linear_layer_11.pth, mlp_layer_2.pth
            # voxel: dinov2_voxel_layer_11.pth
            
            filename = probe_file.stem
            
            if experiment.task == 'viewpoint_estimation':
                # Format: {probe_type}_layer_{layer_num}
                if '_layer_' in filename:
                    parts = filename.split('_layer_')
                    if len(parts) == 2 and parts[1].isdigit():
                        probe_type = parts[0]  # linear or mlp
                        layer_num = int(parts[1])
                        
                        if probe_type not in probes:
                            probes[probe_type] = {}
                        probes[probe_type][layer_num] = probe_file
                        
            elif experiment.task == 'voxel_reconstruction':
                # Format: {model}_voxel_layer_{layer_num}
                if '_voxel_layer_' in filename:
                    layer_part = filename.split('_voxel_layer_')[-1]
                    if layer_part.isdigit():
                        probe_type = 'voxel'
                        layer_num = int(layer_part)
                        
                        if probe_type not in probes:
                            probes[probe_type] = {}
                        probes[probe_type][layer_num] = probe_file
            
        return probes
    
    def get_feature_files(self, experiment: ExperimentResult) -> Dict[int, Dict[str, Path]]:
        """Get feature files organized by layer and split"""
        features = {}
        features_dir = experiment.cache_dir / "features"
        
        if not features_dir.exists():
            return features
            
        for feature_file in features_dir.glob("*.pkl"):
            # Format: layer_{layer_num}_{split}.pkl
            filename = feature_file.stem
            if filename.startswith('layer_'):
                parts = filename.split('_')
                if len(parts) >= 3 and parts[1].isdigit():
                    layer_num = int(parts[1])
                    split = parts[2]  # train, val, test
                    
                    if layer_num not in features:
                        features[layer_num] = {}
                    features[layer_num][split] = feature_file
                    
        return features

In [None]:
# Load all experiment data
loader = SimpleExperimentLoader()
experiments = loader.load_all_experiments()

print(f"Loaded {len(experiments)} experiments:")
for name, exp in experiments.items():
    print(f"\n{name}:")
    print(f"  Model: {exp.model}")
    print(f"  Task: {exp.task}")
    
    # Get probe and feature information
    probes = loader.get_probe_files(exp)
    features = loader.get_feature_files(exp)
    
    print(f"  Probe types: {list(probes.keys())}")
    print(f"  Layers with features: {sorted(features.keys())}")
    if probes:
        for probe_type, layers in probes.items():
            print(f"    {probe_type} probes: layers {sorted(layers.keys())}")
    if features:
        for layer, splits in features.items():
            print(f"    Layer {layer}: {sorted(splits.keys())} splits")

## 2. Voxel Reconstruction Analysis

We analyze how well different layers of DINOv2 can reconstruct 3D voxel representations from 2D images.

### 2.1 Quantitative Analysis

In [None]:
def plot_voxel_performance(experiment: ExperimentResult, loader: SimpleExperimentLoader):
    """Plot voxel reconstruction metrics across layers"""
    # Extract metrics from results
    metrics_dict = experiment.results.get('results', {})
    
    layers = []
    iou_scores = []
    precision_scores = []
    recall_scores = []
    f1_scores = []
    
    # Parse layer results
    for layer_key, layer_data in metrics_dict.items():
        if layer_key.startswith('layer_'):
            layer_num = int(layer_key.split('_')[1])
            
            # Look for voxel probe results
            if 'voxel' in layer_data:
                test_metrics = layer_data['voxel'].get('test_metrics', {})
                
                layers.append(layer_num)
                iou_scores.append(test_metrics.get('voxel_iou', 0))
                precision_scores.append(test_metrics.get('voxel_precision', 0))
                recall_scores.append(test_metrics.get('voxel_recall', 0))
                f1_scores.append(test_metrics.get('voxel_f1', 0))
    
    # Sort by layer number
    sorted_indices = np.argsort(layers)
    layers = np.array(layers)[sorted_indices]
    iou_scores = np.array(iou_scores)[sorted_indices]
    precision_scores = np.array(precision_scores)[sorted_indices]
    recall_scores = np.array(recall_scores)[sorted_indices]
    f1_scores = np.array(f1_scores)[sorted_indices]
    
    # Create subplots
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle(f'Voxel Reconstruction Performance - {experiment.model.upper()}', fontsize=16)
    
    # Plot IoU
    axes[0, 0].plot(layers, iou_scores, 'o-', linewidth=2, markersize=8, color='royalblue')
    axes[0, 0].set_xlabel('Layer')
    axes[0, 0].set_ylabel('IoU Score')
    axes[0, 0].set_title('Intersection over Union')
    axes[0, 0].grid(True, alpha=0.3)
    axes[0, 0].set_ylim([0, 1])
    
    # Plot Precision
    axes[0, 1].plot(layers, precision_scores, 's-', linewidth=2, markersize=8, color='forestgreen')
    axes[0, 1].set_xlabel('Layer')
    axes[0, 1].set_ylabel('Precision')
    axes[0, 1].set_title('Voxel Precision')
    axes[0, 1].grid(True, alpha=0.3)
    axes[0, 1].set_ylim([0, 1])
    
    # Plot Recall
    axes[1, 0].plot(layers, recall_scores, '^-', linewidth=2, markersize=8, color='darkorange')
    axes[1, 0].set_xlabel('Layer')
    axes[1, 0].set_ylabel('Recall')
    axes[1, 0].set_title('Voxel Recall')
    axes[1, 0].grid(True, alpha=0.3)
    axes[1, 0].set_ylim([0, 1])
    
    # Plot F1
    axes[1, 1].plot(layers, f1_scores, 'd-', linewidth=2, markersize=8, color='crimson')
    axes[1, 1].set_xlabel('Layer')
    axes[1, 1].set_ylabel('F1 Score')
    axes[1, 1].set_title('Voxel F1 Score')
    axes[1, 1].grid(True, alpha=0.3)
    axes[1, 1].set_ylim([0, 1])
    
    plt.tight_layout()
    plt.show()
    
    # Find best layer
    if len(iou_scores) > 0:
        best_layer_idx = np.argmax(iou_scores)
        best_layer = layers[best_layer_idx]
        print(f"\nBest performing layer: {best_layer}")
        print(f"  IoU: {iou_scores[best_layer_idx]:.4f}")
        print(f"  Precision: {precision_scores[best_layer_idx]:.4f}")
        print(f"  Recall: {recall_scores[best_layer_idx]:.4f}")
        print(f"  F1: {f1_scores[best_layer_idx]:.4f}")
        return best_layer
    else:
        print("No voxel reconstruction results found")
        return None

In [None]:
# Analyze voxel reconstruction experiment
voxel_exp = None
for name, exp in experiments.items():
    if exp.task == 'voxel_reconstruction' and exp.model == 'dinov2':
        voxel_exp = exp
        break

if voxel_exp:
    print(f"Analyzing: {voxel_exp.name}")
    best_layer = plot_voxel_performance(voxel_exp, loader)
else:
    print("No DINOv2 voxel reconstruction experiment found")

### 2.2 Qualitative Analysis - 3D Visualization

In [None]:
def visualize_voxel_reconstruction(probe_file_path: str, features_file_path: str, 
                                   ground_truth_voxel_data: np.ndarray, 
                                   sample_idx: int = 0):
    """Visualize predicted vs ground truth voxels in 3D"""
    # Load trained probe
    probe_state = torch.load(probe_file_path, map_location=device)
    
    # Create probe model (assuming standard config)
    probe_config = {
        'type': 'voxel',
        'input_dim': 18624, 
        'voxel_resolution': 32
    }
    probe = create_probe(probe_config)
    probe.load_state_dict(probe_state["model_state_dict"])
    probe.to(device)
    probe.eval()
    
    # Load features
    with open(features_file_path, 'rb') as f:
        features_data = pickle.load(f)
    
    features_data = features_data["view_data"]
        
   
    # Extract single sample
    if isinstance(features_data, dict) and 'features' in features_data:
        features = features_data['features'][sample_idx:sample_idx+1]
    else:
        features = features_data[sample_idx:sample_idx+1]
    
    features = torch.tensor(features).to(device)
    features = features.view(features.size(0), -1)
    
    # Generate prediction
    with torch.no_grad():
        pred_logits = probe(features)
        pred_voxels = torch.sigmoid(pred_logits) > 0.5
        pred_voxels = pred_voxels.squeeze().cpu().numpy()
    
    # Create 3D visualization using plotly
    fig = go.Figure()
    
    # Ground truth voxels (blue)
    gt_points = np.argwhere(ground_truth_voxel_data)
    if len(gt_points) > 0:
        fig.add_trace(go.Scatter3d(
            x=gt_points[:, 0],
            y=gt_points[:, 1],
            z=gt_points[:, 2],
            mode='markers',
            marker=dict(size=3, color='blue', opacity=0.6),
            name='Ground Truth'
        ))
    
    # Predicted voxels (red)
    pred_points = np.argwhere(pred_voxels)
    if len(pred_points) > 0:
        fig.add_trace(go.Scatter3d(
            x=pred_points[:, 0] + 35,  # Offset for side-by-side view
            y=pred_points[:, 1],
            z=pred_points[:, 2],
            mode='markers',
            marker=dict(size=3, color='red', opacity=0.6),
            name='Predicted'
        ))
    
    # Update layout
    fig.update_layout(
        title=f'Voxel Reconstruction - Sample {sample_idx}',
        scene=dict(
            xaxis=dict(range=[0, 70]),
            yaxis=dict(range=[0, 32]),
            zaxis=dict(range=[0, 32]),
            aspectmode='data'
        ),
        width=900,
        height=600
    )
    
    fig.show()
    
    # Calculate metrics for this sample
    intersection = np.logical_and(ground_truth_voxel_data, pred_voxels).sum()
    union = np.logical_or(ground_truth_voxel_data, pred_voxels).sum()
    iou = intersection / (union + 1e-6)
    
    print(f"\nSample {sample_idx} metrics:")
    print(f"  IoU: {iou:.4f}")
    print(f"  GT voxels: {ground_truth_voxel_data.sum()}")
    print(f"  Predicted voxels: {pred_voxels.sum()}")

In [None]:
# Load ground truth data for visualization
# Note: This requires the dataset configuration
from omegaconf import OmegaConf
from hydra import compose, initialize_config_dir
from hydra.core.global_hydra import GlobalHydra
import traceback
# Clear any existing hydra instance
GlobalHydra.instance().clear()

# Load existing config instead of creating from scratch
config_dir = str(Path("../configs").resolve())
try:
    with initialize_config_dir(config_dir=config_dir, version_base=None):
        # Load the existing voxel dataset config
        cfg = compose(config_name="datasets/shapenet_voxel_meshes").datasets

        # Override only what's needed for this test
        cfg["categories"] = ["chair"]
        cfg["dataloader"]= {'batch_size': 1, 'num_workers': 0}
        
        # Load test dataloader
        _, _, test_loader = create_3dr2n2_reconstruction_dataloaders(
            cfg, batch_size=1, num_workers=0
        )
        
        # Get a sample
        sample = next(iter(test_loader))
        gt_voxels = sample['voxel_gt'][0].squeeze().numpy()
        
        best_layer = 2
        # Visualize using best layer
        if voxel_exp and best_layer is not None:
            # Get probe and feature files using the loader methods
            probe_files = loader.get_probe_files(voxel_exp)
            feature_files = loader.get_feature_files(voxel_exp)
            
            if 'voxel' in probe_files and best_layer in probe_files['voxel']:
                probe_path = probe_files['voxel'][2]
                features_path = feature_files[best_layer]['test']
                
                print(f"\nVisualizing reconstruction from layer {best_layer}")
                visualize_voxel_reconstruction(str(probe_path), str(features_path), gt_voxels, sample_idx=0)
            else:
                print(f"No voxel probe found for layer {best_layer}")
                
except Exception as e:
    print(f"Could not load dataset for visualization: {e}")
    print(traceback.format_exc())
    print("Please ensure ShapeNet datasets are available")
finally:
    # Clear hydra again
    GlobalHydra.instance().clear()

## 3. Viewpoint Estimation Analysis

We compare how DINOv2 and I-JEPA encode viewpoint information across their layers.

### 3.1 Quantitative Analysis

In [None]:
def plot_viewpoint_performance(experiment_data_list: List[ExperimentResult]):
    """Compare viewpoint estimation performance across models and probe types"""
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle('Viewpoint Estimation Performance Comparison', fontsize=16)
    
    colors = {'dinov2': 'royalblue', 'ijepa': 'forestgreen'}
    markers = {'linear': 'o', 'mlp': 's'}
    
    for exp_data in experiment_data_list:
        model_name = exp_data.model
        metrics_dict = exp_data.results.get('results', {})
        
        # Organize data by probe type
        probe_data = {'linear': {'layers': [], 'mae': [], 'angular_dist': []},
                      'mlp': {'layers': [], 'mae': [], 'angular_dist': []}}
        
        for layer_key, layer_data in metrics_dict.items():
            if layer_key.startswith('layer_'):
                layer_num = int(layer_key.split('_')[1])
                
                for probe_type in ['linear', 'mlp']:
                    if probe_type in layer_data:
                        test_metrics = layer_data[probe_type].get('test_metrics', {})
                        probe_data[probe_type]['layers'].append(layer_num)
                        probe_data[probe_type]['mae'].append(test_metrics.get('mae', 0))
                        probe_data[probe_type]['angular_dist'].append(
                            test_metrics.get('angular_distance_mean', 0)
                        )
        
        # Plot MAE comparison
        ax_mae = axes[0, 0] if model_name == 'dinov2' else axes[0, 1]
        for probe_type in ['linear', 'mlp']:
            if probe_data[probe_type]['layers']:
                sorted_idx = np.argsort(probe_data[probe_type]['layers'])
                layers = np.array(probe_data[probe_type]['layers'])[sorted_idx]
                mae = np.array(probe_data[probe_type]['mae'])[sorted_idx]
                
                ax_mae.plot(layers, mae, 
                           marker=markers[probe_type],
                           linewidth=2, markersize=8,
                           label=f'{probe_type.capitalize()} Probe',
                           color=colors[model_name],
                           alpha=0.8 if probe_type == 'mlp' else 1.0)
        
        ax_mae.set_xlabel('Layer')
        ax_mae.set_ylabel('MAE')
        ax_mae.set_title(f'{model_name.upper()} - Mean Absolute Error')
        ax_mae.legend()
        ax_mae.grid(True, alpha=0.3)
        
        # Plot Angular Distance
        ax_ang = axes[1, 0] if model_name == 'dinov2' else axes[1, 1]
        for probe_type in ['linear', 'mlp']:
            if probe_data[probe_type]['layers']:
                sorted_idx = np.argsort(probe_data[probe_type]['layers'])
                layers = np.array(probe_data[probe_type]['layers'])[sorted_idx]
                angular = np.array(probe_data[probe_type]['angular_dist'])[sorted_idx]
                
                ax_ang.plot(layers, angular,
                           marker=markers[probe_type],
                           linewidth=2, markersize=8,
                           label=f'{probe_type.capitalize()} Probe',
                           color=colors[model_name],
                           alpha=0.8 if probe_type == 'mlp' else 1.0)
        
        ax_ang.set_xlabel('Layer')
        ax_ang.set_ylabel('Angular Distance (degrees)')
        ax_ang.set_title(f'{model_name.upper()} - Mean Angular Distance')
        ax_ang.legend()
        ax_ang.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

In [None]:
def plot_linearity_gap(experiment_data_list: List[ExperimentResult]):
    """Analyze the gap between MLP and linear probe performance"""
    plt.figure(figsize=(10, 6))
    
    colors = {'dinov2': 'royalblue', 'ijepa': 'forestgreen'}
    
    for exp_data in experiment_data_list:
        model_name = exp_data.model
        metrics_dict = exp_data.results.get('results', {})
        
        layers = []
        linearity_gaps = []
        
        for layer_key, layer_data in metrics_dict.items():
            if layer_key.startswith('layer_'):
                layer_num = int(layer_key.split('_')[1])
                
                # Get MAE for both probe types
                linear_mae = layer_data.get('linear', {}).get('test_metrics', {}).get('mae', None)
                mlp_mae = layer_data.get('mlp', {}).get('test_metrics', {}).get('mae', None)
                
                if linear_mae is not None and mlp_mae is not None:
                    layers.append(layer_num)
                    # Linearity gap: how much better is MLP than linear
                    gap = linear_mae - mlp_mae  # Positive means MLP is better
                    linearity_gaps.append(gap)
        
        if layers:
            sorted_idx = np.argsort(layers)
            layers = np.array(layers)[sorted_idx]
            linearity_gaps = np.array(linearity_gaps)[sorted_idx]
            
            plt.plot(layers, linearity_gaps,
                    'o-', linewidth=2.5, markersize=8,
                    label=model_name.upper(),
                    color=colors[model_name])
    
    plt.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
    plt.xlabel('Layer')
    plt.ylabel('Linearity Gap (Linear MAE - MLP MAE)')
    plt.title('Linearity of Viewpoint Representation')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Add annotation
    plt.text(0.02, 0.98, 'Higher = Less Linear\n(MLP much better than Linear)',
             transform=plt.gca().transAxes,
             verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    plt.tight_layout()
    plt.show()

In [None]:
# Get viewpoint experiments
viewpoint_exps = [exp for exp in experiments.values() 
                  if exp.task == 'viewpoint_estimation']

if viewpoint_exps:
    print(f"Found {len(viewpoint_exps)} viewpoint experiments")
    plot_viewpoint_performance(viewpoint_exps)
    plot_linearity_gap(viewpoint_exps)
else:
    print("No viewpoint estimation experiments found")

### 3.2 Qualitative Analysis - Prediction Visualization

In [None]:
def plot_predicted_vs_true_viewpoint(probe: nn.Module, features: torch.Tensor, 
                                     targets: torch.Tensor, device: str):
    """Create scatter plots of predicted vs true viewpoint parameters"""
    probe.eval()
    
    with torch.no_grad():
        predictions = probe(features.to(device)).cpu().numpy()
    
    targets = targets.numpy()
    
    # Denormalize (assuming normalized to [-1, 1])
    pred_azimuth = (predictions[:, 0] + 1) * 180  # [-1, 1] -> [0, 360]
    true_azimuth = (targets[:, 0] + 1) * 180
    pred_elevation = predictions[:, 1] * 90  # [-1, 1] -> [-90, 90]
    true_elevation = targets[:, 1] * 90
    
    # Calculate errors
    az_error = np.abs(pred_azimuth - true_azimuth)
    az_error = np.minimum(az_error, 360 - az_error)  # Handle wraparound
    el_error = np.abs(pred_elevation - true_elevation)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    # Azimuth scatter
    scatter1 = ax1.scatter(true_azimuth, pred_azimuth, c=az_error, 
                          cmap='viridis', alpha=0.6, s=20)
    ax1.plot([0, 360], [0, 360], 'r--', alpha=0.5)
    ax1.set_xlabel('True Azimuth (degrees)')
    ax1.set_ylabel('Predicted Azimuth (degrees)')
    ax1.set_title('Azimuth Predictions')
    ax1.grid(True, alpha=0.3)
    cbar1 = plt.colorbar(scatter1, ax=ax1)
    cbar1.set_label('Absolute Error (degrees)')
    
    # Elevation scatter
    scatter2 = ax2.scatter(true_elevation, pred_elevation, c=el_error,
                          cmap='viridis', alpha=0.6, s=20)
    ax2.plot([-90, 90], [-90, 90], 'r--', alpha=0.5)
    ax2.set_xlabel('True Elevation (degrees)')
    ax2.set_ylabel('Predicted Elevation (degrees)')
    ax2.set_title('Elevation Predictions')
    ax2.grid(True, alpha=0.3)
    cbar2 = plt.colorbar(scatter2, ax=ax2)
    cbar2.set_label('Absolute Error (degrees)')
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    print(f"\nPrediction Statistics:")
    print(f"  Azimuth MAE: {np.mean(az_error):.2f}°")
    print(f"  Elevation MAE: {np.mean(el_error):.2f}°")
    print(f"  Combined Angular Distance: {np.mean(np.sqrt(az_error**2 + el_error**2)):.2f}°")

In [None]:
def plot_viewpoint_error_sphere(true_viewpoints: np.ndarray, pred_viewpoints: np.ndarray):
    """3D visualization of viewpoint errors on a sphere"""
    # Convert to spherical coordinates
    def viewpoint_to_xyz(azimuth, elevation, radius=1):
        az_rad = np.deg2rad(azimuth)
        el_rad = np.deg2rad(elevation)
        x = radius * np.cos(el_rad) * np.cos(az_rad)
        y = radius * np.cos(el_rad) * np.sin(az_rad)
        z = radius * np.sin(el_rad)
        return x, y, z
    
    # Denormalize viewpoints
    true_az = (true_viewpoints[:, 0] + 1) * 180
    true_el = true_viewpoints[:, 1] * 90
    pred_az = (pred_viewpoints[:, 0] + 1) * 180
    pred_el = pred_viewpoints[:, 1] * 90
    
    # Convert to 3D coordinates
    true_x, true_y, true_z = viewpoint_to_xyz(true_az, true_el)
    pred_x, pred_y, pred_z = viewpoint_to_xyz(pred_az, pred_el)
    
    # Calculate errors
    errors = np.sqrt((pred_x - true_x)**2 + (pred_y - true_y)**2 + (pred_z - true_z)**2)
    
    # Create plotly figure
    fig = go.Figure()
    
    # Add reference sphere
    u = np.linspace(0, 2 * np.pi, 50)
    v = np.linspace(0, np.pi, 50)
    sphere_x = np.outer(np.cos(u), np.sin(v))
    sphere_y = np.outer(np.sin(u), np.sin(v))
    sphere_z = np.outer(np.ones(np.size(u)), np.cos(v))
    
    fig.add_trace(go.Surface(
        x=sphere_x, y=sphere_y, z=sphere_z,
        opacity=0.2, colorscale='gray',
        showscale=False, name='Unit Sphere'
    ))
    
    # Add error vectors
    # Sample subset for clarity
    n_samples = min(200, len(true_x))
    indices = np.random.choice(len(true_x), n_samples, replace=False)
    
    for i in indices:
        fig.add_trace(go.Scatter3d(
            x=[true_x[i], pred_x[i]],
            y=[true_y[i], pred_y[i]],
            z=[true_z[i], pred_z[i]],
            mode='lines+markers',
            line=dict(color=errors[i], colorscale='Viridis', width=3),
            marker=dict(size=[4, 6], color=['blue', 'red']),
            showlegend=False
        ))
    
    # Update layout
    fig.update_layout(
        title='Viewpoint Prediction Errors on Unit Sphere',
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z',
            aspectmode='cube'
        ),
        width=800,
        height=800
    )
    
    # Add legend manually
    fig.add_trace(go.Scatter3d(
        x=[None], y=[None], z=[None],
        mode='markers',
        marker=dict(size=6, color='blue'),
        name='True Viewpoint'
    ))
    fig.add_trace(go.Scatter3d(
        x=[None], y=[None], z=[None],
        mode='markers',
        marker=dict(size=6, color='red'),
        name='Predicted Viewpoint'
    ))
    
    fig.show()

In [None]:

dinov2_viewpoint = None
for exp in viewpoint_exps:
    if exp.model == 'dinov2':
        dinov2_viewpoint = exp
        break

if dinov2_viewpoint:
    # Find best performing layer (lowest MAE)
    best_layer = None
    best_mae = float('inf')
    best_probe_type = None
    
    results = dinov2_viewpoint.results.get('results', {})
    for layer_key, layer_data in results.items():
        if layer_key.startswith('layer_'):
            layer_num = int(layer_key.split('_')[1])
            for probe_type in ['linear', 'mlp']:
                if probe_type in layer_data:
                    mae = layer_data[probe_type].get('test_metrics', {}).get('mae', float('inf'))
                    if mae < best_mae:
                        best_mae = mae
                        best_layer = layer_num
                        best_probe_type = probe_type
    
    print(f"Best layer for viewpoint: Layer {best_layer} ({best_probe_type} probe, MAE: {best_mae:.4f})")
    
    # Load probe and features
    if best_layer is not None:
        try:
            # Get probe and feature files using the loader methods
            probe_files = loader.get_probe_files(dinov2_viewpoint)
            feature_files = loader.get_feature_files(dinov2_viewpoint)
            
            if best_probe_type in probe_files and best_layer in probe_files[best_probe_type]:
                probe_path = probe_files[best_probe_type][best_layer]
                features_path = feature_files[best_layer]['test']
                
                # Load probe
                probe_state = torch.load(probe_path, map_location=device)
                probe_config = {
                    'type': best_probe_type,
                    'input_dim': 768,  
                    'output_dim': 2,
                    'hidden_dims': [256] if best_probe_type == 'mlp' else None,
                    'task_type': 'regression'
                }
                original_probe_config = probe_state['probe_config']
                probe = create_probe(original_probe_config)
                probe.load_state_dict(probe_state["model_state_dict"])
                probe.to(device)
                
                # Load features
                with open(features_path, 'rb') as f:
                    test_data = pickle.load(f)
                
                features = torch.tensor(test_data['features'][:500])  # Use subset
                targets = torch.tensor(test_data['targets'][:500])
                
                # Visualize predictions
                plot_predicted_vs_true_viewpoint(probe, features, targets, device)
                
                # Get predictions for sphere visualization
                with torch.no_grad():
                    predictions = probe(features.to(device)).cpu().numpy()
                
                plot_viewpoint_error_sphere(targets.numpy(), predictions)
            else:
                print(f"No {best_probe_type} probe found for layer {best_layer}")
                
        except Exception as e:
            print(f"Could not visualize predictions: {e}")
            print(traceback.print_exc())

## 4. Advanced Feature Analysis

We perform deeper analysis of learned representations, focusing on DINOv2.

### 4.1 Latent Space Cartography

In [None]:
def analyze_latent_space(features_file_path: str, metadata_file_path: Optional[str] = None,
                         dataset_split: str = 'test'):
    """Analyze feature space organization using dimensionality reduction"""
    with open(features_file_path, 'rb') as f:
        data = pickle.load(f)
    
    if isinstance(data, dict):
        features = data['features']
        categories = data.get('metadata', {}).get('categories', [])
        viewpoints = data.get('targets', None)
    else:
        features = data
        categories = []
        viewpoints = None
    
    n_samples = min(5000, len(features))
    indices = np.random.choice(len(features), n_samples, replace=False)
    features = features[indices]
    
    if len(categories) > 0:
        categories = [categories[i] for i in indices]
    if viewpoints is not None:
        viewpoints = viewpoints[indices]
    
    print(f"Analyzing {n_samples} samples...")
    
    # Perform PCA
    print("Running PCA...")
    pca = PCA(n_components=2)
    features_pca = pca.fit_transform(features)
    print(f"PCA explained variance: {pca.explained_variance_ratio_.sum():.3f}")
    
    # Perform t-SNE
    print("Running t-SNE...")
    tsne = TSNE(n_components=2, perplexity=30, n_iter=300)
    features_tsne = tsne.fit_transform(features)
    
    # Create visualizations
    fig, axes = plt.subplots(2, 2, figsize=(15, 15))
    
    # PCA colored by category
    if categories:
        unique_cats = list(set(categories))
        cat_to_idx = {cat: i for i, cat in enumerate(unique_cats)}
        cat_indices = [cat_to_idx[cat] for cat in categories]
        
        scatter1 = axes[0, 0].scatter(features_pca[:, 0], features_pca[:, 1],
                                     c=cat_indices, cmap='tab20', alpha=0.6, s=10)
        axes[0, 0].set_title('PCA - Colored by Object Category')
        axes[0, 0].set_xlabel('PC 1')
        axes[0, 0].set_ylabel('PC 2')
        
        # t-SNE colored by category
        axes[1, 0].scatter(features_tsne[:, 0], features_tsne[:, 1],
                          c=cat_indices, cmap='tab20', alpha=0.6, s=10)
        axes[1, 0].set_title('t-SNE - Colored by Object Category')
        axes[1, 0].set_xlabel('t-SNE 1')
        axes[1, 0].set_ylabel('t-SNE 2')
    
    # PCA colored by viewpoint
    if viewpoints is not None:
        azimuth = (viewpoints[:, 0] + 1) * 180  # Denormalize
        
        scatter2 = axes[0, 1].scatter(features_pca[:, 0], features_pca[:, 1],
                                     c=azimuth, cmap='hsv', alpha=0.6, s=10)
        axes[0, 1].set_title('PCA - Colored by Azimuth')
        axes[0, 1].set_xlabel('PC 1')
        axes[0, 1].set_ylabel('PC 2')
        cbar1 = plt.colorbar(scatter2, ax=axes[0, 1])
        cbar1.set_label('Azimuth (degrees)')
        
        # t-SNE colored by viewpoint
        scatter3 = axes[1, 1].scatter(features_tsne[:, 0], features_tsne[:, 1],
                                     c=azimuth, cmap='hsv', alpha=0.6, s=10)
        axes[1, 1].set_title('t-SNE - Colored by Azimuth')
        axes[1, 1].set_xlabel('t-SNE 1')
        axes[1, 1].set_ylabel('t-SNE 2')
        cbar2 = plt.colorbar(scatter3, ax=axes[1, 1])
        cbar2.set_label('Azimuth (degrees)')
    
    plt.tight_layout()
    plt.show()

In [None]:
if dinov2_viewpoint and best_layer is not None:
    feature_files = loader.get_feature_files(dinov2_viewpoint)
    if best_layer in feature_files and 'test' in feature_files[best_layer]:
        features_path = feature_files[best_layer]['test']
        print(f"\nAnalyzing latent space for DINOv2 layer {best_layer}")
        analyze_latent_space(str(features_path))
    else:
        print(f"No features found for layer {best_layer}")
else:
    print("No DINOv2 viewpoint experiment or best layer found")

### 4.2 Transformer Attention Analysis

Note: DINOv2 models don't always expose attention weights directly. We use a feature similarity approach as an alternative visualization method.

In [None]:
import torch.nn.functional as F

def visualize_attention_maps(
    image_tensor: torch.Tensor,
    model_name: str,
    device: torch.device,
    method: str = 'aggregate', 
    layer: int = -1,
    head: int = None,
):
    """
    Visualize attention maps for ViT models.

    Args:
        image_tensor (torch.Tensor): The input image tensor.
        model_name (str): The name of the model.
        device (torch.device): The device to run the model on.
        method (str): Visualization method.
                      'cls': Standard CLS token attention (for classifiers).
                      'aggregate': Aggregated attention received by patches
                                   (for GAP/regression models like I-JEPA).
        layer (int): The transformer layer to visualize. Default is -1 (last layer).
        head (int): The attention head to visualize. Default is None (average heads).
    """
    model, predecessor = load_model_and_preprocessor( model_name, ckpt_path=None, device=device, cache_dir=None)
    model.to(device).eval()
    if image_tensor.dim() == 3:
        image_tensor = image_tensor.unsqueeze(0)
    inputs_on_device = image_tensor.to(device)

    all_attentions = []
    with torch.no_grad():
        if 'timm' not in model_name and model_name != 'ijepa':
            # HuggingFace logic
            outputs = model(inputs_on_device, output_attentions=True)
            all_attentions = outputs.attentions
        else: # TIMM logic
            attention_maps_by_layer = {}
            hook_handles = []
            original_fused_attn_states = {}
            def get_hook(layer_idx):
                def hook_fn(module, input_args, output_tensor):
                    attention_maps_by_layer[layer_idx] = input_args[0].detach()
                return hook_fn
            if not hasattr(model, 'blocks') or not model.blocks: return
            for i, block in enumerate(model.blocks):
                attn_module = block.attn
                if hasattr(attn_module, 'fused_attn'):
                    original_fused_attn_states[i] = attn_module.fused_attn
                    attn_module.fused_attn = False
                handle = attn_module.attn_drop.register_forward_hook(get_hook(i))
                hook_handles.append(handle)
            _ = model(inputs_on_device)
            for handle in hook_handles: handle.remove()
            for i, state in original_fused_attn_states.items(): model.blocks[i].attn.fused_attn = state
            num_heads = model.blocks[0].attn.num_heads
            batch_size = inputs_on_device.shape[0]
            for i in sorted(attention_maps_by_layer.keys()):
                attn_probs = attention_maps_by_layer[i]
                if attn_probs.ndim == 3:
                    num_tokens = attn_probs.shape[-1]
                    attn_probs = attn_probs.view(batch_size, num_heads, num_tokens, num_tokens)
                all_attentions.append(attn_probs)

    if not all_attentions:
        print("Failed to extract any attention maps.")
        return

    attn_from_layer = all_attentions[layer]
    if head is not None:
        attn_to_vis = attn_from_layer[:, head] # Select specific head
        title_head_info = f"Head {head}"
    else:
        attn_to_vis = attn_from_layer.mean(dim=1) # Average over heads
        title_head_info = "Averaged Heads"

    attention_vector = None
    plot_main_title = ""

    if method == 'cls':
        attention_vector = attn_to_vis[0, 0, 1:]
        plot_main_title = "CLS Token Attention"
    
    elif method == 'aggregate':
        patch_attn = attn_to_vis[0, 1:, 1:] 
        attention_vector = patch_attn.sum(dim=0) 
        plot_main_title = "Aggregated Patch Attention"

    num_patch_tokens = attention_vector.shape[-1]
    patch_grid_size = int(np.sqrt(num_patch_tokens))
    
    if patch_grid_size * patch_grid_size == num_patch_tokens:
        attention_map_reshaped = attention_vector.reshape(patch_grid_size, patch_grid_size)
    elif hasattr(model, 'patch_embed') and hasattr(model.patch_embed, 'grid_size'):
        grid_h, grid_w = model.patch_embed.grid_size
        expected_tokens = grid_h * grid_w
        if num_patch_tokens != expected_tokens:
            padding_needed = expected_tokens - num_patch_tokens
            if padding_needed > 0:
                padded_attention = F.pad(attention_vector, (0, padding_needed), value=0)
                attention_map_reshaped = padded_attention.reshape(grid_h, grid_w)
            else:
                attention_map_reshaped = attention_vector[:expected_tokens].reshape(grid_h, grid_w)
        else:
            attention_map_reshaped = attention_vector.reshape(grid_h, grid_w)
    else:
        raise RuntimeError("Cannot determine grid size for non-square patch number.")
                


    attention_map_reshaped = attention_map_reshaped.cpu().numpy()
    attention_map_reshaped = (attention_map_reshaped - attention_map_reshaped.min()) / (attention_map_reshaped.max() - attention_map_reshaped.min() + 1e-6)
    
    attn_resized = F.interpolate(
        torch.tensor(attention_map_reshaped).unsqueeze(0).unsqueeze(0),
        size=(image_tensor.shape[-2], image_tensor.shape[-1]),
        mode='bilinear',
        align_corners=False
    ).squeeze().cpu().numpy()


    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
    layer_info = f"Layer {layer if layer != -1 else len(all_attentions) - 1}"
    fig.suptitle(f'{plot_main_title} for {model_name}\n({layer_info}, {title_head_info})', fontsize=16)

    img_for_show = image_tensor[0].permute(1, 2, 0).cpu().numpy()
    min_val, max_val = img_for_show.min(), img_for_show.max()
    img_for_show = (img_for_show - min_val) / (max_val - min_val + 1e-6)
    ax1.imshow(img_for_show); ax1.set_title('Original Image'); ax1.axis('off')
    im = ax2.imshow(attn_resized, cmap='hot'); ax2.set_title('Attention Map'); ax2.axis('off')
    fig.colorbar(im, ax=ax2, fraction=0.046, pad=0.04)
    ax3.imshow(img_for_show); ax3.imshow(attn_resized, cmap='hot', alpha=0.5); ax3.set_title('Overlay'); ax3.axis('off')
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()


In [None]:
from omegaconf import OmegaConf
dataset_config = OmegaConf.load('../configs/datasets/shapenet_3dr2n2.yaml')

# Override categories to just use chair for quick testing
dataset_config.categories = ['chair']

_, _, test_loader = create_3dr2n2_dataloaders(
    dataset_config, batch_size=1, num_workers=0
)

# Get a sample image
sample = next(iter(test_loader))
image = sample['image'][0]



In [None]:
# Test attention visualization with a sample image
try:
    # Load existing dataset configuration

    print("Visualizing DINOv2 attention...")
    visualize_attention_maps(image, model_name='dinov2', layer=11, head=1, device=device, method="cls")
    
    print("\nVisualizing I-JEPA (TIMM) attention...")
    visualize_attention_maps(image, model_name='ijepa', layer=11, device=device,method="aggregate")

    
except Exception as e:
    print(f"Could not visualize attention: {e}")
    traceback.print_exc()