In [None]:
# pip install matplotlib h5py scipy seaborn scikit-learn --no-index

In [None]:
import os, sys
from pathlib import Path
project_root = Path.cwd().parent 
# sys.path.append(str(project_##root))

BASE_DIR = Path.cwd().parent  # go up one level
raw_input_dir = BASE_DIR / "data" / "threshold_sweep" / "l1" / "slam" / "bernoulli"

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import h5py
# Import the analysis module
from mirrored_langevin_rnn.utils.data_pipeline.threshold_sweep_analysis import (
    load_threshold_batch_files,
    plot_threshold_heatmap,
    merge_threshold_batches
)

## Configuration

Set up the experiment configuration, including file patterns and parameters.

In [None]:
# out_dir = BASE_DIR / "data" / "threshold_sweep" / "slam" / "auc" / "bernoulli" / "affinity_dense_gamma"
# out_dir = BASE_DIR / "data" / "threshold_sweep" / "slam" / "auc" / "bernoulli" / "affinity_sparse_gamma_sparsity_0.10"
# out_dir = BASE_DIR / "data" / "threshold_sweep" / "slam" / "auc" / "bernoulli" / "affinity_sparse_binary_sparsity_0.10"
out_dir = BASE_DIR / "data" / "threshold_sweep" / "slam" / "auc" / "kumaraswamy" / "affinity_dense_gamma"
# out_dir = BASE_DIR / "data" / "threshold_sweep" / "slam" / "auc" / "kumaraswamy" / "affinity_sparse_binary_sparsity_0.10"
# out_dir = BASE_DIR / "data" / "threshold_sweep" / "slam" / "auc" / "kumaraswamy" / "affinity_sparse_gamma_sparsity_0.10"
# out_dir = BASE_DIR / "data" / "threshold_sweep" / "poisson" / "rank" / "affinity_dense_gamma"
# out_dir = BASE_DIR / "data" / "threshold_sweep" / "poisson" / "rank" / "affinity_sparse_binary_sparsity_0.10"
# out_dir = BASE_DIR / "data" / "threshold_sweep" / "poisson" / "rank" / "affinity_sparse_gamma_sparsity_0.10"

In [None]:
# FILE_PATTERN = "slam_auc_bernoulli_threshold_results_batch*.h5"
FILE_PATTERN = "slam_auc_kumaraswamy_threshold_results_batch*.h5"
# FILE_PATTERN = "poisson_rank_binary_threshold_results_batch*.h5"
# FILE_PATTERN = "poisson_rank_threshold_results_batch*.h5"

BATCH_SIZE = 1  


## Load and Process Threshold Sweep Data

First, we'll load all the batch files, align them into a common grid, and merge them into a complete grid.

In [None]:
# Get the first file matching our pattern for inspection
pattern_files = sorted(out_dir.glob(FILE_PATTERN))
if pattern_files:
    first_file = pattern_files[0]
    print(f"Inspecting first file: {first_file.name}")
    with h5py.File(first_file, "r") as f:
        print(f.keys())
        print(f.attrs.keys())
        grid = f["grid"]
        print(f"Grid shape: {grid.shape}")
        print(f"Grid values: {grid}") 
        grid_np = np.array(grid)
        print("Grid as numpy array:")
        print(grid_np)
        # n_sens_values = f["n_sens_values"][:]
        # n_odor_values = f["n_odor_values
else:
    print(f"No files found matching pattern: {FILE_PATTERN}")

In [None]:
# Load all batch files and align them to a common grid
grids, files, n_sens_values, n_odor_values = load_threshold_batch_files(
    out_dir, 
    pattern=FILE_PATTERN
)

print(f"Loaded {len(files)} batch files")
print(f"Grid shape: {grids.shape}")
print(f"nSens values: {n_sens_values}")
print(f"nOdor values: {n_odor_values}")

In [None]:
# Merge the grids into a complete grid using the first valid value at each position
merged_grid, merged_path = merge_threshold_batches(
    out_dir, 
    pattern=FILE_PATTERN,
    output_file="threshold_results_merged.h5",
    merge_method="first_valid"  # Use first valid value (alternatives: "max", "min", "mean")
)

print(f"Merged grid shape: {merged_grid.shape}")
print(f"Saved to: {merged_path}")

# Count non-NaN values to check grid completeness
non_nan_count = np.count_nonzero(~np.isnan(merged_grid))
total_cells = merged_grid.size
print(f"Grid completeness: {non_nan_count}/{total_cells} cells filled ({non_nan_count/total_cells:.1%})")

# Print a small sample of the grid
print("\nSample of merged grid (first 5x5 section):")
print(merged_grid[:5, :5])

## Visualize the Merged Threshold Sweep Results

Now we'll create a heatmap visualization of the merged grid with contour lines.

In [None]:
# Create a figure and plot the heatmap
from matplotlib.colors import Normalize


plt.figure(figsize=(8, 8))
fig, ax = plot_threshold_heatmap(
    merged_grid,
    n_odor_values,
    n_sens_values,
    figsize=(7,6),
    cmap="Blues",
    sigma=1,  # Gaussian smoothing parameter
    contour_levels = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100],
    # contour_levels=(10, 20, 30),
    highlight_level=20,
    annot=False
)
ax.set_xlim([1000, 16000])
norm = Normalize(vmin=5, vmax=100)
def colorbar_config(norm, fig_rank, ax_rank):
    cbar_rank = fig_rank.colorbar(ax_rank.collections[0], ax=ax_rank, shrink=0.9, aspect=9)
    cbar_rank.mappable.set_norm(norm)
    cbar_rank.set_ticks([10, 20, 30, 40, 50, 60, 70, 80, 90, 100])
    cbar_rank.ax.tick_params(width=0)
    cbar_rank.set_ticklabels(["", '20', "", '40', "", '60', "", '80', "", '100+'], fontsize=16)
    cbar_rank.outline.set_linewidth(2)
    return cbar_rank

cbar_slam = colorbar_config(norm, fig, ax)
def add_contour_lines_to_colorbar(colorbar, levels, norm, label_offset=0.02):
    for level in levels:
        # normalized_level = norm(level)
        if level == 20:
            colorbar.ax.axhline(level, xmin=0.75, xmax=1, color='yellow', linestyle='solid', linewidth=2)
        else:
            colorbar.ax.axhline(level, xmin=0.75, xmax=1, color='black', linestyle='solid', linewidth=2)
contour_levels = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
add_contour_lines_to_colorbar(cbar_slam, contour_levels, norm)

cbar_slam.ax.set_visible(False)

plt.savefig(out_dir / "threshold_heatmap.png", dpi=300, bbox_inches="tight")
plt.show()

## Check for Missing Data

Let's check if there are any NaN values in our merged grid and visualize where they are.

In [None]:
# Check for NaN values in the merged grid
nan_mask = np.isnan(merged_grid)
nan_count = np.sum(nan_mask)
print(f"Number of NaN values in merged grid: {nan_count} out of {merged_grid.size} ({nan_count/merged_grid.size:.1%})")
# Print indices where data is missing
nan_indices = np.where(nan_mask)
print(f"Missing data at grid positions (row, col): {list(zip(nan_indices[0], nan_indices[1]))}")

if nan_count > 0:
    # Visualize where NaN values are
    plt.figure(figsize=(8, 6))
    X, Y = np.meshgrid(n_odor_values, n_sens_values)
    plt.pcolormesh(X, Y, nan_mask, cmap="binary", shading="auto")
    plt.colorbar(label="NaN present")
    plt.xscale("log")
    plt.xlabel("nOdor")
    plt.ylabel("nSens")
    plt.title("Missing Data in Merged Grid")
    plt.savefig(out_dir / "missing_data_mask.png", dpi=300, bbox_inches="tight")
    plt.show()

## Load and Plot from the Merged Grid File

This section demonstrates how to load a previously saved merged grid and create a visualization.

In [None]:
# Function to load a grid from an HDF5 file
def load_grid_from_file(file_path):
    """Load a grid and its axis values from an HDF5 file."""
    with h5py.File(file_path, "r") as f:
        grid = f["grid"][:]
        n_odor_vals = list(np.asarray(f.attrs["nOdor_values"], dtype=int))
        n_sens_vals = list(np.asarray(f.attrs["nSens_values"], dtype=int))
    return grid, n_odor_vals, n_sens_vals

# Load from the saved merged file
grid_file = out_dir / "threshold_results_merged.h5"
loaded_grid, loaded_n_odor, loaded_n_sens = load_grid_from_file(grid_file)

print(f"Loaded grid shape: {loaded_grid.shape}")
print(f"nOdor values: {loaded_n_odor}")
print(f"nSens values: {loaded_n_sens}")

In [None]:
# Plot the loaded grid with different visualization parameters
plt.figure(figsize=(8, 8))
ax = plot_threshold_heatmap(
    loaded_grid,
    loaded_n_odor,
    loaded_n_sens,
    cmap="viridis",  # Different colormap to show variation
    sigma=0.5,       # Less smoothing
    contour_levels=(10, 15, 20, 25, 30),  # More contour levels
    highlight_level=20
)

plt.title("Threshold Heatmap from Loaded File", fontsize=16)
plt.savefig(out_dir / "threshold_heatmap_loaded.png", dpi=300, bbox_inches="tight")
plt.show()

In [None]:
## Check for Missing SLURM Array Job Indices

# This function uses NaN values in the merged grid to determine which 
# SLURM job array indices are missing, matching the actual experiment slicing

def check_missing_job_indices(out_dir, pattern, batch_size):
    """
    Check for missing SLURM array job indices by examining NaN values in the merged grid.
    """
    # Load the merged grid
    try:
        merged_grid, _ = merge_threshold_batches(
            out_dir, 
            pattern=pattern,
            output_file="threshold_results_merged.h5",
            merge_method="first_valid"
        )
        print(f"Loaded merged grid with shape: {merged_grid.shape}")
    except Exception as e:
        print(f"Error loading merged grid: {e}")
        return
    
    # Get actual parameter values from files
    try:
        _, _, n_sens_actual, n_odor_actual = load_threshold_batch_files(out_dir, pattern=pattern)
        n_sens_list = list(n_sens_actual)
        n_odor_list = list(n_odor_actual)
        print(f"Parameter space: {len(n_odor_list)} odor × {len(n_sens_list)} sensor values")
    except Exception as e:
        print(f"Error loading parameter values: {e}")
        return
    
    # Generate parameter combinations in experiment order
    combos = [(o, s) for o in n_odor_list for s in n_sens_list]
    total_combinations = len(combos)
    expected_jobs = (total_combinations + batch_size - 1) // batch_size
    
    print(f"\nMapping info:")
    print(f"  Total combinations: {total_combinations}")
    print(f"  Batch size: {batch_size}")
    print(f"  Expected jobs: {expected_jobs} (indices 0-{expected_jobs-1})")
    
    # Find missing combinations by checking NaN values
    missing_combinations = []
    for i, (odor_val, sens_val) in enumerate(combos):
        try:
            odor_idx = n_odor_list.index(odor_val)
            sens_idx = n_sens_list.index(sens_val)
            
            if np.isnan(merged_grid[sens_idx, odor_idx]):
                missing_combinations.append((i, odor_val, sens_val))
        except ValueError:
            missing_combinations.append((i, odor_val, sens_val))
    
    # Map missing combinations to job indices
    missing_job_indices = set()
    for combo_idx, _, _ in missing_combinations:
        job_idx = combo_idx // batch_size
        missing_job_indices.add(job_idx)
    
    missing_job_indices = sorted(missing_job_indices)
    
    print(f"\nResults:")
    print(f"  Missing parameter combinations: {len(missing_combinations)}")
    print(f"  Missing job indices: {len(missing_job_indices)}")
    
    if missing_combinations:
        print(f"\nFirst 10 missing combinations (combo_idx, odor, sensor):")
        for combo_idx, odor_val, sens_val in missing_combinations[:10]:
            job_idx = combo_idx // batch_size
            pos_in_job = combo_idx % batch_size
            print(f"  {combo_idx:3d}: ({odor_val:5d}, {sens_val:3d}) → job {job_idx}, pos {pos_in_job}")
        if len(missing_combinations) > 10:
            print(f"  ... and {len(missing_combinations) - 10} more")
    
    if missing_job_indices:
        # Group consecutive indices
        ranges = []
        start = missing_job_indices[0]
        end = start
        
        for idx in missing_job_indices[1:]:
            if idx == end + 1:
                end = idx
            else:
                if start == end:
                    ranges.append(str(start))
                else:
                    ranges.append(f"{start}-{end}")
                start = end = idx
        
        if start == end:
            ranges.append(str(start))
        else:
            ranges.append(f"{start}-{end}")
        
        print(f"\nMissing job indices (ranges): {', '.join(ranges)}")
        print(f"Missing job indices (list): {missing_job_indices}")
    else:
        print("\n✓ No missing job indices - all combinations completed!")
    
    # Grid statistics
    total_cells = merged_grid.size
    nan_cells = np.sum(np.isnan(merged_grid))
    completeness = (total_cells - nan_cells) / total_cells * 100
    
    print(f"\nGrid completeness: {completeness:.1f}% ({total_cells - nan_cells}/{total_cells})")
    
    return {
        'missing_combinations': missing_combinations,
        'missing_job_indices': missing_job_indices,
        'total_combinations': total_combinations,
        'expected_jobs': expected_jobs,
        'grid_completeness': completeness
    }

# Run the analysis
result = check_missing_job_indices(
    out_dir, 
    pattern=FILE_PATTERN,
    batch_size=BATCH_SIZE,
)

## Export Grid Data for Further Analysis

This section shows how to export the grid data to other formats.

In [None]:
# Export grid to CSV for use in other tools
def export_grid_to_csv(grid, n_odor_vals, n_sens_vals, output_path):
    """Export grid data to CSV format with row/column headers."""
    import pandas as pd
    
    # Create DataFrame with proper indices
    df = pd.DataFrame(grid, index=n_sens_vals, columns=n_odor_vals)
    
    # Save to CSV
    df.to_csv(output_path)
    print(f"Grid exported to {output_path}")
    
    return df

# Example usage
csv_path = out_dir / "threshold_grid_merged.csv"
df = export_grid_to_csv(merged_grid, n_odor_values, n_sens_values, csv_path)

# Display the first few rows of the DataFrame
df.head()