In [25]:
import duckdb
import polars as pl
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
from pathlib import Path
from typing import List
from functools import partial


In [2]:
from datetime import datetime
print(f"Cell executed at: {datetime.now()}")

Cell executed at: 2025-09-10 14:44:57.082656


In [None]:
database_path = "/media/longterm_hdd/Clay/Sleep_aDBS/data/sql_databases/baseline_data.duckdb"
device_ids = [
    "RCS12L",
]

# this is the name of the table in the duckdb database. Usually corresponds to the session type (e.g. baseline, baseline_2, sleep_profiler, etc.)
table_name = "baseline"
columns_to_collect_stats_for = ['Power_Band5', 'Power_Band6', 'Power_Band7', 'Power_Band8']

sessions_to_ignore = []

data_points_per_second = 1
num_rows_to_exclude_from_beginning_of_session = 10

num_seconds_to_include_for_wake_mean = 1800
out_path_base = "/media/longterm_hdd/Clay/Sleep_aDBS/bayes_opt_experiments"

In [34]:
# Group by SessionNumber and process each group
def process_session_group(group_df: pl.DataFrame, num_to_exclude: int = 0, num_seconds_to_include_for_wake_mean: int = 1800) -> pl.DataFrame:
    # Exclude first num_to_exclude rows
    if len(group_df) > num_to_exclude:
        group_df = group_df.slice(num_to_exclude)
    
    if len(group_df) == 0:
        return None
    
    # Calculate time since start for filtering
    first_timestamp = group_df['localTime'].min()
    group_df = group_df.with_columns(
        (pl.col('localTime') - first_timestamp).dt.total_seconds().alias('Time_Since_Start_Seconds')
    )
    
    # Filter for first num_seconds_to_include_for_wake_mean seconds
    wake_mean_df = group_df.filter(
        pl.col('Time_Since_Start_Seconds') <= num_seconds_to_include_for_wake_mean
    )

    nrem_mean_df = group_df.filter(
        pl.col('Time_Since_Start_Seconds') >= num_seconds_to_include_for_wake_mean
    )
    
    # Calculate statistics for each column
    stats = {}
    session_number = group_df['SessionNumber'].unique().item()
    stats['SessionNumber'] = session_number
    
    for col in columns_to_collect_stats_for:
        if col in group_df.columns:
            # Calculate mean of first num_seconds_to_include_for_wake_mean seconds
            if len(wake_mean_df) > 0:
                wake_mean = wake_mean_df[col].mean()
            else:
                wake_mean = None
            
            # Calculate mean of NREM data (after wake period, where State == 2)
            if len(nrem_mean_df) > 0:
                nrem_mean = nrem_mean_df[col].mean()
            else:
                nrem_mean = None
            
            stats[f'{col}_putative_wake_mean'] = wake_mean
            stats[f'{col}_putative_nrem_mean'] = nrem_mean
        else:
            print(f"Warning: Column '{col}' not found in data")
            stats[f'{col}_putative_wake_mean'] = None
            stats[f'{col}_putative_nrem_mean'] = None
    
    return pl.DataFrame([stats])

In [35]:

def analyze_session_statistics(
    device_id: str,
    table_name: str,
    sessions_to_ignore: List[int],
    columns_to_collect_stats_for: List[str],
    num_to_exclude: int,
    num_seconds_to_include_for_wake_mean: int,
    db_path: str = None
) -> pl.DataFrame:
    """
    Connect to DuckDB table, filter sessions, and calculate statistics for specified columns.
    
    Parameters
    ----------
    table_name : str
        Name of the DuckDB table to query
    sessions_to_ignore : List[int]
        List of SessionNumber values to exclude from analysis
    columns_to_collect_stats_for : List[str]
        List of column names to calculate statistics for
    num_to_exclude : int
        Number of rows to exclude from the beginning of each session
    num_seconds_to_include_for_wake_mean : int
        Number of seconds to include for wake mean calculation (from start of session)
    db_path : str, optional
        Path to DuckDB database file. If None, uses in-memory database.
    
    Returns
    -------
    pl.DataFrame
        DataFrame with session-level statistics including mean of first N seconds and top quartile
    """
    
    # Connect to DuckDB
    conn = duckdb.connect(db_path)
    
    try:
        # Query data excluding specified sessions
        if len(sessions_to_ignore) > 0:
            query = f"""
            SELECT * 
            FROM {device_id}.{table_name} 
            WHERE SessionNumber NOT IN ({','.join(map(str, sessions_to_ignore))})
            ORDER BY SessionNumber, localTime
            """
        else:
            query = f"""
            SELECT * 
            FROM {device_id}.{table_name} 
            ORDER BY SessionNumber, localTime
            """
        
        # Execute query and convert to Polars DataFrame
        df = conn.execute(query).pl()
        
        if len(df) == 0:
            print("Warning: No data found after filtering sessions")
            return pl.DataFrame()

        process_session_group_partial = partial(process_session_group, num_to_exclude=num_to_exclude, num_seconds_to_include_for_wake_mean=num_seconds_to_include_for_wake_mean)
        
        # Apply processing to each session group
        result_df = (
            df.group_by('SessionNumber')
            .map_groups(process_session_group_partial)
            .filter(pl.col('SessionNumber').is_not_null())
        )
        
        return result_df
        
    finally:
        conn.close()

In [36]:
for device_id in device_ids:
    print(f"Processing {device_id}")
    out_path = f"{out_path_base}/{device_id[:-1]}/{device_id}"

    result_df = analyze_session_statistics(
        device_id=device_id,
        table_name=table_name,
        sessions_to_ignore=sessions_to_ignore,
        columns_to_collect_stats_for=columns_to_collect_stats_for,
        num_to_exclude=num_rows_to_exclude_from_beginning_of_session,
        num_seconds_to_include_for_wake_mean=num_seconds_to_include_for_wake_mean,
        db_path=database_path
    )


    result_df.write_csv(f"{out_path}/cluster_means_init_stats.csv")

Processing RCS12L


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

In [37]:
result_df

SessionNumber,Power_Band5_putative_wake_mean,Power_Band5_putative_nrem_mean,Power_Band6_putative_wake_mean,Power_Band6_putative_nrem_mean,Power_Band7_putative_wake_mean,Power_Band7_putative_nrem_mean,Power_Band8_putative_wake_mean,Power_Band8_putative_nrem_mean
str,f64,f64,f64,f64,f64,f64,f64,f64
"""Session1754975513148""",15090.402,28934.22294,4882.942,9937.58139,16203.83,3889.558426,5441.432,1096.784715
"""Session1752041037703""",14619.681208,24842.613076,3484.001678,9897.633815,6287.753356,4508.64438,2898.788591,1729.514015
"""Session1744175574110""",22810.658318,30828.642317,4051.21288,12829.830284,6719.159213,4947.90679,2878.921288,1470.348899
"""Session1751437402486""",16000.105536,23974.069097,4498.937716,12170.876816,7749.484429,6727.654714,3445.33737,2391.242988
"""Session1751523708876""",20055.852792,27069.312009,3608.964467,10703.863207,4677.634518,5297.98092,2555.153976,1834.999208
"""Session1750830543442""",28535.569514,28535.482367,11988.336683,13734.09842,330526.477387,45805.595228,239246.358459,30113.289452
"""Session1748501008221""",21295.37604,21611.153337,4385.091514,11007.666039,8038.831947,9533.734766,1951.49584,5092.24314
"""Session1751955169272""",16181.185315,23008.234809,3172.704545,12065.386546,6697.059441,25590.954137,2858.04021,18020.816998
"""Session1752300193017""",14266.222037,30320.971118,4009.520868,20621.034543,9002.883139,87565.680406,2678.440735,64732.796807


In [54]:
result_summary = result_df.describe()
result_summary.head(10)

statistic,SessionNumber,Power_Band5_putative_wake_mean,Power_Band5_putative_nrem_mean,Power_Band6_putative_wake_mean,Power_Band6_putative_nrem_mean,Power_Band7_putative_wake_mean,Power_Band7_putative_nrem_mean,Power_Band8_putative_wake_mean,Power_Band8_putative_nrem_mean
str,str,f64,f64,f64,f64,f64,f64,f64,f64
"""count""","""9""",9.0,9.0,9.0,9.0,9.0,9.0,9.0,9.0
"""null_count""","""0""",0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
"""mean""",,18761.672529,26569.41123,4897.968039,12551.996784,43989.234825,21540.856641,29328.218719,14053.55958
"""std""",,4810.749231,3333.670152,2711.999233,3291.035814,107501.153704,28436.072356,78725.290249,21464.343934
"""min""","""Session1744175574110""",14266.222037,21611.153337,3172.704545,9897.633815,4677.634518,3889.558426,1951.49584,1096.784715
"""25%""",,15090.402,23974.069097,3608.964467,10703.863207,6697.059441,4947.90679,2678.440735,1729.514015
"""50%""",,16181.185315,27069.312009,4051.21288,12065.386546,7749.484429,6727.654714,2878.921288,2391.242988
"""75%""",,21295.37604,28934.22294,4498.937716,12829.830284,9002.883139,25590.954137,3445.33737,18020.816998
"""max""","""Session1754975513148""",28535.569514,30828.642317,11988.336683,20621.034543,330526.477387,87565.680406,239246.358459,64732.796807


In [53]:
for_means_init = np.round(np.concatenate([result_summary.filter(pl.col("statistic") == "50%").select(pl.col("^.*putative_wake.*$")).to_numpy(), result_summary.filter(pl.col("statistic") == "50%").select(pl.col("^.*putative_nrem.*$")).to_numpy()])).astype(int)
for_means_init

array([[16181,  4051,  7749,  2879],
       [27069, 12065,  6728,  2391]])