In [1]:
import numpy as np
from pathlib import Path
import pandas as pd
from tqdm import tqdm

# --- 1. CONFIGURATION ---
PROJECT_PATH = Path('/content/drive/My Drive/AR_Downscaling')
DATA_DIR = PROJECT_PATH / 'final_dataset_multi_variable'
OUTPUT_DIR = PROJECT_PATH / 'training_data_stats'
OUTPUT_DIR.mkdir(exist_ok=True)

# Thresholds to Analyze
THRESHOLDS_K = [230.0, 220.0, 210.0]

# Dataset splits to analyze
SPLITS = ["train", "val", "test"]

def analyze_event_rarity():
    all_split_summaries = []

    for split in SPLITS:
        print(f"\n🔬 Analyzing split: {split.upper()}")

        split_dir = DATA_DIR / split
        target_files = sorted(list(split_dir.glob('*_target.npy')))
        if not target_files:
            print(f"❌ No target files found in {split_dir}. Skipping.")
            continue

        print(f"Found {len(target_files)} {split} samples to analyze.")

        # Store the frequency for each sample
        all_frequencies = []
        for target_file in tqdm(target_files, desc=f"Processing {split} samples"):
            target_k = np.load(target_file).astype(np.float32)
            total_pixels = target_k.size
            sample_freq = {'filename': target_file.name}

            for thr in THRESHOLDS_K:
                pixels_below_threshold = (target_k <= thr).sum()
                frequency = (pixels_below_threshold / total_pixels) * 100
                sample_freq[f'freq_{int(thr)}K'] = frequency

            all_frequencies.append(sample_freq)

        if not all_frequencies:
            continue

        # Convert results to DataFrame
        df = pd.DataFrame(all_frequencies)

        print("\n📊 FINAL EVENT RARITY REPORT for", split.upper())
        for thr in THRESHOLDS_K:
            col = f'freq_{int(thr)}K'
            mean_freq = df[col].mean()
            std_freq = df[col].std()
            print(f"  - T <= {int(thr)}K: {mean_freq:.4f}% (±{std_freq:.4f}%)")

        # Save per-sample results
        per_sample_path = OUTPUT_DIR / f'event_rarity_statistics_{split}.csv'
        df.to_csv(per_sample_path, index=False)
        print(f"💾 Detailed per-sample results saved to: {per_sample_path}")

        # Save summary stats for this split
        for thr in THRESHOLDS_K:
            col = f'freq_{int(thr)}K'
            all_split_summaries.append({
                'split': split,
                'threshold': thr,
                'mean_%': df[col].mean(),
                'std_%': df[col].std()
            })

    # Combine and save all splits summary
    if all_split_summaries:
        summary_df = pd.DataFrame(all_split_summaries)
        summary_path = OUTPUT_DIR / 'event_rarity_summary_all.csv'
        summary_df.to_csv(summary_path, index=False)
        print(f"\n✅ Combined summary for all splits saved to: {summary_path}")

if __name__ == "__main__":
    try:
        from google.colab import drive
        drive.mount('/content/drive', force_remount=True)
    except ImportError:
        print("Not running in Google Colab. Assuming local file paths.")

    analyze_event_rarity()


Mounted at /content/drive

🔬 Analyzing split: TRAIN
Found 1200 train samples to analyze.


Processing train samples: 100%|██████████| 1200/1200 [02:41<00:00,  7.42it/s]



📊 FINAL EVENT RARITY REPORT for TRAIN
  - T <= 230K: 46.2223% (±34.4570%)
  - T <= 220K: 18.7507% (±19.0172%)
  - T <= 210K: 6.6292% (±9.1181%)
💾 Detailed per-sample results saved to: /content/drive/My Drive/AR_Downscaling/training_data_stats/event_rarity_statistics_train.csv

🔬 Analyzing split: VAL
Found 150 val samples to analyze.


Processing val samples: 100%|██████████| 150/150 [00:09<00:00, 15.52it/s]



📊 FINAL EVENT RARITY REPORT for VAL
  - T <= 230K: 34.5756% (±34.5958%)
  - T <= 220K: 13.9523% (±19.7722%)
  - T <= 210K: 4.2478% (±8.0255%)
💾 Detailed per-sample results saved to: /content/drive/My Drive/AR_Downscaling/training_data_stats/event_rarity_statistics_val.csv

🔬 Analyzing split: TEST
Found 150 test samples to analyze.


Processing test samples: 100%|██████████| 150/150 [00:08<00:00, 16.78it/s]


📊 FINAL EVENT RARITY REPORT for TEST
  - T <= 230K: 49.2168% (±32.6427%)
  - T <= 220K: 19.6215% (±19.2401%)
  - T <= 210K: 6.2629% (±9.1381%)
💾 Detailed per-sample results saved to: /content/drive/My Drive/AR_Downscaling/training_data_stats/event_rarity_statistics_test.csv

✅ Combined summary for all splits saved to: /content/drive/My Drive/AR_Downscaling/training_data_stats/event_rarity_summary_all.csv





In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
import gc

# --- 1. CONFIGURATION ---
PROJECT_PATH = Path('/content/drive/My Drive/AR_Downscaling')
DATA_DIR = PROJECT_PATH / 'final_dataset_multi_variable'
OUTPUT_DIR = PROJECT_PATH / 'hip_eef_analysis_results'
OUTPUT_DIR.mkdir(exist_ok=True)

# --- Analysis Parameters ---
THRESHOLDS_TO_ANALYZE = [220.0, 210.0] # Kelvin

# --- 2. MAIN ANALYSIS SCRIPT ---
def analyze_data_distribution():
    """
    Iterates through the entire training dataset to analyze the distribution
    of rare event pixels.
    """
    print("--- Starting Training Data Distribution Analysis ---")

    # --- Load the training dataset ---
    # We don't need predictors, so we can modify the dataset logic slightly
    # for speed if necessary, but for now we'll use the existing class.
    train_dataset = MultiVariableARDataset(DATA_DIR, 'train')
    stats = train_dataset.stats
    mean, std = stats['target_mean'], stats['target_std'] + 1e-8

    pixel_percentages = {thresh: [] for thresh in THRESHOLDS_TO_ANALYZE}
    total_pixels_per_sample = 0

    print(f"Analyzing {len(train_dataset)} training samples...")
    pbar = tqdm(train_dataset, desc="Processing samples")
    for _, target_norm, _ in pbar:

        ground_truth_k = target_norm.numpy().squeeze() * std + mean

        if total_pixels_per_sample == 0:
            total_pixels_per_sample = ground_truth_k.size

        for thresh in THRESHOLDS_TO_ANALYZE:
            cold_pixels = (ground_truth_k <= thresh).sum()
            percentage = (cold_pixels / total_pixels_per_sample) * 100
            pixel_percentages[thresh].append(percentage)

    print("\n" + "="*80)
    print("                 Training Data Distribution Summary")
    print("="*80)

    for thresh in THRESHOLDS_TO_ANALYZE:
        percentages = np.array(pixel_percentages[thresh])
        print(f"\n--- Analysis for Threshold <= {thresh}K ---")
        print(f"  - Mean Percentage Across All Samples: {percentages.mean():.2f}%")
        print(f"  - Median Percentage Across All Samples: {np.median(percentages):.2f}%")
        print(f"  - Max Percentage in a Single Sample: {percentages.max():.2f}%")
        print(f"  - Samples with ZERO relevant pixels: {np.sum(percentages == 0)} / {len(percentages)} ({ (np.sum(percentages == 0)/len(percentages))*100:.1f}%)")
        print(f"  - Samples with < 1% relevant pixels: {np.sum(percentages < 1)} / {len(percentages)} ({ (np.sum(percentages < 1)/len(percentages))*100:.1f}%)")

    # --- Visualization ---
    fig, axes = plt.subplots(1, len(THRESHOLDS_TO_ANALYZE), figsize=(16, 6), sharey=True)
    fig.suptitle('Distribution of Extreme Event Pixels in Training Data', fontsize=20)

    for i, thresh in enumerate(THRESHOLDS_TO_ANALYZE):
        ax = axes[i]
        percentages = pixel_percentages[thresh]
        ax.hist(percentages, bins=50, range=(0, max(10, np.max(percentages))), edgecolor='black')
        ax.set_title(f'Threshold <= {thresh}K', fontsize=16)
        ax.set_xlabel('Percentage of Pixels in Sample (%)', fontsize=12)
        if i == 0:
            ax.set_ylabel('Number of Training Samples', fontsize=12)
        ax.grid(True, linestyle='--', alpha=0.6)
        ax.axvline(np.mean(percentages), color='r', linestyle='--', linewidth=2, label=f'Mean ({np.mean(percentages):.2f}%)')
        ax.legend()

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    save_path = OUTPUT_DIR / 'training_data_distribution.png'
    plt.savefig(save_path, dpi=150)
    print(f"\n✅ Distribution plot saved to: {save_path}")

    gc.collect()

if __name__ == '__main__':
    if MultiVariableARDataset is not None:
        analyze_data_distribution()
    else:
        print("Analysis script aborted due to missing dataset class.")
