In [None]:
import sys
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings

warnings.filterwarnings('ignore')

sys.path.append('/explore/nobackup/people/jacaraba/development/satvision-pix4d')

from satvision_pix4d.configs.config import _C, _update_config_from_file
from satvision_pix4d.pipelines import PIPELINES, get_available_pipelines
from satvision_pix4d.datasets.abi_temporal_benchmark_dataset import ABITemporalBenchmarkDataset
from satvision_pix4d.datasets.abi_temporal_dataset import ABITemporalDataset

In [None]:
# Whether to save files to a PDF, and where to save them 
save_to_pdf = False
pdf_path = "chip_plot.pdf"

# RGB indices for ABI data (16 channels instead of 14)
rgb_index = [0, 2, 1]  # Adjust based on your ABI channels

# Use your local model paths instead of downloading
model_filename = '/explore/nobackup/projects/pix4dcloud/jacaraba/model_development/satmae/' + \
    'satmae_satvision_pix4d_pretrain-dev/satmae_satvision_pix4d_pretrain-dev/epoch-epoch=40.ckpt/checkpoint/mp_rank_00_model_states.pt'

config_filename = '/explore/nobackup/people/jacaraba/development/satvision-pix4d/tests/configs/test_satmae_dev.yaml'

In [None]:
# Load configuration
config = _C.clone()
_update_config_from_file(config, config_filename)
print("Loaded configuration file.")

# Update config with your paths
config.defrost()
config.MODEL.PRETRAINED = model_filename
config.OUTPUT = '.'
config.freeze()
print("Updated configuration file.")

# Get pipeline and load model
available_pipelines = get_available_pipelines()
print("Available pipelines:", available_pipelines)

pipeline = PIPELINES[config.PIPELINE]
print(f'Using {pipeline}')

ptlPipeline = pipeline(config)

# Load model
print(f'Attempting to load checkpoint from {config.MODEL.PRETRAINED}')
model = ptlPipeline.load_checkpoint(config.MODEL.PRETRAINED, config)
print('Successfully applied checkpoint')

model.cpu()
model.eval()
print('Successfully moved to CPU and eval mode')

In [None]:
import os
import random

data_dir = '/explore/nobackup/projects/pix4dcloud/jacaraba/tiles_pix4d/3-tiles/convection/20200101'

all_zarr_files = [f for f in os.listdir(data_dir) if f.endswith('.zarr')]
print(f"Found {len(all_zarr_files)} zarr files in directory")

train_ds = ABITemporalDataset(
    data_paths=[data_dir], 
    img_size=512,
    in_chans=16,
    data_var='__xarray_dataarray_variable__'
)

print(f"Dataset contains {len(train_ds)} samples from all files")

# Process multiple samples from the dataset
num_samples = min(5, len(train_ds))  
print(f"Processing {num_samples} samples from multiple files...")

all_inputs = []
all_outputs = []
all_masks = []
all_losses = []

print(f"\nProcessing samples for reconstruction...")
for sample_idx in tqdm(range(num_samples)):
    try:
        imgs, ts = train_ds[sample_idx]
        
        # NO NORMALIZATION - use raw data as-is
        # Just ensure it's a tensor and add batch dimension
        imgs_batch = imgs.unsqueeze(0).cpu()
        if isinstance(ts, np.ndarray):
            ts = torch.from_numpy(ts).float()
        ts_batch = ts.unsqueeze(0).cpu()
        
        with torch.no_grad():
            loss, pred, mask = model(imgs_batch, ts_batch)
            
            B, T, C, H, W = imgs_batch.shape
            pred_imgs = model.model.unpatchify(pred, T, H, W)
            pred_imgs = torch.clamp(pred_imgs, 0, 1)
        
        # Store results
        all_inputs.append(imgs_batch.cpu().squeeze(0))
        all_outputs.append(pred_imgs.cpu().squeeze(0))
        all_masks.append(mask.cpu().squeeze(0))
        all_losses.append(loss.cpu().item())
        
        if sample_idx % 10 == 0:  # Print progress every 10 samples
            print(f"Sample {sample_idx}: Loss = {loss.item():.2f}")
        
    except Exception as e:
        print(f"Error processing sample {sample_idx}: {e}")
        continue

print(f"\n=== FINAL RESULTS ===")
print(f"Total samples processed: {len(all_inputs)}")
print(f"Average loss across all samples: {np.mean(all_losses):.4f}")

# Update global variables for compatibility with existing code
inputs = all_inputs
outputs = all_outputs
masks = all_masks
losses = all_losses

In [None]:
from skimage.metrics import structural_similarity as ssim
import numpy as np

def calculate_mse(original, reconstructed):
    return torch.mean((original - reconstructed) ** 2).item()

def calculate_psnr(mse, data_range=1.0):
    if mse == 0:
        return float('inf')
    return 20 * np.log10(data_range / np.sqrt(mse))

def calculate_mae(original, reconstructed):
    return torch.mean(torch.abs(original - reconstructed)).item()

def calculate_ssim_fast(original, reconstructed, sample_channels=4):
    """Fast SSIM - only sample a few channels to speed up calculation"""
    if isinstance(original, torch.Tensor):
        original = original.numpy()
    if isinstance(reconstructed, torch.Tensor):
        reconstructed = reconstructed.numpy()
    
    ssim_scores = []
    
    #sample
    total_channels = original.shape[0] * original.shape[1]  # T * C
    channel_indices = np.linspace(0, total_channels-1, sample_channels, dtype=int)
    
    for idx in channel_indices:
        t_idx = idx // original.shape[1]  # timestep
        c_idx = idx % original.shape[1]   # channel
        
        if t_idx < original.shape[0] and c_idx < original.shape[1]:
            try:
                score = ssim(original[t_idx, c_idx], reconstructed[t_idx, c_idx], 
                           data_range=1.0)
                ssim_scores.append(score)
            except:
                continue
    
    return np.mean(ssim_scores) if ssim_scores else 0.0


print("Calculating comprehensive metrics for large dataset...")
print(f"Processing {len(inputs)} samples...")
print("="*70)


metrics_results = []
batch_size = 20  

# Process in batches
for batch_start in range(0, len(inputs), batch_size):
    batch_end = min(batch_start + batch_size, len(inputs))
    batch_indices = range(batch_start, batch_end)
    
    print(f"Processing batch {batch_start//batch_size + 1}/{(len(inputs)-1)//batch_size + 1} "
          f"(samples {batch_start}-{batch_end-1})")
    
    batch_metrics = []
    
    for i in batch_indices:
        original = inputs[i]
        reconstructed = outputs[i]
        
        mse = calculate_mse(original, reconstructed)
        mae = calculate_mae(original, reconstructed)
        psnr = calculate_psnr(mse)
        
        ssim_score = calculate_ssim_fast(original, reconstructed, sample_channels=4)
        
        rgb_indices = [1, 0, 2]
        rgb_metrics = {}
        
        for band_name, band_idx in zip(['red', 'green', 'blue'], rgb_indices):
            if band_idx < original.shape[1]:
                
                orig_band = original[:, band_idx, :, :].mean(dim=0)
                recon_band = reconstructed[:, band_idx, :, :].mean(dim=0)
                
                band_mse = calculate_mse(orig_band, recon_band)
                band_psnr = calculate_psnr(band_mse)
                
                rgb_metrics[band_name] = {'mse': band_mse, 'psnr': band_psnr}
        
        sample_metrics = {
            'sample_idx': i,
            'loss': losses[i],
            'mse': mse,
            'mae': mae,
            'psnr': psnr,
            'ssim': ssim_score,
            'rgb_metrics': rgb_metrics
        }
        
        batch_metrics.append(sample_metrics)
    
    metrics_results.extend(batch_metrics)
    
    # Print batch summary
    batch_mse = [m['mse'] for m in batch_metrics]
    batch_psnr = [m['psnr'] for m in batch_metrics]
    print(f"  Batch avg MSE: {np.mean(batch_mse):.6f}, avg PSNR: {np.mean(batch_psnr):.2f}")

print("\n" + "="*70)
print("LARGE DATASET METRICS SUMMARY")
print("="*70)

# Overall statistics
all_mse = [m['mse'] for m in metrics_results]
all_mae = [m['mae'] for m in metrics_results]
all_psnr = [m['psnr'] for m in metrics_results]
all_ssim = [m['ssim'] for m in metrics_results]
all_loss = [m['loss'] for m in metrics_results]

print(f"Dataset Statistics (n={len(metrics_results)} samples):")
print(f"  MSE:  {np.mean(all_mse):.6f} ± {np.std(all_mse):.6f}")
print(f"  MAE:  {np.mean(all_mae):.6f} ± {np.std(all_mae):.6f}")
print(f"  PSNR: {np.mean(all_psnr):.2f} ± {np.std(all_psnr):.2f} dB")
print(f"  SSIM: {np.mean(all_ssim):.4f} ± {np.std(all_ssim):.4f}")
print(f"  Loss: {np.mean(all_loss):.2f} ± {np.std(all_loss):.2f}")


print(f"\nPerformance Distribution Analysis:")
percentiles = [10, 25, 50, 75, 90]
mse_percentiles = np.percentile(all_mse, percentiles)
psnr_percentiles = np.percentile(all_psnr, percentiles)

print(f"MSE Percentiles:")
for p, val in zip(percentiles, mse_percentiles):
    print(f"  {p}th: {val:.6f}")

print(f"PSNR Percentiles:")
for p, val in zip(percentiles, psnr_percentiles):
    print(f"  {p}th: {val:.2f} dB")


mse_threshold_good = np.percentile(all_mse, 33)  # Bottom third
mse_threshold_bad = np.percentile(all_mse, 67)   # Top third

good_samples = [i for i, m in enumerate(all_mse) if m <= mse_threshold_good]
medium_samples = [i for i, m in enumerate(all_mse) if mse_threshold_good < m <= mse_threshold_bad]
bad_samples = [i for i, m in enumerate(all_mse) if m > mse_threshold_bad]

print(f"\nPerformance Categories:")
print(f"  Good samples (bottom 33%): {len(good_samples)} samples")
print(f"    Avg MSE: {np.mean([all_mse[i] for i in good_samples]):.6f}")
print(f"    Avg PSNR: {np.mean([all_psnr[i] for i in good_samples]):.2f} dB")

print(f"  Medium samples (middle 33%): {len(medium_samples)} samples")
print(f"    Avg MSE: {np.mean([all_mse[i] for i in medium_samples]):.6f}")
print(f"    Avg PSNR: {np.mean([all_psnr[i] for i in medium_samples]):.2f} dB")

print(f"  Bad samples (top 33%): {len(bad_samples)} samples")
print(f"    Avg MSE: {np.mean([all_mse[i] for i in bad_samples]):.6f}")
print(f"    Avg PSNR: {np.mean([all_psnr[i] for i in bad_samples]):.2f} dB")


print(f"\nRGB Band Performance Summary:")
for band_name in ['red', 'green', 'blue']:
    band_mse = [m['rgb_metrics'][band_name]['mse'] for m in metrics_results 
                if band_name in m['rgb_metrics']]
    band_psnr = [m['rgb_metrics'][band_name]['psnr'] for m in metrics_results 
                 if band_name in m['rgb_metrics']]
    
    if band_mse:
        print(f"  {band_name.capitalize()}: MSE={np.mean(band_mse):.6f}, PSNR={np.mean(band_psnr):.2f}dB")

# Best and worst samples
best_idx = np.argmin(all_mse)
worst_idx = np.argmax(all_mse)

print(f"\nExtreme Samples:")
print(f"  Best sample #{best_idx}: MSE={all_mse[best_idx]:.6f}, PSNR={all_psnr[best_idx]:.2f}dB")
print(f"  Worst sample #{worst_idx}: MSE={all_mse[worst_idx]:.6f}, PSNR={all_psnr[worst_idx]:.2f}dB")
print(f"  Performance range: {all_mse[worst_idx]/all_mse[best_idx]:.1f}x difference")

# Save results summary
performance_categories = {
    'good': good_samples,
    'medium': medium_samples, 
    'bad': bad_samples
}

print(f"\nMetrics calculation complete!")
print(f"Results stored in 'metrics_results' and 'performance_categories' variables")

In [None]:
# Enhanced visualization of reconstruction results for ABI satellite data
def visualize_reconstruction(inputs, outputs, sample_idx=0, timestep=0, rgb_index=[1, 0, 2]):
    """Visualize original vs reconstructed for ABI satellite data"""
    
    original = inputs[sample_idx][timestep]  # (16, 512, 512) - 16 bands
    reconstructed = outputs[sample_idx][timestep]  # (16, 512, 512) - 16 bands
    
    print(f"Data shapes - Original: {original.shape}, Reconstructed: {reconstructed.shape}")
    print(f"Data ranges - Original: [{original.min():.3f}, {original.max():.3f}], "
          f"Reconstructed: [{reconstructed.min():.3f}, {reconstructed.max():.3f}]")
    
    # Create RGB images using ABI bands
    original_rgb = original[rgb_index].permute(1, 2, 0).numpy()  # (512, 512, 3)
    reconstructed_rgb = reconstructed[rgb_index].permute(1, 2, 0).numpy()  # (512, 512, 3)
    
    # Better normalization for ABI data
    def normalize_abi_rgb(rgb_data):
        """Normalize ABI RGB data for better visualization"""
        normalized = np.zeros_like(rgb_data)
        for i in range(3):
            channel = rgb_data[:, :, i]
            # Use percentile normalization to handle outliers
            p2, p98 = np.percentile(channel, [2, 98])
            normalized[:, :, i] = np.clip((channel - p2) / (p98 - p2), 0, 1)
        return normalized
    
    original_display = normalize_abi_rgb(original_rgb)
    reconstructed_display = normalize_abi_rgb(reconstructed_rgb)
    
    diff = np.abs(original_display - reconstructed_display)
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # Top row: RGB composites
    axes[0, 0].imshow(original_display)
    axes[0, 0].set_title(f'Original RGB (Sample {sample_idx}, Time {timestep})\nBands {rgb_index}')
    axes[0, 0].axis('off')
    
    axes[0, 1].imshow(reconstructed_display)
    axes[0, 1].set_title(f'Reconstructed RGB\nBands {rgb_index}')
    axes[0, 1].axis('off')
    
    im_diff = axes[0, 2].imshow(diff, cmap='hot')
    axes[0, 2].set_title(f'RGB Difference\nMean: {np.mean(diff):.4f}')
    axes[0, 2].axis('off')
    plt.colorbar(im_diff, ax=axes[0, 2], fraction=0.046, pad=0.04)
    
   
    for i, (band_idx, color_name) in enumerate(zip(rgb_index, ['Red', 'Green', 'Blue'])):
        orig_band = original[band_idx].numpy()
        recon_band = reconstructed[band_idx].numpy()
        
        orig_norm = (orig_band - np.percentile(orig_band, 2)) / (np.percentile(orig_band, 98) - np.percentile(orig_band, 2))
        recon_norm = (recon_band - np.percentile(recon_band, 2)) / (np.percentile(recon_band, 98) - np.percentile(recon_band, 2))
        
        orig_norm = np.clip(orig_norm, 0, 1)
        recon_norm = np.clip(recon_norm, 0, 1)
        
        band_diff = np.abs(orig_norm - recon_norm)
        
        # Create side-by-side comparison
        combined = np.hstack([orig_norm, recon_norm, band_diff])
        
        im = axes[1, i].imshow(combined, cmap='viridis')
        axes[1, i].set_title(f'Band {band_idx} ({color_name})\nOrig | Recon | Diff')
        axes[1, i].axis('off')
        
        # Add vertical lines to separate sections
        h, w = orig_norm.shape
        axes[1, i].axvline(x=w-0.5, color='white', linewidth=2)
        axes[1, i].axvline(x=2*w-0.5, color='white', linewidth=2)
    
    plt.tight_layout()
    plt.show()
    
    # Print quantitative metrics
    mse = np.mean((original.numpy() - reconstructed.numpy()) ** 2)
    rgb_mse = np.mean((original_display - reconstructed_display) ** 2)
    
    print(f"\nQuantitative Metrics:")
    print(f"  Overall MSE (all bands): {mse:.6f}")
    print(f"  RGB MSE: {rgb_mse:.6f}")
    print(f"  RGB Mean Abs Diff: {np.mean(diff):.4f}")
    print(f"  RGB Max Abs Diff: {np.max(diff):.4f}")
    print(f"  Pixels with >10% difference: {np.sum(diff > 0.1) / diff.size * 100:.1f}%")

# Show reconstructions for first few samples
if len(inputs) > 0:
    print("ABI Satellite Data Reconstruction Results:")
    print("="*50)
    
    for i in range(min(3, len(inputs))):
        print(f"\n--- Sample {i} ---")
        visualize_reconstruction(inputs, outputs, sample_idx=i, timestep=0, rgb_index=[1, 0, 2])
        
else:
    print("No reconstruction data available for visualization")

In [None]:
# Simple stratification analysis
def categorize_by_performance(losses):
    """Categorize samples by reconstruction difficulty"""
    sorted_indices = np.argsort(losses)
    n = len(losses)
    
    categories = {
        'easy': sorted_indices[:n//3],
        'medium': sorted_indices[n//3:2*n//3],
        'hard': sorted_indices[2*n//3:]
    }
    
    return categories

def analyze_by_category(inputs, outputs, losses, categories):
    """Analyze performance by category"""
    results = {}
    
    for cat_name, indices in categories.items():
        cat_mse = []
        cat_psnr = []
        
        for idx in indices:
            mse = calculate_mse(inputs[idx], outputs[idx])
            psnr = calculate_psnr(mse)
            cat_mse.append(mse)
            cat_psnr.append(psnr)
        
        results[cat_name] = {
            'count': len(indices),
            'avg_mse': np.mean(cat_mse),
            'std_mse': np.std(cat_mse),
            'avg_psnr': np.mean(cat_psnr),
            'std_psnr': np.std(cat_psnr),
            'avg_loss': np.mean([losses[i] for i in indices])
        }
    
    return results

# Run analysis
categories = categorize_by_performance(losses)
results = analyze_by_category(inputs, outputs, losses, categories)

# Print results
print("Performance Analysis by Category:")
print("=" * 50)
for cat_name, metrics in results.items():
    print(f"{cat_name.upper()} Category:")
    print(f"  Count: {metrics['count']}")
    print(f"  Avg MSE: {metrics['avg_mse']:.4f} ± {metrics['std_mse']:.4f}")
    print(f"  Avg PSNR: {metrics['avg_psnr']:.2f} ± {metrics['std_psnr']:.2f}")
    print(f"  Avg Loss: {metrics['avg_loss']:.4f}")
    print()

In [None]:
from scipy import stats

# Test if performance differences are significant
def test_category_differences(inputs, outputs, categories):
    """Test if categories have significantly different performance"""
    
    # Get MSE for each category
    easy_mse = [calculate_mse(inputs[i], outputs[i]) for i in categories['easy']]
    medium_mse = [calculate_mse(inputs[i], outputs[i]) for i in categories['medium']]
    hard_mse = [calculate_mse(inputs[i], outputs[i]) for i in categories['hard']]
    
    # ANOVA test
    f_stat, p_value = stats.f_oneway(easy_mse, medium_mse, hard_mse)
    
    print("Statistical Testing Results:")
    print(f"ANOVA F-statistic: {f_stat:.3f}")
    print(f"ANOVA p-value: {p_value:.3f}")
    
    if p_value < 0.05:
        print("✅ Categories have significantly different performance")
    else:
        print("❌ No significant difference between categories")
    
    # Pairwise t-tests
    pairs = [('easy', 'medium'), ('easy', 'hard'), ('medium', 'hard')]
    mse_data = {'easy': easy_mse, 'medium': medium_mse, 'hard': hard_mse}
    
    print("\nPairwise comparisons:")
    for cat1, cat2 in pairs:
        t_stat, p_val = stats.ttest_ind(mse_data[cat1], mse_data[cat2])
        significance = "✅ Significant" if p_val < 0.05 else "❌ Not significant"
        print(f"{cat1} vs {cat2}: t={t_stat:.3f}, p={p_val:.3f} - {significance}")

# Run statistical tests
test_category_differences(inputs, outputs, categories)