In [3]:
import h5py
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import json
import logging
from typing import Dict, List, Tuple, Optional
import pandas as pd
from datetime import datetime
import numpy as np

In [4]:
class TestDataLoader:
    """
    Loads and analyzes H5py files generated during test mode runs
    """
    
    def __init__(self, test_dir: str = "test"):
        """
        Initialize the test data loader
        
        Args:
            test_dir: Directory containing test output files
        """
        self.test_dir = Path(test_dir)
        self.logger = logging.getLogger(__name__)
        
        # Initialize data containers
        self.exposures = {}
        self.summary_stats = {}
        
        # Check if test directory exists
        if not self.test_dir.exists():
            self.logger.warning(f"Test directory {self.test_dir} does not exist")
            return
            
        self._discover_test_files()
    
    def _discover_test_files(self):
        """Discover and catalog all test files"""
        self.logger.info(f"Discovering test files in {self.test_dir}")
        
        # Look for different data types
        self.difference_files = list((self.test_dir / "raw_differences").glob("*_differences.h5")) if (self.test_dir / "raw_differences").exists() else []
        self.patch_files = list((self.test_dir / "patches").glob("*_patches_*.h5")) if (self.test_dir / "patches").exists() else []
        self.temporal_files = list((self.test_dir / "temporal_analysis").glob("*_temporal.h5")) if (self.test_dir / "temporal_analysis").exists() else []
        self.metadata_files = list((self.test_dir / "metadata").glob("*_metadata.json")) if (self.test_dir / "metadata").exists() else []
        
        self.logger.info(f"Found {len(self.difference_files)} difference files")
        self.logger.info(f"Found {len(self.patch_files)} patch files")
        self.logger.info(f"Found {len(self.temporal_files)} temporal files")
        self.logger.info(f"Found {len(self.metadata_files)} metadata files")
    
    def load_exposure_data(self, exposure_id: str) -> Dict:
        """
        Load all data for a specific exposure
        
        Args:
            exposure_id: ID of the exposure to load
            
        Returns:
            Dictionary containing all data for the exposure
        """
        data = {
            'exposure_id': exposure_id,
            'differences': None,
            'patches': {},
            'temporal': None,
            'metadata': None
        }
        
        # Load difference data
        diff_file = self.test_dir / "raw_differences" / f"{exposure_id}_differences.h5"
        if diff_file.exists():
            with h5py.File(diff_file, 'r') as f:
                data['differences'] = {
                    'data': f['differences'][:],
                    'frame_times': f['frame_times'][:] if 'frame_times' in f else None,
                    'reference_frame': f['reference_frame'][:] if 'reference_frame' in f else None,
                    'attrs': dict(f.attrs)
                }
        
        # Load patch data (multiple patch sizes)
        patch_dir = self.test_dir / "patches"
        if patch_dir.exists():
            for patch_file in patch_dir.glob(f"{exposure_id}_patches_*.h5"):
                # Extract patch size from filename
                patch_size = patch_file.stem.split('_')[-1]
                with h5py.File(patch_file, 'r') as f:
                    data['patches'][patch_size] = {
                        'patches': f['patches'][:],
                        'positions': f['positions'][:] if 'positions' in f else None,
                        'frame_indices': f['frame_indices'][:] if 'frame_indices' in f else None,
                        'anomaly_scores': f['anomaly_scores'][:] if 'anomaly_scores' in f else None,
                        'attrs': dict(f.attrs)
                    }
        
        # Load temporal analysis data
        temporal_file = self.test_dir / "temporal_analysis" / f"{exposure_id}_temporal.h5"
        if temporal_file.exists():
            with h5py.File(temporal_file, 'r') as f:
                data['temporal'] = {
                    'temporal_stats': f['temporal_stats'][:] if 'temporal_stats' in f else None,
                    'frame_statistics': f['frame_statistics'][:] if 'frame_statistics' in f else None,
                    'attrs': dict(f.attrs)
                }
        
        # Load metadata
        metadata_file = self.test_dir / "metadata" / f"{exposure_id}_metadata.json"
        if metadata_file.exists():
            with open(metadata_file, 'r') as f:
                data['metadata'] = json.load(f)
        
        self.exposures[exposure_id] = data
        return data
    
    def load_all_exposures(self) -> Dict:
        """Load data for all discovered exposures"""
        exposure_ids = set()
        
        # Extract exposure IDs from filenames
        for file_list in [self.difference_files, self.patch_files, self.temporal_files, self.metadata_files]:
            for file_path in file_list:
                # Extract exposure ID (assumes format: exposureID_datatype.h5)
                parts = file_path.stem.split('_')
                if len(parts) >= 2:
                    exposure_id = '_'.join(parts[:-1])  # Everything except the last part
                    if 'patches' in parts:
                        # For patch files: exposureID_patches_size.h5
                        exposure_id = '_'.join(parts[:-2])
                    exposure_ids.add(exposure_id)
        
        for exposure_id in exposure_ids:
            self.load_exposure_data(exposure_id)
        
        return self.exposures
    
    def get_data_summary(self) -> Dict:
        """Get summary statistics of loaded data"""
        if not self.exposures:
            self.load_all_exposures()
        
        summary = {
            'total_exposures': len(self.exposures),
            'exposures': {},
            'overall_stats': {
                'total_difference_frames': 0,
                'total_patches': 0,
                'patch_sizes': set(),
                'datasets': set()
            }
        }
        
        for exp_id, data in self.exposures.items():
            exp_summary = {
                'has_differences': data['differences'] is not None,
                'has_patches': len(data['patches']) > 0,
                'has_temporal': data['temporal'] is not None,
                'has_metadata': data['metadata'] is not None
            }
            
            if data['differences']:
                diff_shape = data['differences']['data'].shape
                exp_summary['difference_shape'] = diff_shape
                exp_summary['num_frames'] = diff_shape[0] if len(diff_shape) > 2 else 1
                summary['overall_stats']['total_difference_frames'] += exp_summary['num_frames']
            
            if data['patches']:
                exp_summary['patch_sizes'] = list(data['patches'].keys())
                summary['overall_stats']['patch_sizes'].update(exp_summary['patch_sizes'])
                
                patch_count = sum(data['patches'][size]['patches'].shape[0] 
                                 for size in data['patches'])
                exp_summary['total_patches'] = patch_count
                summary['overall_stats']['total_patches'] += patch_count
            
            if data['metadata']:
                dataset_type = data['metadata'].get('dataset_type', 'unknown')
                exp_summary['dataset_type'] = dataset_type
                summary['overall_stats']['datasets'].add(dataset_type)
            
            summary['exposures'][exp_id] = exp_summary
        
        # Convert sets to lists for JSON serialization
        summary['overall_stats']['patch_sizes'] = list(summary['overall_stats']['patch_sizes'])
        summary['overall_stats']['datasets'] = list(summary['overall_stats']['datasets'])
        
        self.summary_stats = summary
        return summary
    
    def analyze_differences(self, exposure_id: str = None) -> Dict:
        """
        Analyze frame difference data
        
        Args:
            exposure_id: Specific exposure to analyze, or None for all
            
        Returns:
            Analysis results
        """
        if exposure_id:
            exposures_to_analyze = [exposure_id] if exposure_id in self.exposures else []
        else:
            exposures_to_analyze = list(self.exposures.keys())
        
        analysis = {}
        
        for exp_id in exposures_to_analyze:
            data = self.exposures[exp_id]
            if not data['differences']:
                continue
                
            diff_data = data['differences']['data']
            
            # Calculate statistics
            stats = {
                'shape': diff_data.shape,
                'mean': float(np.mean(diff_data)),
                'std': float(np.std(diff_data)),
                'min': float(np.min(diff_data)),
                'max': float(np.max(diff_data)),
                'median': float(np.median(diff_data)),
                'percentiles': {
                    '1': float(np.percentile(diff_data, 1)),
                    '5': float(np.percentile(diff_data, 5)),
                    '95': float(np.percentile(diff_data, 95)),
                    '99': float(np.percentile(diff_data, 99))
                }
            }
            
            # Frame-by-frame statistics if multiple frames
            if len(diff_data.shape) > 2:
                frame_stats = []
                for i in range(diff_data.shape[0]):
                    frame = diff_data[i]
                    frame_stats.append({
                        'frame': i,
                        'mean': float(np.mean(frame)),
                        'std': float(np.std(frame)),
                        'rms': float(np.sqrt(np.mean(frame**2)))
                    })
                stats['frame_stats'] = frame_stats
            
            analysis[exp_id] = stats
        
        return analysis
    
    def plot_difference_analysis(self, exposure_id: str, save_dir: str = None):
        """
        Create plots for difference data analysis
        
        Args:
            exposure_id: Exposure to plot
            save_dir: Directory to save plots (optional)
        """
        if exposure_id not in self.exposures:
            self.logger.error(f"Exposure {exposure_id} not loaded")
            return
        
        data = self.exposures[exposure_id]['differences']
        if not data:
            self.logger.error(f"No difference data for exposure {exposure_id}")
            return
        
        diff_data = data['data']
        
        # Create figure with subplots
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        fig.suptitle(f'Difference Analysis: {exposure_id}', fontsize=16)
        
        # Histogram of all pixel values
        axes[0, 0].hist(diff_data.flatten(), bins=100, alpha=0.7, edgecolor='black')
        axes[0, 0].set_title('Pixel Value Distribution')
        axes[0, 0].set_xlabel('Difference Value')
        axes[0, 0].set_ylabel('Count')
        axes[0, 0].set_yscale('log')
        
        # Show first frame (or single frame)
        first_frame = diff_data[0] if len(diff_data.shape) > 2 else diff_data
        im1 = axes[0, 1].imshow(first_frame, cmap='viridis', aspect='auto')
        axes[0, 1].set_title('First Difference Frame')
        plt.colorbar(im1, ax=axes[0, 1])
        
        # Show standard deviation across frames (if multiple frames)
        if len(diff_data.shape) > 2:
            std_frame = np.std(diff_data, axis=0)
            im2 = axes[0, 2].imshow(std_frame, cmap='plasma', aspect='auto')
            axes[0, 2].set_title('Std Dev Across Frames')
            plt.colorbar(im2, ax=axes[0, 2])
            
            # Frame-by-frame statistics
            frame_means = [np.mean(diff_data[i]) for i in range(diff_data.shape[0])]
            frame_stds = [np.std(diff_data[i]) for i in range(diff_data.shape[0])]
            
            axes[1, 0].plot(frame_means, 'b-', label='Mean')
            axes[1, 0].set_title('Frame Statistics')
            axes[1, 0].set_xlabel('Frame Number')
            axes[1, 0].set_ylabel('Mean Value')
            axes[1, 0].legend()
            
            ax2 = axes[1, 0].twinx()
            ax2.plot(frame_stds, 'r-', label='Std Dev')
            ax2.set_ylabel('Std Dev')
            ax2.legend(loc='upper right')
        else:
            axes[0, 2].text(0.5, 0.5, 'Single Frame\nNo Temporal Analysis', 
                           ha='center', va='center', transform=axes[0, 2].transAxes)
            axes[1, 0].text(0.5, 0.5, 'Single Frame\nNo Frame Statistics', 
                           ha='center', va='center', transform=axes[1, 0].transAxes)
        
        # Radial profile (assuming square image)
        if len(first_frame.shape) == 2:
            center_y, center_x = np.array(first_frame.shape) // 2
            y, x = np.ogrid[:first_frame.shape[0], :first_frame.shape[1]]
            r = np.sqrt((x - center_x)**2 + (y - center_y)**2)
            
            # Bin by radius
            max_r = int(np.max(r))
            radial_profile = []
            radii = []
            for i in range(0, max_r, max(1, max_r//50)):
                mask = (r >= i) & (r < i+1)
                if np.any(mask):
                    radial_profile.append(np.mean(first_frame[mask]))
                    radii.append(i)
            
            axes[1, 1].plot(radii, radial_profile, 'g-')
            axes[1, 1].set_title('Radial Profile')
            axes[1, 1].set_xlabel('Radius (pixels)')
            axes[1, 1].set_ylabel('Mean Value')
        
        # Power spectrum (2D FFT)
        fft_data = np.fft.fft2(first_frame)
        power_spectrum = np.abs(fft_data)**2
        axes[1, 2].imshow(np.log10(np.fft.fftshift(power_spectrum) + 1), 
                         cmap='hot', aspect='auto')
        axes[1, 2].set_title('Power Spectrum (log)')
        
        plt.tight_layout()
        
        if save_dir:
            save_path = Path(save_dir) / f"{exposure_id}_difference_analysis.png"
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
            self.logger.info(f"Saved plot to {save_path}")
        
        plt.show()
    
    def compare_exposures(self, exposure_ids: List[str]) -> Dict:
        """
        Compare statistics across multiple exposures
        
        Args:
            exposure_ids: List of exposure IDs to compare
            
        Returns:
            Comparison results
        """
        comparison = {
            'exposures': exposure_ids,
            'differences': {},
            'patches': {}
        }
        
        # Compare difference statistics
        diff_stats = []
        for exp_id in exposure_ids:
            if exp_id in self.exposures and self.exposures[exp_id]['differences']:
                data = self.exposures[exp_id]['differences']['data']
                stats = {
                    'exposure_id': exp_id,
                    'mean': float(np.mean(data)),
                    'std': float(np.std(data)),
                    'rms': float(np.sqrt(np.mean(data**2))),
                    'shape': data.shape
                }
                diff_stats.append(stats)
        
        comparison['differences']['stats'] = diff_stats
        
        # Compare patch counts
        patch_stats = []
        for exp_id in exposure_ids:
            if exp_id in self.exposures and self.exposures[exp_id]['patches']:
                patches = self.exposures[exp_id]['patches']
                for size, patch_data in patches.items():
                    patch_stats.append({
                        'exposure_id': exp_id,
                        'patch_size': size,
                        'count': patch_data['patches'].shape[0],
                        'mean_anomaly_score': float(np.mean(patch_data['anomaly_scores'])) if patch_data['anomaly_scores'] is not None else None
                    })
        
        comparison['patches']['stats'] = patch_stats
        
        return comparison
    
    def export_summary_report(self, filename: str = None):
        """Export a comprehensive summary report"""
        if not filename:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = f"test_data_summary_{timestamp}.json"
        
        report = {
            'generated_at': datetime.now().isoformat(),
            'test_directory': str(self.test_dir),
            'summary': self.get_data_summary(),
            'file_discovery': {
                'difference_files': [str(f) for f in self.difference_files],
                'patch_files': [str(f) for f in self.patch_files],
                'temporal_files': [str(f) for f in self.temporal_files],
                'metadata_files': [str(f) for f in self.metadata_files]
            }
        }
        
        # Add detailed analysis for each exposure
        if self.exposures:
            report['detailed_analysis'] = self.analyze_differences()
        
        with open(filename, 'w') as f:
            json.dump(report, f, indent=2)
        
        self.logger.info(f"Summary report exported to {filename}")
        return filename


def main():
    """Example usage of the TestDataLoader"""
    # Setup logging
    logging.basicConfig(level=logging.INFO, 
                       format='%(asctime)s - %(levelname)s - %(message)s')
    
    # Initialize loader
    loader = TestDataLoader("test")  # or whatever your test directory is
    
    # Load all exposures
    exposures = loader.load_all_exposures()
    print(f"Loaded {len(exposures)} exposures")
    
    # Get summary
    summary = loader.get_data_summary()
    print("\nData Summary:")
    print(f"Total exposures: {summary['total_exposures']}")
    print(f"Total difference frames: {summary['overall_stats']['total_difference_frames']}")
    print(f"Total patches: {summary['overall_stats']['total_patches']}")
    print(f"Patch sizes: {summary['overall_stats']['patch_sizes']}")
    print(f"Datasets: {summary['overall_stats']['datasets']}")
    
    # Analyze differences
    if exposures:
        first_exposure = list(exposures.keys())[0]
        print(f"\nAnalyzing differences for {first_exposure}:")
        analysis = loader.analyze_differences(first_exposure)
        print(f"Shape: {analysis[first_exposure]['shape']}")
        print(f"Mean: {analysis[first_exposure]['mean']:.6f}")
        print(f"Std: {analysis[first_exposure]['std']:.6f}")
        
        # Create plots
        loader.plot_difference_analysis(first_exposure)
    
    # Export report
    report_file = loader.export_summary_report()
    print(f"\nDetailed report saved to: {report_file}")


if __name__ == "__main__":
    main()



AttributeError: 'TestDataLoader' object has no attribute 'difference_files'