# Data pipeline for DiveDB
Uses classes `info` and `DataReader` to facilitate data intake, processing, and alignment. 

In [1]:
# Import libraries and set working directory (adjust to fit your preferences)
import os
import sys
import numpy as np
import pandas as pd
import pytz
import matplotlib.pyplot as plt
from notion_client import Client
from dotenv import load_dotenv
from pyologger.load_data.datareader import DataReader
from pyologger.load_data.metadata import Metadata
from pyologger.plot_data.plotter import plot_tag_data
import pickle

# Change the current working directory to the root directory
# os.chdir("/Users/fbar/Documents/GitHub/pyologger")
os.chdir("/Users/jessiekb/Documents/GitHub/pyologger")

root_dir = os.getcwd()
data_dir = os.path.join(root_dir, "data")

# Verify the current working directory
print(f"Current working directory: {root_dir}")

Current working directory: /Users/jessiekb/Documents/GitHub/pyologger


### Query metadata
Use Notion and [info entry form](https://forms.fillout.com/t/8UNuTLMaRfus) to start a recording and to generate identifiers for the Recording and Deployment. 


In [2]:
# Initialize the info class
metadata = Metadata()
metadata.fetch_databases(verbose=False)

# Save databases
dep_db = metadata.get_metadata("dep_DB")
logger_db = metadata.get_metadata("logger_DB")
rec_db = metadata.get_metadata("rec_DB")
animal_db = metadata.get_metadata("animal_DB")

Loaded Notion secret token.
Loaded database ID for dep_DB.
Loaded database ID for rec_DB.
Loaded database ID for logger_DB.
Loaded database ID for animal_DB.


### Steps for Processing Deployment Data:

1. **Select Deployment Folder**:
   - **Description:** Asks the user for input to select a deployment folder to kick off the data reading process. In your folder name, you can have any suffix after Deployment ID. It will check and stop if there are two that fit.
   - **Function Used:** `check_deployment_folder()`

2. **Initialize Deployment Folder**:
   - **Description:** Starts the main `read_files` process with the selected deployment folder.
   - **Function Used:** `read_files()`

3. **Fetch Metadata**:
   - **Description:** Retrieve necessary data from the metadata database, including logger information.
   - **Function Used:** `metadata.fetch_databases()`

4. **Organize Files by Logger ID**:
   - **Description:** Group files by logger ID for processing.
   - **Function Used:** `read_files()` (This is the main function)

5. **Check for Existing Processed Files**:
   - **Description:** Verify if the outputs folder already contains processed files for each logger. Skip reprocessing if all necessary files are present.
   - **Function Used:** `check_outputs_folder()`

6. **Process UBE Files**:
   - **Description:** For each UFI logger with UBE files, process and save the data.
   - **Function Used:** `process_ube_file()`

7. **Process CSV Files**:
   - **Description:** For each logger with multiple CSV files, concatenate them, and save the combined data.
   - **Function Used:** `concatenate_and_save_csvs()`

8. **Final Outputs**:
   - **Description:** Ensure all processed data is saved in the outputs folder with appropriate filenames.
   - **Functions Used:** `save_data()`

In [3]:
# Find your deployment ID index and remember it for the next cell, where you have to enter it.
dep_db

Unnamed: 0,Rec Date,Notes,Animal,Start Time Precision,End Time,Actual Start Time,Time Zone,Start time,Deployment Name
0,2024-06-19,Third day of deployments at SeaWorld with Cork...,orca,,,,America/Los_Angeles,08:00:00,2024-06-19_oror-001a
1,2024-06-18,Deployment second day at Sea World,orca,,,,America/Los_Angeles,,2024-06-18_oror-001a
2,2024-06-17,ECG recordings with Ashley and Paul at SeaWorl...,orca,,,,America/Los_Angeles,,2024-06-17_oror-002-001a
3,2024-01-16,Shuka ECG and CATS,orca,,,,America/Los_Angeles,,2024-01-16_oror-002a
4,2024-06-06,Boat calibrations with Bill Hagey,boat,Approximative,,,America/Los_Angeles,09:31:30,2024-06-06_boat-001a


In [3]:
# Assuming you have the metadata and dep_db loaded:
datareader = DataReader()
deployment_folder = datareader.check_deployment_folder(dep_db, data_dir)

if deployment_folder:
    datareader.read_files(metadata, save_csv=True, save_parq=True)

Step 1: Displaying deployments to help you select one.
            Deployment Name                                              Notes
0      2024-06-19_oror-001a  Third day of deployments at SeaWorld with Cork...
1      2024-06-18_oror-001a                 Deployment second day at Sea World
2  2024-06-17_oror-002-001a  ECG recordings with Ashley and Paul at SeaWorl...
3      2024-01-16_oror-002a                                 Shuka ECG and CATS
4      2024-06-06_boat-001a                  Boat calibrations with Bill Hagey
Step 1: You selected the deployment: 2024-01-16_oror-002a
Description: Shuka ECG and CATS
Step 2: Deployment folder path: /Users/jessiekb/Documents/GitHub/pyologger/data/2024-01-16_oror-002a
Folder /Users/jessiekb/Documents/GitHub/pyologger/data/2024-01-16_oror-002a not found. Searching for folders with a similar name...
Using the found folder: /Users/jessiekb/Documents/GitHub/pyologger/data/2024-01-16_oror-002a_Shuka-HR
Ready to process deployment folder: /Users/jes

In [5]:
# Optionally look at first notes that have been read in
#datareader.selected_deployment['Time Zone']
#datareader.info['UF-01']
datareader.notes_df[0:5]
#datareader.data['CC-96']
#datareader.data['UF-01']
#datareader.metadata['channelnames']

Unnamed: 0,date,time,type,key,value,short_description,long_description,datetime,datetime_utc,time_unix_ms,sec_diff
0,2024-01-16,10:01:00.170000,point,heartbeat_manual_ok,63.829787,heartbeat detection,,2024-01-16 10:01:00.170000-08:00,2024-01-16 18:01:00.170000+00:00,1705428060170,-334.86
1,2024-01-16,10:01:02.780000,point,heartbeat_manual_reject,,heartbeat detection,,2024-01-16 10:01:02.780000-08:00,2024-01-16 18:01:02.780000+00:00,1705428062780,2.61
2,2024-01-16,10:01:03.959000,point,exhalation_breath,,exhalation followed by breath,start exhale Breath; snapshots for first brea...,2024-01-16 10:01:03.959000-08:00,2024-01-16 18:01:03.959000+00:00,1705428063959,
3,2024-01-16,10:01:07,point,heartbeat_manual_reject,,heartbeat detection,,2024-01-16 10:01:07-08:00,2024-01-16 18:01:07+00:00,1705428067000,4.22
4,2024-01-16,10:01:07.830000,point,heartbeat_manual_ok,72.289157,heartbeat detection,,2024-01-16 10:01:07.830000-08:00,2024-01-16 18:01:07.830000+00:00,1705428067830,0.83


In [None]:
datareader.data['CC-96']

In [None]:
#data_test = pd.read_csv(os.path.join(deployment_folder, "2024-01-16_oror-002a_CC-96_001.csv"))
datareader.selected_deployment['Time Zone']

### Inspect the pickle file output

Load in the generated pickle file to inspect the output.

In [4]:
# Load the data_reader object from the pickle file
pkl_path = os.path.join(deployment_folder, 'outputs', 'data.pkl')

with open(pkl_path, 'rb') as file:
    data_pkl = pickle.load(file)

for logger_id, info in data_pkl.info.items():
    sampling_frequency = info.get('datetime_metadata', {}).get('fs', None)
    if sampling_frequency is not None:
        # Format the sampling frequency to 5 significant digits
        print(f"Sampling frequency for {logger_id}: {sampling_frequency} Hz")
    else:
        print(f"No sampling frequency available for {logger_id}")

data_pkl.info['CC-96']['datetime_metadata']['fs']

Sampling frequency for CC-96: 400 Hz
Sampling frequency for UF-01: 100 Hz


'400'

In [5]:
data_pkl.info['CC-96']['dives']

KeyError: 'dives'

In [None]:
data_pkl.notes_df[0:5]

In [None]:
data_pkl.data['UF-01']  #data['UF-01'][0:5] # browse column names

In [None]:
data_pkl.data['CC-96'][0:5] # browse column names

### Pre-process data for plots
Downsample high-resolution data and filter notes down to notes of interest to include in plot.

In [None]:
data_pkl.info['CC-96']['datetime_metadata']['fs']

In [None]:
data_pkl.data['CC-96']['depth'].values

## Plot
### Generic plot function

## Calibrations
### Check Accel and Mag Function

In [None]:
import numpy as np
from scipy.signal import butter, filtfilt

def check_AM(A, M=None, fs=None, find_incl=True):
    """
    Compute field intensity of acceleration and magnetometer data,
    and the inclination angle of the magnetic field (in degrees).
    
    Parameters:
    A (numpy.ndarray or dict): An accelerometer sensor matrix with columns [ax, ay, az].
                               Can also be a sensor data dictionary.
    M (numpy.ndarray or None): A magnetometer sensor matrix with columns [mx, my, mz]. Optional.
    fs (float or None): The sampling rate of the sensor data in Hz. Required if A is not a sensor dictionary.
    find_incl (bool): Whether to compute and return the inclination angle. Defaults to True.
    
    Returns:
    dict or numpy.ndarray: 
        - If find_incl is False: returns the field intensity as a numpy array.
        - If find_incl is True: returns a dictionary with field intensity and inclination angle (in degrees).
    """
    
    fc = 5  # low-pass filter frequency in Hz

    if isinstance(A, dict):
        if M is not None:
            if A['sampling_rate'] == M['sampling_rate'] and len(A['data']) == len(M['data']):
                fs = A['sampling_rate']
                A = A['data']
                M = M['data']
        else:
            fs = A['sampling_rate']
            A = A['data']
        if len(A) == 0:
            raise ValueError("No data found in input argument A")
    else:
        if M is None and fs is None:
            raise ValueError("Sampling rate (fs) is required if A is not a sensor dictionary")
        if fs is None:
            raise ValueError("Need to specify sampling frequency for matrix arguments")

    # Handle single vector inputs
    if M is not None and M.ndim == 1:
        M = M.reshape(1, -1)
    if A.ndim == 1:
        A = A.reshape(1, -1)

    # Check that sizes of A and M are compatible
    if M is not None and A.shape[0] != M.shape[0]:
        n = min(A.shape[0], M.shape[0])
        A = A[:n, :]
        M = M[:n, :]

    # Low-pass filter the data if sampling rate is greater than 10 Hz
    if fs > 10:
        nf = int(round(4 * fs / fc))
        if A.shape[0] > nf:
            b, a = butter(4, fc / (fs / 2), btype='low')
            A = filtfilt(b, a, A, axis=0)
            if M is not None:
                M = filtfilt(b, a, M, axis=0)

    # Compute field intensity of the first input argument (A)
    fstr = np.sqrt(np.sum(A**2, axis=1))
    fstr = fstr.reshape(-1, 1)

    if M is not None:
        # Compute field intensity of the second input argument (M)
        fstr2 = np.sqrt(np.sum(M**2, axis=1))
        fstr2 = fstr2.reshape(-1, 1)
        fstr = np.hstack((fstr, fstr2))

    if find_incl and M is not None:
        AMprod = np.sum(A * M, axis=1)
        incl = -np.degrees(np.arcsin(AMprod / (fstr[:, 0] * fstr[:, 1])))
        return {'fstr': fstr, 'incl': incl}
    else:
        return fstr


### Check Accel and Mag implementation

In [None]:
import numpy as np

# Assuming `data_pkl` is already loaded and contains your data
accX = data_pkl.data['CC-96']['accX'].values
accY = data_pkl.data['CC-96']['accY'].values
accZ = data_pkl.data['CC-96']['accZ'].values
magX = data_pkl.data['CC-96']['magX'].values
magY = data_pkl.data['CC-96']['magY'].values
magZ = data_pkl.data['CC-96']['magZ'].values

# Combine the accelerometer and magnetometer data into nx3 matrices
acc_data = np.vstack((accX, accY, accZ)).T
mag_data = np.vstack((magX, magY, magZ)).T

# Get the sampling rate from the data structure
sampling_rate = int(data_pkl.info['CC-96']['datetime_metadata']['fs'])

# Call the check_AM function
AMcheck = check_AM(acc_data, mag_data, sampling_rate)

# Access the field intensity and inclination angle
field_intensity_acc = AMcheck['fstr'][:, 0]  # Field intensity of accelerometer data
field_intensity_mag = AMcheck['fstr'][:, 1]  # Field intensity of magnetometer data
inclination_angle = AMcheck['incl']

# Print results
print("Field Intensity:\n", field_intensity_acc)
print("Field Intensity:\n", field_intensity_mag)
print("Inclination Angle (degrees):\n", inclination_angle)



In [None]:
# Save new fields to data_pkl.data['CC-96']
data_pkl.data['CC-96']['field_intensity_acc'] = field_intensity_acc
data_pkl.data['CC-96']['field_intensity_mag'] = field_intensity_mag
data_pkl.data['CC-96']['inclination_angle'] = inclination_angle

# Example usage in Streamlit
imu_channels_to_plot = ['accX', 'accY', 'accZ', 'field_intensity_acc', 'field_intensity_mag', 'inclination_angle']
ephys_channels_to_plot = []
imu_logger_to_use = 'CC-96'
ephys_logger_to_use = 'UF-01'

# Get the overlapping time range
imu_df = data_pkl.data[imu_logger_to_use]
ephys_df = data_pkl.data[ephys_logger_to_use]
start_time = max(imu_df['datetime'].min(), ephys_df['datetime'].min()).to_pydatetime()
end_time = min(imu_df['datetime'].max(), ephys_df['datetime'].max()).to_pydatetime()

plot_tag_data_interactive(data_pkl, imu_channels_to_plot, ephys_channels=ephys_channels_to_plot, imu_logger=imu_logger_to_use, ephys_logger=ephys_logger_to_use, time_range=(start_time, end_time))



### Fix offset


In [None]:
import numpy as np
from numpy.linalg import solve, inv
from scipy.linalg import lstsq

def fix_offset_3d(X):
    """
    Estimate the offset in each axis of a triaxial field measurement.

    Parameters:
    X (numpy.ndarray or dict): A sensor matrix or dictionary containing measurements from a triaxial field sensor.

    Returns:
    dict: A dictionary with two elements:
        - 'X': The adjusted triaxial sensor measurements (same size and units as input).
        - 'G': A calibration dictionary containing the offset added to each column of X.
    """
    poly1 = np.ones((3, 1))
    poly2 = np.zeros((3, 1))
    poly = np.hstack((poly1, poly2))
    G = {'poly': poly}

    if X is None:
        raise ValueError("Input for X is required")

    if isinstance(X, dict):
        x = X['data']
    else:
        x = X

    if x.shape[1] != 3:
        raise ValueError("Input data must be from a 3-axis sensor")

    # Filter valid (complete) rows
    valid_rows = np.all(np.isfinite(x), axis=1)
    x_valid = x[valid_rows, :]

    # Compute the squared magnitude of each vector and the mean magnitude
    bsq = np.sum(x_valid**2, axis=1)
    mb = np.sqrt(np.mean(bsq))

    XX = np.hstack((2 * x_valid, np.full((len(x_valid), 1), mb)))

    R = np.dot(XX.T, XX)

    if np.linalg.cond(R) > 1e3:
        raise ValueError("Condition too poor to get reliable solution")

    P = np.dot(bsq, XX)
    H = -solve(R, P)

    G['poly'] = np.hstack((poly1, H[:3].reshape(3, 1)))

    # Adjust the sensor data by adding the offset
    x = x + H[:3]

    if not isinstance(X, dict):
        return {'X': x, 'G': G}

    X['data'] = x

    # Check and adjust for cal_map or cal_cross if present
    if 'cal_map' in X:
        G['poly'][:, 1] = np.dot(inv(X['cal_map']), G['poly'][:, 1])

    if 'cal_cross' in X:
        G['poly'][:, 1] = np.dot(inv(X['cal_cross']), G['poly'][:, 1])

    X['cal_poly'] = G['poly']

    # Update the history
    if 'history' in X and X['history']:
        X['history'].append('fix_offset_3d')
    else:
        X['history'] = ['fix_offset_3d']

    return {'X': X, 'G': G}


In [None]:
# Assuming `data_pkl` is already loaded and contains your data
accX = data_pkl.data['CC-96']['accX'].values
accY = data_pkl.data['CC-96']['accY'].values
accZ = data_pkl.data['CC-96']['accZ'].values

# Combine into a single matrix
acc_data = np.vstack((accX, accY, accZ)).T  # Shape should be (n_samples, 3)

# Apply the fix_offset_3d function
result = fix_offset_3d(acc_data)

# Extract the adjusted data and calibration info
adjusted_data_acc = result['X']
calibration_info_acc = result['G']

print("Adjusted Data:\n", adjusted_data_acc)
print("Calibration Info:\n", calibration_info_acc)

data_pkl.data['CC-96']['accX_adjusted'] = adjusted_data_acc[:, 0]
data_pkl.data['CC-96']['accY_adjusted'] = adjusted_data_acc[:, 1]
data_pkl.data['CC-96']['accZ_adjusted'] = adjusted_data_acc[:, 2]
data_pkl.info['CC-96']['calibration_info'] = {}
data_pkl.info['CC-96']['calibration_info'] = calibration_info_acc

In [None]:
# Assuming `data_pkl` is already loaded and contains your data
magX = data_pkl.data['CC-96']['magX'].values
magY = data_pkl.data['CC-96']['magY'].values
magZ = data_pkl.data['CC-96']['magZ'].values

# Combine into a single matrix
mag_data = np.vstack((magX, magY, magZ)).T  # Shape should be (n_samples, 3)

# Apply the fix_offset_3d function
result = fix_offset_3d(mag_data)

# Extract the adjusted data and calibration info
adjusted_data_mag = result['X']
calibration_info_mag = result['G']

print("Adjusted Data:\n", adjusted_data_mag)
print("Calibration Info:\n", calibration_info_mag)

data_pkl.data['CC-96']['magX_adjusted'] = adjusted_data_mag[:, 0]
data_pkl.data['CC-96']['magY_adjusted'] = adjusted_data_mag[:, 1]
data_pkl.data['CC-96']['magZ_adjusted'] = adjusted_data_mag[:, 2]
data_pkl.info['CC-96']['calibration_info'] = {}
data_pkl.info['CC-96']['calibration_info'] = calibration_info_mag

In [None]:
import numpy as np

# Assuming `data_pkl` is already loaded and contains your data
accX = data_pkl.data['CC-96']['accX_adjusted'].values
accY = data_pkl.data['CC-96']['accY_adjusted'].values
accZ = data_pkl.data['CC-96']['accZ_adjusted'].values
magX = data_pkl.data['CC-96']['magX_adjusted'].values
magY = data_pkl.data['CC-96']['magY_adjusted'].values
magZ = data_pkl.data['CC-96']['magZ_adjusted'].values

# Combine the accelerometer and magnetometer data into nx3 matrices
acc_data = np.vstack((accX, accY, accZ)).T
mag_data = np.vstack((magX, magY, magZ)).T

# Get the sampling rate from the data structure
sampling_rate = int(data_pkl.info['CC-96']['datetime_metadata']['fs'])

# Call the check_AM function
AMcheck2 = check_AM(acc_data, mag_data, sampling_rate)

# Access the field intensity and inclination angle
field_intensity_acc2 = AMcheck2['fstr'][:, 0]  # Field intensity of accelerometer data
field_intensity_mag2 = AMcheck2['fstr'][:, 1]  # Field intensity of magnetometer data
inclination_angle2 = AMcheck2['incl']

# Print results
print("Field Intensity:\n", field_intensity_acc2)
print("Field Intensity:\n", field_intensity_mag2)
print("Inclination Angle (degrees):\n", inclination_angle2)

# Save new fields to data_pkl.data['CC-96']
data_pkl.data['CC-96']['field_intensity_acc2'] = field_intensity_acc2
data_pkl.data['CC-96']['field_intensity_mag2'] = field_intensity_mag2
data_pkl.data['CC-96']['inclination_angle2'] = inclination_angle2

# Example usage in Streamlit
imu_channels_to_plot = ['accX', 'accY', 'accZ', 'field_intensity_acc2', 'field_intensity_mag2', 'inclination_angle2']
ephys_channels_to_plot = []
imu_logger_to_use = 'CC-96'
ephys_logger_to_use = 'UF-01'

# Get the overlapping time range
imu_df = data_pkl.data[imu_logger_to_use]
ephys_df = data_pkl.data[ephys_logger_to_use]
start_time = max(imu_df['datetime'].min(), ephys_df['datetime'].min()).to_pydatetime()
end_time = min(imu_df['datetime'].max(), ephys_df['datetime'].max()).to_pydatetime()

plot_tag_data_interactive(data_pkl, imu_channels_to_plot, ephys_channels=ephys_channels_to_plot, imu_logger=imu_logger_to_use, ephys_logger=ephys_logger_to_use, time_range=(start_time, end_time))



### Fix Inclination Angle

Here is where you would fix the axes of triaxial sensor data (if the sensor axis differs from the tag axis)

## Tag frame to animal frame

### Find dives method

In [None]:
import numpy as np
import pandas as pd
from scipy.signal import decimate, medfilt
from scipy.linalg import norm, svd
from scipy.stats import iqr, linregress

def calculate_depth_rate(depth_data, sampling_rate, smoothing_factor=0.2):
    """
    Calculate the rate of change of depth (depth rate).
    
    Parameters:
    depth_data (numpy.ndarray): Depth data in meters.
    sampling_rate (float): Sampling rate in Hz.
    smoothing_factor (float): Smoothing factor to reduce noise in the derivative.
    
    Returns:
    numpy.ndarray: Smoothed depth rate.
    """
    depth_diff = np.gradient(depth_data) * sampling_rate
    smoothed_depth_rate = medfilt(depth_diff, kernel_size=int(smoothing_factor * sampling_rate))
    return smoothed_depth_rate

def prh_predictor2(depth_data, accel_data, sampling_rate, max_depth_threshold=10):
    """
    Predict the tag position on a diving animal from depth and acceleration data.
    
    Parameters:
    depth_data (numpy.ndarray): Depth vector in meters.
    accel_data (numpy.ndarray): Acceleration matrix with columns ax, ay, az.
    sampling_rate (float): The sampling rate of the sensor data in Hz.
    max_depth_threshold (float): The maximum depth of near-surface dives. Default is 10 m.
    
    Returns:
    pandas.DataFrame: DataFrame with columns 'cue', 'p0', 'r0', 'h0', and 'q'.
    """
    
    min_segment_length = 30  # minimum surface segment length in seconds
    max_segment_length = 300  # maximum surface segment length in seconds
    gap_time = 5  # gap time in seconds to avoid dive edges
    
    if depth_data is None or accel_data is None:
        raise ValueError("prh_predictor2 requires inputs depth_data and accel_data.")
    
    if sampling_rate is None:
        raise ValueError("sampling_rate must be specified.")
    
    # Decimate data to 5Hz if needed
    if sampling_rate >= 7.5:
        decimation_factor = round(sampling_rate / 5)
        depth_data = decimate(depth_data, decimation_factor, zero_phase=True)
        accel_data = decimate(accel_data, decimation_factor, axis=0, zero_phase=True)
        sampling_rate /= decimation_factor

    # Normalize acceleration to 1 g
    accel_data = accel_data / np.linalg.norm(accel_data, axis=1)[:, np.newaxis]

    # Calculate depth rate
    depth_rate = calculate_depth_rate(depth_data, sampling_rate)

    # Detect dive start/ends using the find_dives function
    dive_times = find_dives(depth_data, min_depth_threshold=max_depth_threshold, 
                            sampling_rate=sampling_rate, duration_threshold=10)

    if dive_times.shape[0] == 0:
        raise ValueError(f"No dives deeper than {max_depth_threshold:.0f} found - change max_depth_threshold.")
    
    # Augment all dive-start and dive-end times by gap_time seconds
    dive_times['start'] -= gap_time
    dive_times['end'] += gap_time
    
    # Check for segments before first dive and after last dive
    first_segment_start = max(dive_times.iloc[0]['start'] - max_segment_length, 0)
    first_segment_end = dive_times.iloc[0]['start']
    last_segment_start = dive_times.iloc[-1]['end']
    last_segment_end = min(dive_times.iloc[-1]['end'] + max_segment_length, len(depth_data) / sampling_rate)
    
    # Adjust the first and last segments based on depth
    first_segment_indices = np.where(depth_data[int(sampling_rate * first_segment_start):int(sampling_rate * first_segment_end)] > max_depth_threshold)[0]
    if len(first_segment_indices) > 0:
        first_segment_start += first_segment_indices[-1] / sampling_rate

    last_segment_indices = np.where(depth_data[int(sampling_rate * last_segment_start):int(sampling_rate * last_segment_end)] > max_depth_threshold)[0]
    if len(last_segment_indices) > 0:
        last_segment_end = last_segment_start + (last_segment_indices[0] - 1) / sampling_rate

    # Combine all segments
    all_segments = np.vstack([
        [first_segment_start, first_segment_end],
        dive_times[['end', 'start']].values[:-1],
        [last_segment_start, last_segment_end]
    ])

    # Filter out segments that are too short
    segment_durations = np.diff(all_segments, axis=1)[:, 0]
    valid_segments = all_segments[segment_durations > min_segment_length]
    
    # Break up long surfacing intervals
    while True:
        long_segments = np.where(np.diff(valid_segments, axis=1)[:, 0] > max_segment_length)[0]
        if len(long_segments) == 0:
            break
        segment_index = long_segments[0]
        valid_segments = np.vstack([
            valid_segments[:segment_index],
            [valid_segments[segment_index, 0], valid_segments[segment_index, 0] + max_segment_length],
            [valid_segments[segment_index, 0] + max_segment_length, valid_segments[segment_index, 1]],
            valid_segments[segment_index + 1:]
        ])
    
    # Check for segments with sufficient variation in orientation
    # orientation_variation = np.zeros(valid_segments.shape[0])
    # for segment_index in range(valid_segments.shape[0]):
    #     indices = np.arange(int(valid_segments[segment_index, 0] * sampling_rate), 
    #                         int(valid_segments[segment_index, 1] * sampling_rate))
        
    #     # Boundary check to avoid out-of-bounds indices
    #     indices = indices[indices < accel_data.shape[0]]

    #     orientation_variation[segment_index] = norm(np.std(accel_data[indices, :], axis=0))
    
    # variation_threshold = np.median(orientation_variation) + 1.5 * iqr(orientation_variation) * np.array([-1, 1])
    # valid_segments = valid_segments[(orientation_variation > variation_threshold[0]) & 
    #                                  (orientation_variation < variation_threshold[1])]

    # PRH inference
    prh_data = np.empty((valid_segments.shape[0], 5))
    for segment_index in range(valid_segments.shape[0]):
        indices = np.arange(int(valid_segments[segment_index, 0] * sampling_rate), 
                            int(valid_segments[segment_index, 1] * sampling_rate))
        
        # Boundary check to avoid out-of-bounds indices
        indices = indices[indices < accel_data.shape[0]]

        prh = applymethod2(accel_data[indices, :], depth_rate, sampling_rate, valid_segments[segment_index, :])
        if prh is not None:
            prh_data[segment_index, :] = np.hstack([np.mean(valid_segments[segment_index, :]), prh])
    
    # Convert prh_data array to DataFrame
    prh_df = pd.DataFrame(prh_data, columns=['cue', 'p0', 'r0', 'h0', 'q'])

    return prh_df

def applymethod2(accel_data, depth_rate, sampling_rate, segment_times):
    """
    Apply PRH predictor method 2 to estimate tag-to-animal orientation.

    Parameters:
    accel_data (numpy.ndarray): Acceleration data for the segment.
    depth_rate (numpy.ndarray): Depth rate data for the segment.
    sampling_rate (float): The sampling rate of the sensor data in Hz.
    segment_times (numpy.ndarray): Start and end times of the segment.

    Returns:
    numpy.ndarray: Estimated orientation angles [p0, r0, h0, q].
    """

    # Extract the segment indices
    segment_indices = np.arange(int(segment_times[0] * sampling_rate), int(segment_times[1] * sampling_rate))
    segment_indices = segment_indices[segment_indices < accel_data.shape[0]]
    
    As = accel_data[segment_indices, :]
    vs = depth_rate[segment_indices]

    # Energy ratio between plane-of-motion and axis of rotation
    QQ = As.T @ As  # Form outer product of acceleration
    if np.any(np.isnan(QQ)):
        return None

    svd_out = svd(QQ)
    pow_ratio = svd_out[1][2] / svd_out[1][1]  # Power ratio from singular values

    # Axis of rotation to restore V to tag Y axis
    aa = np.arccos(np.dot([0, 1, 0], svd_out[2][:, 2]))
    Phi = np.cross([0, 1, 0], svd_out[2][:, 2]) / np.sin(aa)
    S = np.array([[0, -Phi[2], Phi[1]],
                  [Phi[2], 0, -Phi[0]],
                  [-Phi[1], Phi[0], 0]])

    Q = np.eye(3) + (1 - np.cos(aa)) * S @ S - np.sin(aa) * S  # Generate rotation matrix

    am = np.mean(As, axis=0) @ Q.T
    p0 = np.arctan2(am[0], am[2])
    Q = euler2rotmat(p=p0, r=0, h=0) @ Q

    prh = np.array([np.arcsin(Q[2, 0]), np.arctan2(Q[2, 1], Q[2, 2]), np.arctan2(Q[1, 0], Q[0, 0])])

    aa_transformed = As @ Q[1, :].T
    prh_quality = np.mean([pow_ratio, np.std(aa_transformed)])

    # Check that h0 is not 180 degrees out by checking the regression
    # between Aa[:, 0] and depth_rate is negative.
    Q_final = euler2rotmat(prh[0], prh[1], prh[2])
    Aa = As @ Q_final.T
    slope, _, _, _, _ = linregress(Aa[:, 0], vs)

    if slope > 0:
        prh[2] = (prh[2] - np.pi) % (2 * np.pi)  # Correct if necessary

    # Constrain r0 and h0 to the interval -pi:pi
    for i in range(1, 3):
        if abs(prh[i]) > np.pi:
            prh[i] -= np.sign(prh[i]) * 2 * np.pi

    return np.hstack((prh, prh_quality))

def euler2rotmat(p, r, h):
    """
    Convert Euler angles to a rotation matrix.
    
    Parameters:
    p (float): Pitch angle in radians.
    r (float): Roll angle in radians.
    h (float): Heading angle in radians.
    
    Returns:
    numpy.ndarray: 3x3 rotation matrix.
    """
    # Rotation matrix from Euler angles (assuming the same order of rotations as in R)
    cp, sp = np.cos(p), np.sin(p)
    cr, sr = np.cos(r), np.sin(r)
    ch, sh = np.cos(h), np.sin(h)
    
    rot_matrix = np.array([
        [ch * cr, ch * sr * sp - sh * cp, ch * sr * cp + sh * sp],
        [sh * cr, sh * sr * sp + ch * cp, sh * sr * cp - ch * sp],
        [-sr, cr * sp, cr * cp]
    ])
    
    return rot_matrix


# Example usage
P = data_pkl.data['CC-96']['corrdepth'].values
A = np.vstack((data_pkl.data['CC-96']['accX'].values, 
               data_pkl.data['CC-96']['accY'].values, 
               data_pkl.data['CC-96']['accZ'].values)).T
sampling_rate = int(data_pkl.info['CC-96']['datetime_metadata']['fs'])

PRH = prh_predictor2(P, A, sampling_rate=sampling_rate, max_depth_threshold=1)
print(PRH)


In [None]:
field_intensity_acc2 = AMcheck2['fstr'][:, 0]  # Field intensity of accelerometer data
field_intensity_mag2 = AMcheck2['fstr'][:, 1]  # Field intensity of magnetometer data
inclination_angle2 = AMcheck2['incl']

# Print results
print("Field Intensity:\n", field_intensity_acc2)
print("Field Intensity:\n", field_intensity_mag2)
print("Inclination Angle (degrees):\n", inclination_angle2)

# Save new fields to data_pkl.data['CC-96']
data_pkl.data['CC-96']['field_intensity_acc2'] = field_intensity_acc2
data_pkl.data['CC-96']['field_intensity_mag2'] = field_intensity_mag2
data_pkl.data['CC-96']['inclination_angle2'] = inclination_angle2

### PRH method 1

### PRH method 2
for short-surfacing animals like beaked whales

In [None]:
import numpy as np
import pandas as pd
from scipy.signal import decimate
from scipy.linalg import norm
from scipy.stats import iqr

# Helper function for PRH inference
def prh_predictor2_gui(depth_data, accel_data, sampling_rate, max_depth_threshold=10):
    """
    Predict the tag position on a diving animal from depth and acceleration data.
    Returns: DataFrame with p0, r0, h0, and quality estimates.
    """
    min_segment_length = 30  # minimum surface segment length in seconds
    max_segment_length = 300  # maximum surface segment length in seconds
    gap_time = 5  # gap time in seconds to avoid dive edges

    # Decimate data to 5Hz if needed
    if sampling_rate >= 7.5:
        decimation_factor = round(sampling_rate / 5)
        depth_data = decimate(depth_data, decimation_factor, zero_phase=True)
        accel_data = decimate(accel_data, decimation_factor, axis=0, zero_phase=True)
        sampling_rate /= decimation_factor

    # Normalize acceleration to 1 g
    accel_data = accel_data / np.linalg.norm(accel_data, axis=1)[:, np.newaxis]

    # Detect dive start/ends using the find_dives function
    dive_times = find_dives(depth_data, min_depth_threshold=max_depth_threshold, 
                            sampling_rate=sampling_rate, duration_threshold=10)

    if dive_times.shape[0] == 0:
        raise ValueError(f"No dives deeper than {max_depth_threshold:.0f} found - change max_depth_threshold.")

    # Augment all dive-start and dive-end times by gap_time seconds
    dive_times['start'] -= gap_time
    dive_times['end'] += gap_time

    # Initialize the PRH and quality estimates DataFrame
    prh_data = []

    for i, row in dive_times.iterrows():
        start_time = int(row['start'] * sampling_rate)
        end_time = int(row['end'] * sampling_rate)

        # Analyze orientation segments
        accel_segment = accel_data[start_time:end_time, :]
        orientation_variation = norm(np.std(accel_segment, axis=0))
        quality = np.abs(orientation_variation) / np.mean(orientation_variation)

        # Estimate p0, r0, h0 (random example for simplicity)
        p0 = np.mean(accel_segment[:, 0])
        r0 = np.mean(accel_segment[:, 1])
        h0 = np.mean(accel_segment[:, 2])

        # Append the estimates to the list
        prh_data.append({
            'cue': row['tmax'],
            'p0': p0,
            'r0': r0,
            'h0': h0,
            'quality': quality
        })

    # Convert the list of dictionaries into a DataFrame
    prh_data = pd.DataFrame(prh_data)

    return prh_data

# Example usage
P = data_pkl.data['CC-96']['corrdepth'].values
A = np.vstack((data_pkl.data['CC-96']['accX'].values, 
               data_pkl.data['CC-96']['accY'].values, 
               data_pkl.data['CC-96']['accZ'].values)).T
sampling_rate = int(data_pkl.info['CC-96']['datetime_metadata']['fs'])

PRH = prh_predictor2(P, A, sampling_rate=sampling_rate, max_depth_threshold=1)




## Filtering

### Dominant Stroke Frequency function

In [None]:
import numpy as np
from scipy.signal import butter, filtfilt, welch
from scipy.fft import fft
from scipy import polyfit

def dsf(acc_data, sampling_rate, fc=2.5, Nfft=None):
    """
    Estimate the dominant stroke frequency from triaxial accelerometer data.

    Parameters:
    acc_data (numpy.ndarray): nx3 acceleration matrix with columns [ax, ay, az].
    sampling_rate (float): The sampling rate of the sensor data in Hz.
    fc (float, optional): The cut-off frequency in Hz of a low-pass filter to apply to acc_data. Defaults to 2.5 Hz.
    Nfft (int, optional): The FFT length and therefore the frequency resolution. Defaults to the power of two closest to 20*sampling_rate.

    Returns:
    dict: A dictionary with two elements:
        - 'fpk': The dominant stroke frequency (Hz).
        - 'q': The quality of the peak.
    """

    # Ensure acc_data is a numpy array
    acc_data = np.asarray(acc_data)

    # Determine the default FFT length if not provided
    if Nfft is None:
        Nfft = int(2**np.round(np.log2(20 * sampling_rate)))

    # Low-pass filter
    if fc and fc < (sampling_rate / 2):
        b, a = butter(6, fc / (sampling_rate / 2), btype='low')
        acc_data_filtered = filtfilt(b, a, acc_data, axis=0)
    else:
        acc_data_filtered = acc_data

    # Calculate the power spectral density for each axis
    freqs, power_spectrum = welch(acc_data_filtered, fs=sampling_rate, nperseg=Nfft, axis=0)

    # Sum the power spectral densities across the three axes
    summed_power_spectrum = np.sum(power_spectrum, axis=1)

    # Find the frequency with the maximum power
    max_power = np.max(summed_power_spectrum)
    peak_index = np.argmax(summed_power_spectrum)

    if 1 < peak_index < len(freqs) - 1:
        # Quadratic interpolation to refine peak frequency
        p = polyfit(freqs[peak_index-1:peak_index+2], summed_power_spectrum[peak_index-1:peak_index+2], 2)
        fpk = -p[1] / (2 * p[0])
    else:
        fpk = freqs[peak_index]

    # Quality of the peak
    q = max_power / np.mean(summed_power_spectrum)

    return {'fpk': fpk, 'q': q}



import numpy as np

# Assuming `data_pkl` is already loaded and contains your data
accX = data_pkl.data['CC-96']['accX'].values
accY = data_pkl.data['CC-96']['accY'].values
accZ = data_pkl.data['CC-96']['accZ'].values

# Combine the accelerometer data into an nx3 matrix
acc_data = np.vstack((accX, accY, accZ)).T

# Get the sampling rate from the data structure
sampling_rate = int(data_pkl.info['CC-96']['datetime_metadata']['fs'])  # Replace with the correct path to your sampling rate if needed

# Call the dsf function
result = dsf(acc_data, sampling_rate)

# Print the results
print("Dominant Stroke Frequency (fpk):", result['fpk'], "Hz")
print("Quality of the Peak (q):", result['q'])


### Complementary Filter Function

In [None]:
import numpy as np
from scipy.signal import butter, filtfilt

def comp_filt(X, sampling_rate=None, fc=None):
    """
    Complementary filtering of a signal.
    
    Parameters:
    X (numpy.ndarray or dict): A sensor vector or matrix (signal in each column) or a sensor data dictionary.
    sampling_rate (float): The sampling rate of the sensor data in Hz.
    fc (list or float): Cut-off frequency/frequencies of the complementary filters in Hz. If one frequency is given, 
                        X will be split into low- and high-frequency components. If a list of frequencies is given, 
                        X will be split into multiple complementary bands.

    Returns:
    dict: A dictionary with filtered signals. The keys correspond to the frequency bands: 'lowpass', 'highpass', and 'bandpass'.
    """

    if isinstance(X, dict):
        sampling_rate = X['sampling_rate']
        X = X['data']
    else:
        if fc is None or sampling_rate is None:
            raise ValueError("inputs X, sampling_rate, and fc are all required if X is not a dictionary")

    # Ensure fc is a list for consistency
    if isinstance(fc, float) or isinstance(fc, int):
        fc = [fc]

    nf = [int(4 * sampling_rate / f) for f in fc]
    Xf = {}

    # Apply the complementary filters
    for i, f in enumerate(fc):
        b, a = butter(4, f / (sampling_rate / 2), btype='low')
        lowpass = filtfilt(b, a, X, axis=0)
        Xf[f'band_{i}'] = lowpass
        X = X - lowpass  # Highpass component

    # Store the final highpass component
    Xf['highpass'] = X

    # Handle the case where there's only one frequency cutoff
    if len(fc) == 1:
        Xf = {'lowpass': Xf['band_0'], 'highpass': Xf['highpass']}
    
    return Xf


### Complementary Filter Usage

In [None]:
import numpy as np
import plotly.graph_objs as go
from plotly.subplots import make_subplots

# Assuming `data_pkl` is already loaded and contains your data
accX = data_pkl.data['CC-96']['accX'].values
accY = data_pkl.data['CC-96']['accY'].values
accZ = data_pkl.data['CC-96']['accZ'].values

# Combine the accelerometer data into an nx3 matrix
acc_data = np.vstack((accX, accY, accZ)).T

# Get the sampling rate from the data structure
sampling_rate = int(data_pkl.info['CC-96']['datetime_metadata']['fs'])

# Define the cut-off frequency (e.g., 0.15 Hz)
fc = 0.15

# Apply the complementary filter function
filtered_signals = comp_filt(acc_data, sampling_rate, fc)

# Extract lowpass and highpass components
lowpass_signal = filtered_signals['lowpass']
highpass_signal = filtered_signals['highpass']

# Create a time axis (assuming continuous data)
time_axis = np.arange(len(acc_data)) / sampling_rate

# Create subplots
fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
                    subplot_titles=("Low-Frequency Component", "High-Frequency Component"))

# Plot the low-frequency component
fig.add_trace(go.Scatter(x=time_axis[::100], y=lowpass_signal[::100, 0], mode='lines', name='Lowpass X'), row=1, col=1)
fig.add_trace(go.Scatter(x=time_axis[::100], y=lowpass_signal[::100, 1], mode='lines', name='Lowpass Y'), row=1, col=1)
fig.add_trace(go.Scatter(x=time_axis[::100], y=lowpass_signal[::100, 2], mode='lines', name='Lowpass Z'), row=1, col=1)

# Plot the high-frequency component
fig.add_trace(go.Scatter(x=time_axis[::100], y=highpass_signal[::100, 0], mode='lines', name='Highpass X'), row=2, col=1)
fig.add_trace(go.Scatter(x=time_axis[::100], y=highpass_signal[::100, 1], mode='lines', name='Highpass Y'), row=2, col=1)
fig.add_trace(go.Scatter(x=time_axis[::100], y=highpass_signal[::100, 2], mode='lines', name='Highpass Z'), row=2, col=1)

# Update the layout for better visualization
fig.update_layout(title="Complementary Filtered Signals", height=600, showlegend=True)
fig.update_xaxes(title_text="Time (seconds)")
fig.update_yaxes(title_text="Signal Amplitude")

# Show the plot in Streamlit (if using Streamlit)
# st.plotly_chart(fig, use_container_width=True)

# Or display the plot in a Jupyter notebook or other environments
fig.show()


In [None]:
import plotly.graph_objs as go
from plotly.subplots import make_subplots

# Color mapping dictionary with pastel, colorblind-friendly colors
color_mapping = {
    'ECG': '#FFCCCC',              # Light Red with alpha in rgba
    'Depth': '#00008B',            # Dark Blue
    'Accelerometer X [m/s²]': '#87CEFA',          # Light Blue
    'Accelerometer Y [m/s²]': '#98FB98',          # Pale Green
    'Accelerometer Z [m/s²]': '#FF6347',          # Light Coral
    'Gyro X': '#9370DB',           # Medium Purple
    'Gyro Y': '#BA55D3',           # Medium Orchid
    'Gyro Z': '#8A2BE2',           # Blue Violet
    'Mag X': '#FFD700',            # Gold
    'Mag Y': '#FFA500',            # Orange
    'Mag Z': '#FF8C00',            # Dark Orange
    'Filtered Heartbeats': '#808080',  # Gray for dotted lines
}

def plot_tag_data(data_pkl, imu_channels, ephys_channels=None, imu_logger=None, ephys_logger=None, imu_sampling_rate=10, ephys_sampling_rate=50, draw=True):
    if not imu_logger and not ephys_logger:
        raise ValueError("At least one logger (imu_logger or ephys_logger) must be specified.")

    # Ensure the order of channels: ECG, Depth, Accel, Gyro, Mag
    ordered_channels = []
    if ephys_channels and 'ecg' in [ch.lower() for ch in ephys_channels]:
        ordered_channels.append(('ECG', 'ecg'))
    if 'depth' in [ch.lower() for ch in imu_channels]:
        ordered_channels.append(('Depth', 'depth'))
    if any(ch.lower() in ['accx', 'accy', 'accz'] for ch in imu_channels):
        ordered_channels.append(('Accel', ['accX', 'accY', 'accZ']))
    if any(ch.lower() in ['gyrx', 'gyry', 'gyrz'] for ch in imu_channels):
        ordered_channels.append(('Gyro', ['gyrX', 'gyrY', 'gyrZ']))
    if any(ch.lower() in ['magx', 'magy', 'magz'] for ch in imu_channels):
        ordered_channels.append(('Mag', ['magX', 'magY', 'magZ']))

    # Calculate the number of rows needed
    num_rows = len(ordered_channels)

    fig = make_subplots(rows=num_rows, cols=1, shared_xaxes=True, vertical_spacing=0.03)
    
    def downsample(df, original_fs, target_fs):
        if target_fs >= original_fs:
            return df
        conversion_factor = int(original_fs / target_fs)
        return df.iloc[::conversion_factor, :]

    if imu_logger:
        imu_df = data_pkl.data[imu_logger]
        imu_fs = 1 / imu_df['datetime'].diff().dt.total_seconds().mean()
        imu_df_downsampled = downsample(imu_df, imu_fs, imu_sampling_rate)
        imu_info = data_pkl.info[imu_logger]['channelinfo']
    
    if ephys_logger:
        ephys_df = data_pkl.data[ephys_logger]
        ephys_fs = 1 / ephys_df['datetime'].diff().dt.total_seconds().mean()
        ephys_df_downsampled = downsample(ephys_df, ephys_fs, ephys_sampling_rate)
        ephys_info = data_pkl.info[ephys_logger]['channelinfo']

    row_counter = 1
    
    for channel_type, channels in ordered_channels:
        if channel_type == 'ECG' and ephys_channels and 'ecg' in [ch.lower() for ch in ephys_channels]:
            # Plot ECG
            channel = 'ecg'
            df = ephys_df_downsampled
            info = ephys_info
            original_name = info[channel]['original_name']
            unit = info[channel]['unit']

            y_data = df[channel]
            x_data = df['datetime']

            y_label = f"{original_name} [{unit}]"
            color = color_mapping.get(original_name, color_mapping['ECG'])

            fig.add_trace(go.Scatter(
                x=x_data,
                y=y_data,
                mode='lines',
                name=y_label,
                line=dict(color=color)
            ), row=row_counter, col=1)

            # Add vertical lines for heartbeats
            filtered_notes = data_pkl.notes_df[data_pkl.notes_df['key'] == 'heartbeat_manual_ok']
            if not filtered_notes.empty:
                for dt in filtered_notes['datetime']:
                    fig.add_trace(go.Scatter(
                        x=[dt, dt],
                        y=[y_data.min(), y_data.max()],
                        mode='lines',
                        line=dict(color=color_mapping['Filtered Heartbeats'], width=1, dash='dot'),
                        showlegend=False
                    ), row=row_counter, col=1)

            fig.update_yaxes(title_text=y_label, row=row_counter, col=1)
            row_counter += 1

        elif channel_type == 'Depth' and 'depth' in [ch.lower() for ch in imu_channels]:
            # Plot Depth
            channel = 'depth'
            df = imu_df_downsampled
            info = imu_info
            original_name = info[channel]['original_name']
            unit = info[channel]['unit']

            y_data = df[channel]
            x_data = df['datetime']

            y_label = f"{original_name} [{unit}]"
            color = color_mapping.get(original_name, color_mapping['Depth'])

            fig.add_trace(go.Scatter(
                x=x_data,
                y=y_data,
                mode='lines',
                name=y_label,
                line=dict(color=color)
            ), row=row_counter, col=1)

            fig.update_yaxes(title_text=y_label, autorange="reversed", row=row_counter, col=1)
            row_counter += 1

        elif channel_type in ['Accel', 'Gyro', 'Mag']:
            # Plot Accel, Gyro, or Mag channels together
            for sub_channel in channels:
                if sub_channel in imu_df_downsampled.columns:
                    df = imu_df_downsampled
                    info = imu_info
                    original_name = info[sub_channel]['original_name']
                    unit = info[sub_channel]['unit']

                    y_data = df[sub_channel]
                    x_data = df['datetime']

                    y_label = f"{original_name} [{unit}]"
                    color = color_mapping.get(original_name, '#000000')

                    fig.add_trace(go.Scatter(
                        x=x_data,
                        y=y_data,
                        mode='lines',
                        name=y_label,
                        line=dict(color=color)
                    ), row=row_counter, col=1)

            fig.update_yaxes(title_text=f"{channel_type} [{unit}]", row=row_counter, col=1)
            row_counter += 1

    fig.update_layout(
        height=200 * num_rows,
        width=1200,
        title_text=f"{data_pkl.selected_deployment['Deployment Name']}",
        showlegend=True
    )
    
    fig.update_xaxes(title_text="Datetime", row=row_counter-1, col=1)

    if draw:
        fig.show()
    else:
        return fig

# Example usage:
# Specify channels and loggers
imu_channels_to_plot = ['depth', 'accX', 'accY', 'accZ', 'gyrX', 'gyrY', 'gyrZ', 'magX', 'magY', 'magZ']
ephys_channels_to_plot = ['ecg']
imu_logger_to_use = 'CC-96'
ephys_logger_to_use = 'UF-01'

plot_tag_data(data_pkl, imu_channels_to_plot, ephys_channels=ephys_channels_to_plot, imu_logger=imu_logger_to_use, ephys_logger=ephys_logger_to_use, imu_sampling_rate=10, ephys_sampling_rate=75)


In [None]:
data_pkl.info['channelnames']['CC-96']

In [None]:
import plotly.graph_objs as go
from plotly.subplots import make_subplots

# Assuming your new data columns are as follows:
# ECG signal column: 'ecg'
# Depth column: 'depth1'
# Accelerometer columns: 'accX', 'accY', 'accZ'
# Gyroscope column: 'gyrY'

# Create subplots
fig = make_subplots(rows=5, cols=1, shared_xaxes=True, vertical_spacing=0.03)

# Add Heart Rate (bpm) plot at the top
fig.add_trace(go.Scatter(
    x=filtered_notes['datetime'], 
    y=filtered_notes['value'], 
    mode='markers', 
    marker=dict(color='gray', size=8, symbol='circle-open'),
    name='Heart rate (bpm)'
), row=1, col=1)

# Add ECG plot with light red color and alpha 0.2
fig.add_trace(go.Scatter(
    x=ecg_df50['datetime'], 
    y=ecg_df50['ecg'], 
    mode='lines', 
    name='ECG [mV]', 
    line=dict(color='rgba(255, 0, 0, 0.2)')
), row=2, col=1)

# Add vertical dotted lines for detected heartbeats
for dt in filtered_notes['datetime']:
    fig.add_trace(go.Scatter(
        x=[dt, dt], 
        y=[ecg_df50['ecg'].min(), ecg_df50['ecg'].max()], 
        mode='lines', 
        line=dict(color='gray', width=1, dash='dot'),
        showlegend=False
    ), row=2, col=1)

# Add Depth plot with dark blue color
fig.add_trace(go.Scatter(
    x=CO_df10['datetime'], 
    y=CO_df10['depth1'], 
    mode='lines', 
    name='Depth [m]', 
    line=dict(color='darkblue')
), row=3, col=1)
fig.update_yaxes(autorange="reversed", row=3, col=1)

# Add Accelerometer plots on the same y-axis
fig.add_trace(go.Scatter(
    x=CO_df10['datetime'], 
    y=CO_df10['accX'], 
    mode='lines', 
    name='Accel X [m/s²]', 
    line=dict(color='blue')
), row=4, col=1)

fig.add_trace(go.Scatter(
    x=CO_df10['datetime'], 
    y=CO_df10['accY'], 
    mode='lines', 
    name='Accel Y [m/s²]', 
    line=dict(color='green')
), row=4, col=1)

fig.add_trace(go.Scatter(
    x=CO_df10['datetime'], 
    y=CO_df10['accZ'], 
    mode='lines', 
    name='Accel Z [m/s²]', 
    line=dict(color='red')
), row=4, col=1)

# Add Gyroscope Y plot
fig.add_trace(go.Scatter(
    x=CO_df10['datetime'], 
    y=CO_df10['gyrY'], 
    mode='lines', 
    name='Gyr Y [mrad/s]', 
    line=dict(color='purple')
), row=5, col=1)

# Update layout
fig.update_layout(
    height=800, 
    width=1200, 
    title_text=f"{data_pkl.selected_deployment['Deployment Name']}", 
    showlegend=True
)
fig.update_xaxes(title_text="Datetime", row=5, col=1)

# Update y-axes labels
fig.update_yaxes(title_text="Heart rate (bpm)", row=1, col=1)
fig.update_yaxes(title_text="ECG [mV]", row=2, col=1)
fig.update_yaxes(title_text="Depth [m]", row=3, col=1)
fig.update_yaxes(title_text="Accelerometer [m/s²]", row=4, col=1)
fig.update_yaxes(title_text="Gyr Y [mrad/s]", row=5, col=1)

# Show plot
fig.show()

In [None]:
# Save the interactive plot as an HTML file
fig.write_html(os.path.join(deployment_folder, "outputs", f"{data_pkl.selected_deployment['Deployment Name']}.html")) 
data_pkl.selected_deployment['Deployment Name']