# Reconstruction Experiment on SSL Models (iJEPa)

This notebook adapts the probing experiment framework to focus on 3D voxel reconstruction using `VoxelProbe`. It will:
- Load a pre-trained DINOv2 model.
- Extract features from specified layers.
- Train `VoxelProbe` instances on these features to predict 3D voxel occupancy.
- Evaluate performance using IoU, Precision, Recall, and F1-score.
- Analyze and visualize results to determine which layers are best for reconstruction.

### Imports, Logging Setup

In [1]:
# Set environment variables before imports
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

# Imports
import hydra
from omegaconf import DictConfig, OmegaConf
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from pathlib import Path
import logging
import wandb
from typing import Dict, List, Tuple, Optional
from tqdm import tqdm

# Model and dataset imports
from src.models.reconstruction_feature_extractor import ReconstructionFeatureExtractor, load_image_feature_extractor 
from src.datasets.shapenet_voxel_meshes import create_3dr2n2_reconstruction_dataloaders 

# Probing imports using new modular structure
from src.probing import (
    create_probe, 
    ProbeTrainer,
    ReconstructionPipeline,
    ReconstructionDataset,
    compute_voxel_metrics,
    MetricsTracker,
)
from src.analysis.layer_analysis import LayerWiseAnalyzer

# Fix duplicate logging issue in Jupyter notebooks
# Clear any existing handlers to prevent duplicates
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
    root_logger.removeHandler(handler)

# Configure logging fresh
logging.basicConfig(
    level=logging.INFO, 
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    force=True  # This forces reconfiguration in newer Python versions
)

# Get the notebook logger
logger = logging.getLogger(__name__)

### Reconstruction Experiment Setup


In [2]:
class ReconstructionExperiment:
    """Orchestrates 3D reconstruction experiments using VoxelProbes"""

    def __init__(self, config: DictConfig):
        self.config = config
        device_to_use = config.get("device", config.get("device"))
        if device_to_use:
            self.device = device_to_use
        else:
            self.device = (
                "cuda"
                if torch.cuda.is_available()
                else "mps" if torch.backends.mps.is_available() else "cpu"
            )
        logger.info(f"Using device: {self.device}")

        if config.get("wandb", {}).get("enabled", False):
            wandb.init(
                project=config.wandb.project,
                entity=config.wandb.get("entity"),
                name=config.experiment.name + "_reconstruction", 
                config=OmegaConf.to_container(config, resolve=True),
            )

        self.results_dir = Path(config.get("results_dir", "./results")) / (config.experiment.name)
        self.results_dir.mkdir(parents=True, exist_ok=True)
        self.cache_dir = Path(config.get("cache_dir", "./cache")) / (config.experiment.name)
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        
        self.probe_save_dir = self.cache_dir / "probes"
        self.probe_save_dir.mkdir(parents=True, exist_ok=True)
        
        self.analyzer = LayerWiseAnalyzer(self.results_dir)

    def load_source_dataset(self) -> Tuple[DataLoader, DataLoader, DataLoader]:
        """Load the source ShapeNet dataset with images and voxels."""
        subset_percentage = self.config.datasets.get("subset_percentage", None)
        return create_3dr2n2_reconstruction_dataloaders( 
            self.config.datasets, 
            batch_size=self.config.datasets.get("source_batch_size", 16), 
            num_workers=self.config.get("num_workers", 4),
            subset_percentage=subset_percentage
        )

    def load_reconstruction_feature_extractor(self) -> ReconstructionFeatureExtractor:
        """Load and setup ReconstructionFeatureExtractor"""
        model_config = self.config.models
        model_config_dict = OmegaConf.to_container(model_config, resolve=True)
        model_config_dict["device"] = self.device
        model_config_dict["cache_dir"] = str(self.cache_dir / "models")
        
        feature_extractor = load_image_feature_extractor(model_config_dict)
        logger.info(f"Loaded {model_config_dict.get('model_name')} reconstruction feature extractor")
        return feature_extractor

    def prepare_reconstruction_input_datasets_for_layer(
        self,
        reconstruction_pipeline: ReconstructionPipeline,
        train_source_loader: DataLoader,
        val_source_loader: DataLoader,
        test_source_loader: DataLoader,
        layer: int,
        image_feature_type: str,
    ) -> Tuple[ReconstructionDataset, ReconstructionDataset, ReconstructionDataset]:
        """
        Uses ReconstructionPipeline to extract image features for a specific layer
        and combine them with processed camera parameters, then creates datasets.
        """
        experiment_id = f"{self.config.models.model_name}_{self.config.experiment.name}_layer_{layer}"
        
        # Create datasets using the new API
        train_input_dataset = reconstruction_pipeline.create_dataset(
            dataloader=train_source_loader,
            layers=[layer],
            feature_type=image_feature_type,
            cache_key=f"{experiment_id}_train",
            force_recompute=self.config.get("force_recompute_processed_data", False)
        )
        
        val_input_dataset = reconstruction_pipeline.create_dataset(
            dataloader=val_source_loader,
            layers=[layer],
            feature_type=image_feature_type,
            cache_key=f"{experiment_id}_val",
            force_recompute=self.config.get("force_recompute_processed_data", False)
        )
        
        test_input_dataset = reconstruction_pipeline.create_dataset(
            dataloader=test_source_loader,
            layers=[layer],
            feature_type=image_feature_type,
            cache_key=f"{experiment_id}_test",
            force_recompute=self.config.get("force_recompute_processed_data", False)
        )
        
        return train_input_dataset, val_input_dataset, test_input_dataset

    def run_voxel_probe_experiment(
        self,
        train_processed_loader: DataLoader, 
        val_processed_loader: DataLoader,
        test_processed_loader: DataLoader,
        feature_dim: int,
        layer: int,
    ) -> Dict:
        """Run a single VoxelProbe experiment"""
        probe_type = "voxel"
        logger.info(
            f"Running VoxelProbe on layer {layer} (input_feature_dim: {feature_dim})"
        )

        probe_config = self.config.probing
        
        probe_config["input_dim"] = feature_dim
        probe_config["task_type"] = "voxel_reconstruction" 
        
        self.device = probe_config.get("device", self.device)

        probe = create_probe(probe_config)
        
        metrics_tracker = MetricsTracker()
        trainer = ProbeTrainer(
            probe, device=self.device, MetricsTracker=metrics_tracker
        )

        training_config = self.config.probing.get("training", {}) 
        optimizer_specific_config = probe_config.get("optimizer", training_config.get("optimizer", {}))
        scheduler_specific_config = probe_config.get("scheduler", training_config.get("scheduler", {}))

        optimizer = self.create_optimizer(probe, optimizer_specific_config)
        scheduler = self.create_scheduler(optimizer, scheduler_specific_config)

        epochs = training_config.get("epochs", 50) # Potentially more epochs for reconstruction
        early_stopping_patience = training_config.get("early_stopping_patience", 10)
        wandb_enabled = self.config.get("wandb", {}).get("enabled", False)

        best_model_state_dict, best_val_loss = trainer.train(
            epochs,
            optimizer,
            scheduler,
            early_stopping_patience,
            train_processed_loader,
            val_processed_loader,
            probe_type=probe_type, # Pass "voxel"
            layer=layer,
            wandb_enabled=wandb_enabled,
        )
        
        probe_filename = f"{self.config.models.model_name}_{probe_type}_layer_{layer}_probe.pth"
        probe_save_path = self.probe_save_dir / probe_filename
        
        torch.save({
            'model_state_dict': best_model_state_dict,
            'probe_config': probe_config, 
            'layer': layer,
            'probe_type': probe_type,
            'experiment_name': self.config.experiment.name,
            'model_name': self.config.models.model_name,
            'best_val_loss': best_val_loss,
            'input_feature_dim': feature_dim 
        }, probe_save_path)
        logger.info(f"Saved VoxelProbe for layer {layer} to {probe_save_path}")

        probe.load_state_dict(best_model_state_dict)
       
        test_metrics = trainer.evaluate(
            test_loader=test_processed_loader, 
            wandb_enabled=wandb_enabled, 
            probe_type=probe_type, 
            layer=layer
        )

        detailed_metrics = self._compute_detailed_metrics(probe, test_processed_loader) 

        total_epochs_trained = len(metrics_tracker.get_history("train"))

        results = {
            "train_history": metrics_tracker.get_history("train"),
            "val_history": metrics_tracker.get_history("val"),
            "test_metrics": test_metrics, 
            "detailed_metrics": detailed_metrics, 
            "best_epoch": metrics_tracker.best_epoch,
            "total_epochs": total_epochs_trained,
        }
        return results

    def create_optimizer(
        self, model: nn.Module, optimizer_config: Dict
    ) -> torch.optim.Optimizer:
        from hydra.utils import instantiate
        opt_config_copy = OmegaConf.create(optimizer_config) 
        if "_target_" not in opt_config_copy: 
             raise ValueError("Optimizer config must have a _target_ field")

        return instantiate(opt_config_copy, params=model.parameters())


    def create_scheduler(
        self, optimizer: torch.optim.Optimizer, scheduler_config: Dict
    ):
        if not scheduler_config or not scheduler_config.get("_target_"): 
            return None
        from hydra.utils import instantiate
        sched_config_copy = OmegaConf.create(scheduler_config)
        return instantiate(sched_config_copy, optimizer=optimizer)

    def _compute_detailed_metrics( 
        self, probe: nn.Module, test_loader: DataLoader 
    ) -> Dict:
        
        probe.to(self.device)
        probe.eval()
        all_predictions = [] 
        all_targets = []   

        with torch.no_grad():
            for batch in tqdm(test_loader, desc="Computing detailed metrics"):
                features = batch["processed_views"].to(self.device) 
                targets = batch["target_voxels"].to(self.device) 

                if probe.task_type == "voxel_reconstruction":
                    features = features.view(features.size(0), -1)

                outputs = probe(features) # Voxel logits: [B, 1, D, H, W]

                all_predictions.append(outputs.cpu()) 
                all_targets.append(targets.cpu())

        predictions_cat = torch.cat(all_predictions, dim=0)
        targets_cat = torch.cat(all_targets, dim=0)
        
        metrics = {}
        voxel_eval_metrics = compute_voxel_metrics(predictions_cat, targets_cat)
        metrics.update(voxel_eval_metrics)
        
        return metrics

    def save_results(self, results: Dict) -> str:
        import json
        results_file = self.results_dir / "reconstruction_results.json"
        serializable_results = self.make_json_serializable(results)
        combined_results = {
            "config": OmegaConf.to_container(self.config, resolve=True),
            "results": serializable_results,
        }
        with open(results_file, "w") as f:
            json.dump(combined_results, f, indent=2)
        logger.info(f"Reconstruction results saved to {results_file}")
        return str(results_file)

    def make_json_serializable(self, obj):
        if isinstance(obj, dict):
            return {k: self.make_json_serializable(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [self.make_json_serializable(v) for v in obj]
        elif isinstance(obj, (torch.Tensor, np.ndarray)):
            return obj.tolist() if hasattr(obj, "tolist") else float(obj)
        elif isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, Path):
            return str(obj)
        else:
            return obj

### Hydra Configuration Loading / Setup


In [3]:
from hydra import initialize, compose
from hydra.core.global_hydra import GlobalHydra
import os 
from pathlib import Path

CONFIG_PATH = "../configs"
CONFIG_NAME = "exp_ijepa_reconstruction"  # Use the new restructured config with defaults

cfg: Optional[DictConfig] = None

if GlobalHydra.instance().is_initialized():
    logger.info("Clearing existing Hydra global state.")
    GlobalHydra.instance().clear()

try:
    project_root = Path(os.getcwd()).parent 
    data_dir_abs = project_root / "data" 
    os.environ["DATA_DIR"] = str(data_dir_abs)
    
    logger.info(f"Initializing Hydra with config_path: '{CONFIG_PATH}' relative to {os.getcwd()}")
    initialize(version_base=None, config_path=CONFIG_PATH, job_name="reconstruction_experiment")
    
    logger.info(f"Composing configuration with config_name: '{CONFIG_NAME}'")
    cfg = compose(config_name=CONFIG_NAME) 

except Exception as e:
    logger.error(f"Error initializing Hydra or loading configuration: {e}", exc_info=True)
    cfg = None 

if cfg:
    logger.info("Hydra configuration loaded successfully for reconstruction experiment.")
    logger.info(f"Experiment name: {cfg.experiment.name}")
    logger.info(f"Task type: {cfg.probing.task_type}")
    logger.info(f"Probe types: {cfg.probing.probe_types}")

else:
    logger.error("Failed to load Hydra configuration. Please check paths and config files.")

# Quick check for critical reconstruction settings
if cfg and cfg.probing.task_type != "voxel_reconstruction":
    logger.warning(f"Configured task_type is '{cfg.probing.task_type}', expected 'voxel_reconstruction' for this notebook.")
if cfg and "voxel" not in cfg.probing.probe_types:
    logger.warning(f"Configured probe_types are '{cfg.probing.probe_types}', 'voxel' probe might not run.")


2025-06-05 00:09:10,414 - __main__ - INFO - Initializing Hydra with config_path: '../configs' relative to /Users/druhi/Documents/+Programming/GitHub/LatentInvestigation/notebooks
2025-06-05 00:09:10,577 - __main__ - INFO - Composing configuration with config_name: 'exp_ijepa_reconstruction'
2025-06-05 00:09:10,577 - __main__ - INFO - Composing configuration with config_name: 'exp_ijepa_reconstruction'
2025-06-05 00:09:10,635 - __main__ - INFO - Hydra configuration loaded successfully for reconstruction experiment.
2025-06-05 00:09:10,636 - __main__ - INFO - Experiment name: phase2_ijepa_voxel_reconstruction
2025-06-05 00:09:10,636 - __main__ - INFO - Task type: voxel_reconstruction
2025-06-05 00:09:10,636 - __main__ - INFO - Probe types: ['voxel']
2025-06-05 00:09:10,635 - __main__ - INFO - Hydra configuration loaded successfully for reconstruction experiment.
2025-06-05 00:09:10,636 - __main__ - INFO - Experiment name: phase2_ijepa_voxel_reconstruction
2025-06-05 00:09:10,636 - __main

## Running the Reconstruction Experiment

In [4]:
reconstruction_results = None
if cfg:
    logger.info("Starting reconstruction experiment execution")
    experiment = ReconstructionExperiment(cfg)
else:
    logger.error("Configuration not loaded. Cannot start experiment.")
    experiment = None
    

2025-06-05 00:09:10,650 - __main__ - INFO - Starting reconstruction experiment execution
2025-06-05 00:09:10,650 - __main__ - INFO - Using device: mps
2025-06-05 00:09:10,650 - __main__ - INFO - Using device: mps
2025-06-05 00:09:10,943 - wandb.jupyter - ERROR - Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
2025-06-05 00:09:10,943 - wandb.jupyter - ERROR - Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdruhin-bhowal[0m ([33mcse493g1_drn[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Currently logged in as: [33mdruhin-bhowal[0m ([33mcse493g1_drn[0m). Use [1m`wandb login --relogin`[0m to force relogin


### Load Source Data and Feature Extractor

In [None]:
if experiment:
    image_feature_extractor = experiment.load_reconstruction_feature_extractor()
    
    extraction_config = cfg.models.get("feature_extraction", {})
    layers_to_probe = extraction_config.get("layers", [0, 2, 5, 8, 11]) 
    image_feature_type = extraction_config.get("feature_type", "cls_token")

    train_source_loader, val_source_loader, test_source_loader = experiment.load_source_dataset()
    
    reconstruction_pipeline = ReconstructionPipeline(
        image_pipeline=image_feature_extractor,
        device=experiment.device,
        cache_dir=str(experiment.cache_dir / "processed_reconstruction_data")
    )
else:
    logger.error("Experiment not initialized. Skipping feature extractor and dataset loading.")


2025-06-05 00:09:14,322 - src.models.model_loader - INFO - Loading timm model 'facebook/vit_huge_patch14_224_ijepa'
2025-06-05 00:09:19,344 - timm.models._builder - INFO - Loading pretrained weights from url (https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar)
2025-06-05 00:09:19,344 - timm.models._builder - INFO - Loading pretrained weights from url (https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar)
2025-06-05 00:09:21,333 - src.models.reconstruction_feature_extractor - INFO - Loaded and froze ijepa on mps.
2025-06-05 00:09:21,334 - __main__ - INFO - Loaded ijepa reconstruction feature extractor
2025-06-05 00:09:21,333 - src.models.reconstruction_feature_extractor - INFO - Loaded and froze ijepa on mps.
2025-06-05 00:09:21,334 - __main__ - INFO - Loaded ijepa reconstruction feature extractor


Created train DataLoader with 2451 samples, batch size 32.
Created val DataLoader with 525 samples, batch size 32.
Created test DataLoader with 525 samples, batch size 32.


: 

### Process Data and Train VoxelProbes for Each Layer

In [None]:
if experiment:
    reconstruction_results = {}

    if "voxel" not in cfg.probing.probe_types:
        logger.error(f"'voxel' not in configured probe_types: {cfg.probing.probe_types}. VoxelProbes will not be trained.")
    else:
        logger.info(f"Will train VoxelProbes for layers: {layers_to_probe}")

        for layer_idx in tqdm(layers_to_probe, desc="Processing Layers"):
            logger.info(f"Processing layer {layer_idx} for reconstruction...")

            # Use the updated method with new API
            train_input_ds, val_input_ds, test_input_ds = experiment.prepare_reconstruction_input_datasets_for_layer(
                reconstruction_pipeline=reconstruction_pipeline,
                train_source_loader=train_source_loader,
                val_source_loader=val_source_loader,
                test_source_loader=test_source_loader,
                layer=layer_idx,
                image_feature_type=image_feature_type,
            )
            
            logger.info(f"Layer {layer_idx}: Train Input Dataset size: {len(train_input_ds)}")
            logger.info(f"Layer {layer_idx}: Val Input Dataset size: {len(val_input_ds)}")
            logger.info(f"Layer {layer_idx}: Test Input Dataset size: {len(test_input_ds)}")

            # Create DataLoaders from the datasets
            processed_batch_size = cfg.training.get("batch_size", 32)
            
            train_processed_loader = DataLoader(
                train_input_ds,
                batch_size=processed_batch_size,
                shuffle=True,
                num_workers=cfg.get("num_workers", 4),
                pin_memory=True
            )
            
            val_processed_loader = DataLoader(
                val_input_ds,
                batch_size=processed_batch_size,
                shuffle=False,
                num_workers=cfg.get("num_workers", 4),
                pin_memory=True
            )
            
            test_processed_loader = DataLoader(
                test_input_ds,
                batch_size=processed_batch_size,
                shuffle=False,
                num_workers=cfg.get("num_workers", 4),
                pin_memory=True
            )

            # Get feature dimension from sample
            sample_processed_data = train_input_ds[0]["processed_views"]
            input_feature_dim_for_probe = sample_processed_data.numel() 
            logger.info(f"Input feature dimension for VoxelProbe at layer {layer_idx}: {input_feature_dim_for_probe} (Shape of sample: {sample_processed_data.shape})")

            logger.info(f"Running VoxelProbe on layer {layer_idx}...")
            probe_run_results = experiment.run_voxel_probe_experiment(
                train_processed_loader=train_processed_loader,
                val_processed_loader=val_processed_loader,
                test_processed_loader=test_processed_loader,
                feature_dim=input_feature_dim_for_probe,
                layer=layer_idx,
            )
            reconstruction_results[f"layer_{layer_idx}"] = {"voxel": probe_run_results}

else:
    logger.error("Experiment not initialized. Skipping layer processing and probe training.")

2025-06-05 00:09:21,394 - __main__ - INFO - Will train VoxelProbes for layers: [2, 4, 6, 8, 10, 11]
Processing Layers:   0%|          | 0/6 [00:00<?, ?it/s]2025-06-05 00:09:21,457 - __main__ - INFO - Processing layer 2 for reconstruction...
2025-06-05 00:09:21,458 - src.probing.base_pipeline - INFO - Processing 76 batches...
Processing Layers:   0%|          | 0/6 [00:00<?, ?it/s]2025-06-05 00:09:21,457 - __main__ - INFO - Processing layer 2 for reconstruction...
2025-06-05 00:09:21,458 - src.probing.base_pipeline - INFO - Processing 76 batches...
2025-06-05 00:09:26,411 - src.models.base_feature_extractor - INFO - Processing 768 images in chunks of 8
2025-06-05 00:09:26,411 - src.models.base_feature_extractor - INFO - Processing 768 images in chunks of 8

Aborted!

Aborted!


In [None]:
if experiment and reconstruction_results:
    logger.info("Saving reconstruction experiment results...")
    result_file_path = experiment.save_results(reconstruction_results)
    logger.info(f"Results saved to: {result_file_path}")
else:
    logger.warning("No results to save or experiment not run.")

### Analyze and Visualize Results


In [None]:
from src.analysis.layer_analysis import analyze_experiment_results

if experiment and reconstruction_results and 'result_file_path' in locals():
    logger.info("Analyzing reconstruction results...")

    analyze_experiment_results(
        results_file=result_file_path,
        output_dir=Path(result_file_path).parent
    )
else:
    logger.warning("No results to analyze or result file path not available.")
