In [1]:
import polars as pl
import duckdb
import os
from pathlib import Path
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
from datetime import datetime
print(f"Cell executed at: {datetime.now()}")

# This notebook creates a duckdb database from the sleep aDBS projects baseline data (i.e. overnight data collected on cDBS)
## It also filters out sessions with incorrect power band values or other incorrect settings
### Intended use: Call from 'execute_papermill_parameterized_notebook.py' with participant configuration file as argument. Can cycle through multiple devices and participants if desired, but ideally only call once per participant, so that appropriate reports are made. The duckdb table is subsequently used by integrated_rcs_analysis to create the NREM classification models.

Run with sleepclass2, sleepclass3, or bayes_opt conda envs


In [None]:
database_path = "/media/longterm_hdd/Clay/Sleep_aDBS/data/sql_databases/baseline_data.duckdb"
device_ids = [
    "RCS01L",
    "RCS01R",
]

# this is the name of the table in the duckdb database. Usually corresponds to the session type (e.g. baseline, baseline_2, sleep_profiler, etc.)
session_type = "baseline"
table_name = f"{session_type}"
session_settings_csv_template_path = "/media/longterm_hdd/Clay/Sleep_aDBS/data/{session_type}/{device}/session_settings/{session}/FftAndPowerSettings.csv"
parquet_path = "/media/longterm_hdd/Clay/Sleep_aDBS/data/{session_type}/{device_id}/time_domain_data/*.parquet"
settings_QA_dict = {
    'TDsampleRates': [500],
    'fft_interval': [1000],
    'Power_Band5': ['0.73-4.15'],
    'Power_Band6': ['4.64-12.45'],
    'Power_Band7': ['13.43-30.03'],
    'Power_Band8': ['31.01-59.81']
}
power_bands = ['Power_Band1', 'Power_Band2', 'Power_Band5', 'Power_Band6', 'Power_Band7', 'Power_Band8']
known_bad_sessions_path = "/media/longterm_hdd/Clay/Sleep_aDBS/Sleep_aDBS_sessions_to_skip.csv"
other_sessions_to_skip = []
output_path = None
parquet_columns = ['localTime',
 'DerivedTime',
 'TD_key0',
 'TD_key1',
 'TD_key2',
 'TD_key3',
 'TD_samplerate',
 'Power_FftSize',
 'Power_IsPowerChannelOverrange',
 'Power_Band1',
 'Power_Band2',
 'Power_Band3',
 'Power_Band4',
 'Power_Band5',
 'Power_Band6',
 'Power_Band7',
 'Power_Band8',
 'Accel_XSamples',
 'Accel_YSamples',
 'Accel_ZSamples',
 'Accel_samplerate',
 'SessionNumber']

In [4]:
def verify_session_settings(session_settings: dict, settings_QA_dict: dict):
    for key, value in settings_QA_dict.items():
        if key not in session_settings.keys():
            print(f"{key} not found in session_settings_df")
            return False
        if session_settings[key] != value:
            if len(session_settings[key]) > 1 and all([v == value[0] for v in session_settings[key]]):
                continue
            print(f"{key} does not match expected value")
            print(f"Expected: {value}, Actual: {session_settings[key]}")
            return False
    return True

In [5]:
known_bad_sessions = (pl.read_csv(known_bad_sessions_path)
                .with_columns(
                    (pl.col("RCS#") + pl.col("Side").str.slice(0, 1)).alias("Device"),
                    pl.col("Session#").alias("SessionNumber")
                )
)

In [None]:
conn = duckdb.connect(database_path)

identified_bad_sessions = {}

for device_id in device_ids:
    print(f"Processing {device_id}")
    identified_bad_sessions[device_id] = []
    parquet_path = parquet_path.format(session_type=session_type, device_id=device_id)
    # data = pl.read_parquet(parquet_path, columns=parquet_columns, missing_columns='insert')
    data = (
        pl.scan_parquet(parquet_path, extra_columns="ignore", missing_columns="insert")
        .collect()
    )
    remove_sessions = known_bad_sessions.filter(pl.col("Device") == device_id)
    data = data.join(remove_sessions, on="SessionNumber", how="anti")

    if len(other_sessions_to_skip) > 0:
        data = data.filter(~pl.col("SessionNumber").is_in(other_sessions_to_skip))
        
    # Get unique session numbers for this device
    session_numbers = data.select(pl.col("SessionNumber").unique()).sort("SessionNumber")
    
    if len(session_numbers) == 0:
        print(f"No session numbers found for {device_id}")
        continue
    
    print(f"Found {len(session_numbers)} sessions for {device_id}")
    
    # For each session, verify settings and mark bad sessions
    for session in session_numbers['SessionNumber']:
        print(f"\nChecking session {session}")
        
        # Format the path using the template
        session_settings_csv_path = session_settings_csv_template_path.format(
            session_type=session_type,
            device=device_id,
            session=session
        )
        
        # Check if settings file exists
        if not os.path.exists(session_settings_csv_path):
            print(f"Device {device_id}, Session {session}: Settings file not found")
            identified_bad_sessions[device_id].append(session)
            continue
        
        # Read and verify settings
        session_settings = pl.read_csv(session_settings_csv_path)
        
        # Print relevant settings for debugging
        relevant_columns = [col for col in session_settings.columns if col in settings_QA_dict.keys()]
        print("Settings found:")
        print(session_settings.select(relevant_columns))
        
        # Verify settings
        if verify_session_settings(session_settings.to_dict(as_series=False), settings_QA_dict):
            print(f"Session {session} settings verified ✓")
        else:
            print(f"Session {session} settings verification failed ✗")
            identified_bad_sessions[device_id].append(session)
                
    # Remove bad sessions from data
    if identified_bad_sessions[device_id]:
        print(f"\nRemoving {len(identified_bad_sessions[device_id])} bad sessions from {device_id}")
        data = data.filter(~pl.col("SessionNumber").is_in(identified_bad_sessions[device_id]))
    
    # Create schema and save filtered data to duckdb
    conn.execute(f"CREATE SCHEMA IF NOT EXISTS {device_id}")
    conn.execute(f"DROP TABLE IF EXISTS {device_id}.{table_name}")
    conn.execute(f"CREATE TABLE {device_id}.{table_name} AS SELECT * FROM data")
    
    print(f"\nSaved {len(data)} rows for {device_id}")

    # Plotting power bands for this device using the filtered data
    print(f"Plotting power bands for {device_id}")
    # Filter out null/nan values for power bands
    df = data.to_pandas()
    for band in power_bands:
        df = df[(df[band].notnull()) & (df[band] > 0)]
    # Create grid of scatter plots
    g = sns.PairGrid(df[power_bands])
    g.map_upper(sns.scatterplot, alpha=0.1)
    g.map_lower(sns.scatterplot, alpha=0.1)
    g.map_diag(sns.histplot)
    plt.suptitle(f'Power Band Relationships - {device_id}')
    plt.tight_layout()
    # Save plot as PNG
    os.makedirs(output_path, exist_ok=True)
    plt.savefig(f'{output_path}/{device_id}_{table_name}_power_band.png', dpi=300, bbox_inches='tight')
    plt.close()

# Print summary of bad sessions
print("\nSummary of bad sessions:")
for device_id, sessions in identified_bad_sessions.items():
    if sessions:
        print(f"{device_id}: {len(sessions)} bad sessions - {sessions}")
    else:
        print(f"{device_id}: All sessions passed verification") 
    
conn.close()
