In [4]:
import numpy as np
import polars as pl 
from pathlib import Path
import sys
from typing import Optional 

# Dynamically add the project root to sys.path
project_root = Path().resolve().parent  # Go up one level from notebooks
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

from src.utils.constants import FREQUENCIES


In [11]:
def load_audiogram_data(path: Path = "/Users/jasonbrant/python_projects/adaptive-audiogram/data/anonymized_cleaned_data.parquet") -> pl.DataFrame:
    """
    Load audiogram data from a CSV file.
    """
    df = pl.read_parquet(path)
    return df

In [20]:
def compute_frequency_ear_means(
    df: Optional[pl.DataFrame] = None, frequencies: list = FREQUENCIES
) -> np.ndarray:
    """
    Computes mean values per Frequency and Ear, returning as a 1D Numpy array.

    Args:
        df (Optional[pl.DataFrame]): Input Polars DataFrame with 'Frequency', 'Ear', and 'Value' columns.
                                     If None, the default dataset will be loaded.
        frequencies (list): List of frequency values to include.
        default_value (float): Default value to fill for missing frequency-ear combinations.

    Returns:
        np.ndarray: 1D array of mean values in the order:
                    [250_L, 250_R, 500_L, 500_R, ..., 8000_L, 8000_R].
    """
    if df is None:
        df = load_audiogram_data()

    # Fill nulls with mean per Frequency and Ear
    df_filled = df.with_columns(
        pl.col("Value").fill_null(pl.col("Value").mean().over(["Frequency", "Ear"]))
    )

    # Filter to specified frequencies
    df_filtered = df_filled.filter(pl.col("Frequency").is_in(frequencies))

    # Group by Frequency and Ear, compute means
    df_grouped = (
        df_filtered.group_by(["Frequency", "Ear"])
        .agg(pl.col("Value").mean().alias("Mean_Value"))
        .with_columns(
            (pl.col("Frequency").cast(pl.Utf8) + "_" + pl.col("Ear")).alias("Freq_Ear")
        )
    ).sort(["Frequency", "Ear"])
    
    return df_grouped["Mean_Value"].to_numpy()


In [21]:
test = compute_frequency_ear_means()
print(test)

[17.42306374 17.03571961 22.61782892 22.20711297 23.67202301 23.18448536
 30.33091908 29.19593821 36.46905178 34.93554267 48.15445128 46.81536024]
