# Probing Experiment on SSL Models

This is effectively a notebook-ized version of the old experiment runner script. It compartmentalizes everything so we don't lose state between small errors.

### Imports, Logging Setup

In [1]:
# Set environment variables before imports
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

# Imports
import hydra
from omegaconf import DictConfig, OmegaConf
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from pathlib import Path
import logging
import wandb
from typing import Dict, List, Tuple, Optional
from tqdm import tqdm


from src.models.feature_extractor import FeatureExtractor, load_feature_extractor
from src.datasets.shapenet_3dr2n2 import create_3dr2n2_dataloaders
from src.probing.probes import create_probe, ProbeTrainer
from src.probing.data_preprocessing import (
    FeatureExtractorPipeline,
    create_probing_dataloaders,
    ProbingDataset,
)
from src.probing.metrics import (
    compute_regression_metrics,
    compute_viewpoint_specific_metrics,
    MetricsTracker,
)
from src.analysis.layer_analysis import LayerWiseAnalyzer

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")

### Probing Setup
This class is the overarching "manager" that is responsible for the entire experiment. It contains all the functionalities required to:

- Create & setup dataloaders 
- Extract features from the frozen layers of the ViT models 
- Train MLP & Linear probes on those layers 
- Summarize results

In [2]:
class ProbingExperiment:
    """Orchestrates probing experiments"""

    def __init__(self, config: DictConfig):
        self.config = config
        # Determine device: prioritize models.device, then top-level device, then auto-detect
        device_to_use = config.models.get("device", config.get("device"))
        if device_to_use:
            self.device = device_to_use
        else:
            self.device = (
                "cuda"
                if torch.cuda.is_available()
                else "mps" if torch.backends.mps.is_available() else "cpu"
            )
        logger.info(f"Using device: {self.device}")

        # Initialize wandb
        if config.get("wandb", {}).get("enabled", False):
            wandb.init(
                project=config.wandb.project,
                entity=config.wandb.get("entity"),
                name=config.experiment.name,
                config=OmegaConf.to_container(config, resolve=True),
            )

        # Setup paths
        self.results_dir = Path(config.get("results_dir", "./results"))
        self.results_dir.mkdir(parents=True, exist_ok=True)
        self.cache_dir = Path(config.get("cache_dir", "./cache"))
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        
        # Setup probe save directory
        self.probe_save_dir = self.cache_dir / "probes" / self.config.experiment.name
        self.probe_save_dir.mkdir(parents=True, exist_ok=True)

        # Initialize analyzer
        self.analyzer = LayerWiseAnalyzer(self.results_dir / config.experiment.name)

   
    def load_dataset(self) -> Tuple[DataLoader, DataLoader, DataLoader]:
        """Load the dataset"""
        subset_percentage = self.config.datasets.get("subset_percentage", None)
        return create_3dr2n2_dataloaders(
            self.config.datasets, subset_percentage=subset_percentage
        )

    def load_feature_extractor(self) -> FeatureExtractor:
        """Load and setup feature extractor"""
        model_config = self.config.models
        model_config.device = self.device
        model_config.cache_dir = str(self.cache_dir / "models")

        feature_extractor = load_feature_extractor(OmegaConf.to_container(model_config))
        logger.info(f"Loaded {model_config.model_name} feature extractor")
        return feature_extractor

    def extract_features_for_layer(
        self,
        feature_extractor: FeatureExtractor,
        train_loader: DataLoader,
        val_loader: DataLoader,
        test_loader: DataLoader,
        layer: int,
        feature_type: str,
        task_type: str,
    ) -> Tuple[ProbingDataset, ProbingDataset, ProbingDataset]:
        """Extract features for a specific layer"""
        pipeline = FeatureExtractorPipeline(
            feature_extractor=feature_extractor,
            device=self.device,
            batch_size=self.config.get("extraction_batch_size", 32),
            cache_dir=str(self.cache_dir / "features"),
        )

        experiment_name = f"{self.config.models.model_name}_{self.config.experiment.name}_layer_{layer}"

        return pipeline.create_probing_datasets(
            train_loader=train_loader,
            val_loader=val_loader,
            test_loader=test_loader,
            layers=[layer],
            feature_type=feature_type,
            task_type=task_type,
            experiment_name=experiment_name,
        )

    def run_probe_experiment(
        self,
        probe_type: str,
        train_loader: DataLoader,
        val_loader: DataLoader,
        test_loader: DataLoader,
        feature_dim: int,
        layer: int,
    ) -> Dict:
        """Run a single probe experiment"""

        logger.info(
            f"Running {probe_type} probe on layer {layer} (feature_dim: {feature_dim})"
        )

        # Get probe configuration
        probe_config = self.config.probing.get(probe_type, {})
        # Make a mutable copy for modification
        probe_config = OmegaConf.to_container(probe_config, resolve=True)

        # Create probe
        probe_config["input_dim"] = feature_dim
        probe_config["output_dim"] = self.config.probing.get("output_dim", 2)

        main_task_type = self.config.probing.get("task_type", "regression")
        if main_task_type == "viewpoint_regression":
            probe_config["task_type"] = "regression"
        elif main_task_type == "view_classification":
            probe_config["task_type"] = "classification"
        else:
            probe_config["task_type"] = main_task_type

        probe = create_probe(probe_config)

        # Setup trainer
        trainer = ProbeTrainer(probe, device=self.device)

        # Setup optimizer and scheduler
        training_config = probe_config.get("training", {})
        optimizer = self.create_optimizer(probe, training_config.get("optimizer", {}))
        scheduler = self.create_scheduler(
            optimizer, training_config.get("scheduler", {})
        )

        # Training parameters"results/phase1_dinov2_viewpoint_probing/results.json"
        epochs = training_config.get("epochs", 30)
        early_stopping_patience = training_config.get("early_stopping_patience", 15)

        metrics_tracker = MetricsTracker()
        trainer = ProbeTrainer(
            probe, device=self.device, MetricsTracker=metrics_tracker
        )

        # Check if wandb is enabled
        wandb_enabled = self.config.get("wandb", {}).get("enabled", False)

        best_model, best_val_loss = trainer.train(
            epochs,
            optimizer,
            scheduler,
            early_stopping_patience,
            train_loader,
            val_loader,
            probe_type=probe_type,
            layer=layer,
            wandb_enabled=wandb_enabled,
        )
        
        # Save the trained probe
        probe_save_dir = self.cache_dir / "probes" / self.config.experiment.name
        probe_save_dir.mkdir(parents=True, exist_ok=True)
        probe_filename = f"{probe_type}_layer_{layer}_probe.pth"
        probe_save_path = probe_save_dir / probe_filename
        
        torch.save({
            'model_state_dict': best_model,  # best_model is already a state_dict
            'probe_config': probe_config,
            'layer': layer,
            'probe_type': probe_type,
            'experiment_name': self.config.experiment.name,
            'model_name': self.config.models.model_name,
            'best_val_loss': best_val_loss,
            'feature_dim': feature_dim
        }, probe_save_path)
        
        logger.info(f"Saved {probe_type} probe for layer {layer} to {probe_save_path}")

        test_metrics = trainer.evaluate(test_loader)

        detailed_metrics = self.compute_detailed_metrics(probe, test_loader)

        total_epochs = len(metrics_tracker.get_history("train"))

        results = {
            "train_history": metrics_tracker.get_history("train"),
            "val_history": metrics_tracker.get_history("val"),
            "test_metrics": test_metrics,
            "detailed_metrics": detailed_metrics,
            "best_epoch": metrics_tracker.best_epoch,
            "total_epochs": total_epochs,
        }

        return results

    def save_probe(self, probe: nn.Module, probe_type: str, layer: int, probe_config: Dict):
        """Save the trained probe model and its configuration"""
        import json
        
        # Create filename with model name, probe type, and layer
        model_name = self.config.models.model_name
        filename = f"{model_name}_{probe_type}_layer_{layer}.pth"
        probe_path = self.probe_save_dir / filename
        
        # Save the probe state dict
        torch.save({
            'model_state_dict': probe.state_dict(),
            'probe_config': probe_config,
            'model_name': model_name,
            'probe_type': probe_type,
            'layer': layer,
            'experiment_name': self.config.experiment.name
        }, probe_path)
        
        # Also save the config as JSON
        config_filename = f"{model_name}_{probe_type}_layer_{layer}_config.json"
        config_path = self.probe_save_dir / config_filename
        
        with open(config_path, 'w') as f:
            json.dump({
                'probe_config': probe_config,
                'model_name': model_name,
                'probe_type': probe_type,
                'layer': layer,
                'experiment_name': self.config.experiment.name
            }, f, indent=2)
        
        logger.info(f"Probe saved to {probe_path}")
        logger.info(f"Probe config saved to {config_path}")

    def load_probe(self, probe_type: str, layer: int, device: Optional[str] = None) -> nn.Module:
        """Load a previously saved probe"""
        if device is None:
            device = self.device
            
        model_name = self.config.models.model_name
        filename = f"{model_name}_{probe_type}_layer_{layer}.pth"
        probe_path = self.probe_save_dir / filename
        
        if not probe_path.exists():
            raise FileNotFoundError(f"Probe not found at {probe_path}")
        
        # Load the saved data
        saved_data = torch.load(probe_path, map_location=device)
        
        # Recreate the probe using the saved config
        probe_config = saved_data['probe_config']
        probe = create_probe(probe_config)
        
        # Load the state dict
        probe.load_state_dict(saved_data['model_state_dict'])
        probe.to(device)
        
        logger.info(f"Probe loaded from {probe_path}")
        return probe

    def create_optimizer(
        self, model: nn.Module, optimizer_config: Dict
    ) -> torch.optim.Optimizer:
        """Create optimizer from config using Hydra instantiate"""
        from hydra.utils import instantiate

        # Create a copy of config and add model parameters
        optimizer_config = optimizer_config.copy()
        optimizer_config["params"] = model.parameters()

        return instantiate(optimizer_config)

    def create_scheduler(
        self, optimizer: torch.optim.Optimizer, scheduler_config: Dict
    ):
        """Create learning rate scheduler from config using Hydra instantiate"""
        if not scheduler_config:
            return None

        from hydra.utils import instantiate

        scheduler_config = scheduler_config.copy()
        scheduler_config["optimizer"] = optimizer

        return instantiate(scheduler_config)

    def compute_detailed_metrics(
        self, probe: nn.Module, test_loader: DataLoader
    ) -> Dict:
        """Compute alles metrics"""
        probe.eval()

        all_predictions = []
        all_targets = []
        all_categories = []

        with torch.no_grad():
            for batch in test_loader:
                features = batch["features"].to(self.device)
                targets = batch["targets"]

                outputs = probe(features)

                all_predictions.append(outputs.cpu())
                all_targets.append(targets)

                # Get categories if available
                if "categories" in batch:
                    all_categories.extend(batch["categories"])

        predictions = torch.cat(all_predictions, dim=0)
        targets = torch.cat(all_targets, dim=0)

        # Basic regression metrics
        metrics = compute_regression_metrics(predictions, targets, return_per_dim=True)

        # Viewpoint-specific metrics
        if predictions.shape[1] == 2:
            viewpoint_metrics = compute_viewpoint_specific_metrics(
                azimuth_pred=predictions[:, 0],
                elevation_pred=predictions[:, 1],
                azimuth_target=targets[:, 0],
                elevation_target=targets[:, 1],
            )
            metrics.update(viewpoint_metrics)

        return metrics

    def save_results(self, results: Dict) -> str:
        """Save results to disk"""
        import json

        # Create experiment directory
        exp_dir = self.results_dir / self.config.experiment.name
        exp_dir.mkdir(parents=True, exist_ok=True)

        # Save results
        results_file = exp_dir / "results.json"

        # Convert tensors to lists for JSON serialization
        serializable_results = self.make_json_serializable(results)

        combined_results = {
            "config": OmegaConf.to_container(self.config, resolve=True),
            "results": serializable_results,
        }

        with open(results_file, "w") as f:
            json.dump(combined_results, f, indent=2)

        logger.info(f"Results saved to {results_file}")
        return results_file

    def make_json_serializable(self, obj):
        """Convert object to JSON-serializable format"""
        if isinstance(obj, dict):
            return {k: self.make_json_serializable(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [self.make_json_serializable(v) for v in obj]
        elif isinstance(obj, (torch.Tensor, np.ndarray)):
            return obj.tolist() if hasattr(obj, "tolist") else float(obj)
        elif isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        else:
            return obj



### Hydra Configuration Loading / Setup

In [3]:
from hydra import initialize, compose
from hydra.core.global_hydra import GlobalHydra
import os 
from pathlib import Path #

CONFIG_PATH = "../configs"
CONFIG_NAME = "experiment_config"

cfg: Optional[DictConfig] = None

if GlobalHydra.instance().is_initialized():
    logger.info("Clearing existing Hydra global state.")
    GlobalHydra.instance().clear()

try:
    project_root = Path(os.getcwd()).parent 
    data_dir_abs = project_root / "data"
    
    os.environ["DATA_DIR"] = str(data_dir_abs)

    logger.info(f"Initializing Hydra with config_path: '{CONFIG_PATH}'")
    
    initialize(version_base=None, config_path=CONFIG_PATH)
    
    logger.info(f"Composing configuration with config_name: '{CONFIG_NAME}'")
    
    cfg = compose(config_name=CONFIG_NAME)

except Exception as e:
    logger.error(f"Error initializing Hydra or loading configuration: {e}", exc_info=True)

if cfg:
    logger.info("Hydra configuration loaded successfully.")


2025-06-01 20:37:26,470 - __main__ - INFO - Initializing Hydra with config_path: '../configs'
2025-06-01 20:37:26,656 - __main__ - INFO - Composing configuration with config_name: 'experiment_config'
2025-06-01 20:37:26,716 - __main__ - INFO - Hydra configuration loaded successfully.


## Running the Experiment
The following code uses the above configurations and utility functions to run the actual experiment.

In [4]:
results = None
logger.info("Starting experiment execution")
experiment = ProbingExperiment(cfg)
    

2025-06-01 20:37:26,724 - __main__ - INFO - Starting experiment execution
2025-06-01 20:37:26,725 - __main__ - INFO - Using device: mps
2025-06-01 20:37:27,176 - wandb.jupyter - ERROR - Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdruhin-bhowal[0m ([33mcse493g1_drn[0m). Use [1m`wandb login --relogin`[0m to force relogin


### Load the Feature Extractor & Dataset

In [5]:
feature_extractor = experiment.load_feature_extractor()
extraction_config = cfg.models.get("feature_extraction", {})
layers = extraction_config.get("layers", [11])
feature_type = extraction_config.get("feature_type", "cls_token")
task_type = cfg.probing.get("task_type", "viewpoint_regression")

2025-06-01 20:37:31,287 - src.models.feature_extractor - INFO - Loaded dinov2 model on mps
2025-06-01 20:37:31,287 - __main__ - INFO - Loaded dinov2 feature extractor


In [6]:
train_loader, val_loader, test_loader = experiment.load_dataset()

100%|██████████| 30648/30648 [00:30<00:00, 995.22it/s] 


Using 0.05% of train data: 367 samples.


100%|██████████| 6567/6567 [00:06<00:00, 1000.67it/s]


Using 0.05% of val data: 78 samples.


100%|██████████| 6569/6569 [00:06<00:00, 1012.40it/s]

Using 0.05% of test data: 78 samples.





### Train the Probes

In [7]:
results = {}
for layer in tqdm(layers):
    logger.info(f"Processing layer {layer}...")

    # Extract features for this layer
    train_dataset, val_dataset, test_dataset = experiment.extract_features_for_layer(
        feature_extractor,
        train_loader,
        val_loader,
        test_loader,
        layer,
        feature_type,
        task_type,
    )

    # Create probing dataloaders
    probe_train_loader, probe_val_loader, probe_test_loader = (
       create_probing_dataloaders(
            train_dataset,
            val_dataset,
            test_dataset,
            batch_size=cfg.probing.get("training", {}).get(
                "batch_size", 64
            ),
            num_workers=cfg.get("num_workers", 4),
        )
    )

    # Run probing experiments for each probe type
    layer_results = {}
    for probe_type in cfg.probing.probe_types:
        logger.info(f"Running {probe_type} probe on layer {layer}...")
        probe_results = experiment.run_probe_experiment(
            probe_type,
            probe_train_loader,
            probe_val_loader,
            probe_test_loader,
            train_dataset.features.shape[1],
            layer,
        )
        layer_results[probe_type] = probe_results

    results[f"layer_{layer}"] = layer_results

  0%|          | 0/2 [00:00<?, ?it/s]2025-06-01 20:38:15,277 - __main__ - INFO - Processing layer 2...
2025-06-01 20:38:15,279 - src.probing.data_preprocessing - INFO - Loading cached features from cache/features/dinov2_phase1_dinov2_viewpoint_probing_layer_2_train.pkl
2025-06-01 20:38:15,283 - src.probing.data_preprocessing - INFO - TRAIN dataset: 352 samples
2025-06-01 20:38:15,285 - src.probing.data_preprocessing - INFO - Loading cached features from cache/features/dinov2_phase1_dinov2_viewpoint_probing_layer_2_val.pkl
2025-06-01 20:38:15,287 - src.probing.data_preprocessing - INFO - VAL dataset: 78 samples
2025-06-01 20:38:15,288 - src.probing.data_preprocessing - INFO - Loading cached features from cache/features/dinov2_phase1_dinov2_viewpoint_probing_layer_2_test.pkl
2025-06-01 20:38:15,292 - src.probing.data_preprocessing - INFO - TEST dataset: 78 samples
2025-06-01 20:38:15,292 - __main__ - INFO - Running linear probe on layer 2...
2025-06-01 20:38:15,293 - __main__ - INFO - Ru

Epoch 0: train_loss=0.2296, val_loss=0.2422


2025-06-01 20:38:26,097 - __main__ - INFO - Running mlp probe on layer 2...
2025-06-01 20:38:26,098 - __main__ - INFO - Running mlp probe on layer 2 (feature_dim: 768)
Training 1/1: 100%|██████████| 1/1 [00:02<00:00,  2.40s/it]
2025-06-01 20:38:30,786 - __main__ - INFO - Saved mlp probe for layer 2 to cache/probes/phase1_dinov2_viewpoint_probing/mlp_layer_2_probe.pth


Epoch 0: train_loss=0.2306, val_loss=0.2035


 50%|█████     | 1/2 [00:20<00:20, 20.03s/it]2025-06-01 20:38:35,304 - __main__ - INFO - Processing layer 11...
2025-06-01 20:38:35,306 - src.probing.data_preprocessing - INFO - Loading cached features from cache/features/dinov2_phase1_dinov2_viewpoint_probing_layer_11_train.pkl
2025-06-01 20:38:35,308 - src.probing.data_preprocessing - INFO - TRAIN dataset: 352 samples
2025-06-01 20:38:35,309 - src.probing.data_preprocessing - INFO - Loading cached features from cache/features/dinov2_phase1_dinov2_viewpoint_probing_layer_11_val.pkl
2025-06-01 20:38:35,311 - src.probing.data_preprocessing - INFO - VAL dataset: 78 samples
2025-06-01 20:38:35,311 - src.probing.data_preprocessing - INFO - Loading cached features from cache/features/dinov2_phase1_dinov2_viewpoint_probing_layer_11_test.pkl
2025-06-01 20:38:35,312 - src.probing.data_preprocessing - INFO - TEST dataset: 78 samples
2025-06-01 20:38:35,313 - __main__ - INFO - Running linear probe on layer 11...
2025-06-01 20:38:35,313 - __main_

Epoch 0: train_loss=0.7528, val_loss=0.5066


2025-06-01 20:38:44,199 - __main__ - INFO - Running mlp probe on layer 11...
2025-06-01 20:38:44,200 - __main__ - INFO - Running mlp probe on layer 11 (feature_dim: 768)
Training 1/1: 100%|██████████| 1/1 [00:02<00:00,  2.28s/it]
2025-06-01 20:38:48,747 - __main__ - INFO - Saved mlp probe for layer 11 to cache/probes/phase1_dinov2_viewpoint_probing/mlp_layer_11_probe.pth


Epoch 0: train_loss=0.8335, val_loss=1.2173


100%|██████████| 2/2 [00:37<00:00, 18.99s/it]


In [8]:
logger.info("Saving results...")
result_path = experiment.save_results(results)

2025-06-01 20:38:53,268 - __main__ - INFO - Saving results...
2025-06-01 20:38:53,272 - __main__ - INFO - Results saved to results/phase1_dinov2_viewpoint_probing/results.json


In [None]:
from src.analysis.layer_analysis import analyze_experiment_results

logger.info("Creating analysis and visualizations...")
analyze_experiment_results(result_path, output_dir=result_path.parent)

logger.info("Results analyzed! Please see the results and analysis_results folders for the outcomes.")

2025-06-01 20:38:53,288 - __main__ - INFO - Creating analysis and visualizations...
2025-06-01 20:38:55,227 - src.analysis.layer_analysis - INFO - Analysis report saved to analysis_results/layer_analysis_report.json
2025-06-01 20:38:55,228 - __main__ - INFO - Results analyzed! Please see the results and analysis_results folders for the outcomes.
