# Data pipeline for DiveDB
Uses classes `info` and `DataReader` to facilitate data intake, processing, and alignment. 

In [None]:
# Import libraries and set working directory (adjust to fit your preferences)
import os
import sys
import numpy as np
import pandas as pd
import pytz
import matplotlib.pyplot as plt
from notion_client import Client
from dotenv import load_dotenv
from datareader import DataReader
from metadata import Metadata
from loggerdata import LoggerData
#from plotter import plot_tag_data
import plotly.express as px
import pickle
import nbformat
print(nbformat.__version__)

# Change the current working directory to the root directory
# os.chdir("/Users/fbar/Documents/GitHub/pyologger")
os.chdir("/Users/jessiekb/Documents/GitHub/pyologger")
root_dir = os.getcwd()
data_dir = os.path.join(root_dir, "data")

# Verify the current working directory
print(f"Current working directory: {root_dir}")

### Query metadata
Use Notion and [info entry form](https://forms.fillout.com/t/8UNuTLMaRfus) to start a recording and to generate identifiers for the Recording and Deployment. 


In [None]:
# Initialize the info class
metadata = Metadata()
metadata.fetch_databases(verbose=False)

# Save databases
dep_db = metadata.get_metadata("dep_DB")
logger_db = metadata.get_metadata("logger_DB")
rec_db = metadata.get_metadata("rec_DB")
animal_db = metadata.get_metadata("animal_DB")

### Steps for Processing Deployment Data:

1. **Select Deployment Folder**:
   - **Description:** Asks the user for input to select a deployment folder to kick off the data reading process. In your folder name, you can have any suffix after Deployment ID. It will check and stop if there are two that fit.
   - **Function Used:** `check_deployment_folder()`

2. **Initialize Deployment Folder**:
   - **Description:** Starts the main `read_files` process with the selected deployment folder.
   - **Function Used:** `read_files()`

3. **Fetch Metadata**:
   - **Description:** Retrieve necessary data from the metadata database, including logger information.
   - **Function Used:** `metadata.fetch_databases()`

4. **Organize Files by Logger ID**:
   - **Description:** Group files by logger ID for processing.
   - **Function Used:** `read_files()` (This is the main function)

5. **Check for Existing Processed Files**:
   - **Description:** Verify if the outputs folder already contains processed files for each logger. Skip reprocessing if all necessary files are present.
   - **Function Used:** `check_outputs_folder()`

6. **Process UBE Files**:
   - **Description:** For each UFI logger with UBE files, process and save the data.
   - **Function Used:** `process_ube_file()`

7. **Process CSV Files**:
   - **Description:** For each logger with multiple CSV files, concatenate them, and save the combined data.
   - **Function Used:** `concatenate_and_save_csvs()`

8. **Final Outputs**:
   - **Description:** Ensure all processed data is saved in the outputs folder with appropriate filenames.
   - **Functions Used:** `save_data()`

In [None]:
# Find your deployment ID index and remember it for the next cell, where you have to enter it.
dep_db

In [None]:
# Assuming you have the metadata and dep_db loaded:
datareader = DataReader()
deployment_folder = datareader.check_deployment_folder(dep_db, data_dir)

if deployment_folder:
    datareader.read_files(metadata, save_csv=True, save_parq=True)

In [None]:
# Optionally look at first notes that have been read in
#datareader.selected_deployment['Time Zone']
#datareader.info['UF-01']
datareader.notes_df[0:5]
#datareader.data['CC-96']
#datareader.data['UF-01']
#datareader.metadata['channelnames']

In [None]:
datareader.data['CC-96']

In [None]:
#data_test = pd.read_csv(os.path.join(deployment_folder, "2024-01-16_oror-002a_CC-96_001.csv"))
datareader.selected_deployment['Time Zone']

### Inspect the pickle file output

Load in the generated pickle file to inspect the output.

In [None]:
# Load the data_reader object from the pickle file
pkl_path = os.path.join(deployment_folder, 'outputs', 'data.pkl')

with open(pkl_path, 'rb') as file:
    data_pkl = pickle.load(file)

for logger_id, info in data_pkl.info.items():
    sampling_frequency = info.get('datetime_metadata', {}).get('fs', None)
    if sampling_frequency is not None:
        # Format the sampling frequency to 5 significant digits
        print(f"Sampling frequency for {logger_id}: {sampling_frequency} Hz")
    else:
        print(f"No sampling frequency available for {logger_id}")

data_pkl.info['CC-96']['datetime_metadata']['fs']

In [None]:
data_pkl.notes_df[0:5]

In [None]:
data_pkl.data['UF-01']  #data['UF-01'][0:5] # browse column names

In [None]:
data_pkl.data['CC-96'][0:5] # browse column names

### Pre-process data for plots
Downsample high-resolution data and filter notes down to notes of interest to include in plot.

In [None]:
import plotly.graph_objs as go
from plotly.subplots import make_subplots

# Assuming you have loaded your data into data_pkl
imu_df = data_pkl.data['CC-96']
ecg_df = data_pkl.data['UF-01']

# Calculate sampling frequencies
CO_fs = 1 / imu_df['datetime'].diff().dt.total_seconds().mean()
ecg_fs = 1 / ecg_df['datetime'].diff().dt.total_seconds().mean()

# Define new desired sampling rates
new_CATS_sampling_rate = 10  # Hz
new_ecg_sampling_rate = 50  # Hz

# Calculate the downsampling conversion factors
CATS_conversion = int(CO_fs / new_CATS_sampling_rate)
ecg_conversion = int(ecg_fs / new_ecg_sampling_rate)

# Subsample the dataframes
ecg_df50 = ecg_df.iloc[::ecg_conversion, :]
imu_df10 = imu_df.iloc[::CATS_conversion, :]

# Filter notes for those with key == 'heartbeat_manual_ok'
filtered_notes = data_pkl.notes_df[data_pkl.notes_df['key'] == 'heartbeat_manual_ok']


In [None]:
data_pkl.info['CC-96']['channelinfo']

In [None]:
import plotly.graph_objs as go
from plotly.subplots import make_subplots

# Color mapping dictionary with pastel, colorblind-friendly colors
color_mapping = {
    'ECG': '#FFCCCC',              # Light Red with alpha in rgba
    'Depth': '#00008B',            # Dark Blue
    'Accelerometer X [m/s²]': '#87CEFA',          # Light Blue
    'Accelerometer Y [m/s²]': '#98FB98',          # Pale Green
    'Accelerometer Z [m/s²]': '#FF6347',          # Light Coral
    'Gyro X': '#9370DB',           # Medium Purple
    'Gyro Y': '#BA55D3',           # Medium Orchid
    'Gyro Z': '#8A2BE2',           # Blue Violet
    'Mag X': '#FFD700',            # Gold
    'Mag Y': '#FFA500',            # Orange
    'Mag Z': '#FF8C00',            # Dark Orange
    'Filtered Heartbeats': '#808080',  # Gray for dotted lines
}

def plot_tag_data(data_pkl, imu_channels, ephys_channels=None, imu_logger=None, ephys_logger=None, imu_sampling_rate=10, ephys_sampling_rate=50, draw=True):
    if not imu_logger and not ephys_logger:
        raise ValueError("At least one logger (imu_logger or ephys_logger) must be specified.")

    # Ensure the order of channels: ECG, Depth, Accel, Gyro, Mag
    ordered_channels = []
    if ephys_channels and 'ecg' in [ch.lower() for ch in ephys_channels]:
        ordered_channels.append(('ECG', 'ecg'))
    if 'depth' in [ch.lower() for ch in imu_channels]:
        ordered_channels.append(('Depth', 'depth'))
    if any(ch.lower() in ['accx', 'accy', 'accz'] for ch in imu_channels):
        ordered_channels.append(('Accel', ['accX', 'accY', 'accZ']))
    if any(ch.lower() in ['gyrx', 'gyry', 'gyrz'] for ch in imu_channels):
        ordered_channels.append(('Gyro', ['gyrX', 'gyrY', 'gyrZ']))
    if any(ch.lower() in ['magx', 'magy', 'magz'] for ch in imu_channels):
        ordered_channels.append(('Mag', ['magX', 'magY', 'magZ']))

    # Calculate the number of rows needed
    num_rows = len(ordered_channels)

    fig = make_subplots(rows=num_rows, cols=1, shared_xaxes=True, vertical_spacing=0.03)
    
    def downsample(df, original_fs, target_fs):
        if target_fs >= original_fs:
            return df
        conversion_factor = int(original_fs / target_fs)
        return df.iloc[::conversion_factor, :]

    if imu_logger:
        imu_df = data_pkl.data[imu_logger]
        imu_fs = 1 / imu_df['datetime'].diff().dt.total_seconds().mean()
        imu_df_downsampled = downsample(imu_df, imu_fs, imu_sampling_rate)
        imu_info = data_pkl.info[imu_logger]['channelinfo']
    
    if ephys_logger:
        ephys_df = data_pkl.data[ephys_logger]
        ephys_fs = 1 / ephys_df['datetime'].diff().dt.total_seconds().mean()
        ephys_df_downsampled = downsample(ephys_df, ephys_fs, ephys_sampling_rate)
        ephys_info = data_pkl.info[ephys_logger]['channelinfo']

    row_counter = 1
    
    for channel_type, channels in ordered_channels:
        if channel_type == 'ECG' and ephys_channels and 'ecg' in [ch.lower() for ch in ephys_channels]:
            # Plot ECG
            channel = 'ecg'
            df = ephys_df_downsampled
            info = ephys_info
            original_name = info[channel]['original_name']
            unit = info[channel]['unit']

            y_data = df[channel]
            x_data = df['datetime']

            y_label = f"{original_name} [{unit}]"
            color = color_mapping.get(original_name, color_mapping['ECG'])

            fig.add_trace(go.Scatter(
                x=x_data,
                y=y_data,
                mode='lines',
                name=y_label,
                line=dict(color=color)
            ), row=row_counter, col=1)

            # Add vertical lines for heartbeats
            filtered_notes = data_pkl.notes_df[data_pkl.notes_df['key'] == 'heartbeat_manual_ok']
            if not filtered_notes.empty:
                for dt in filtered_notes['datetime']:
                    fig.add_trace(go.Scatter(
                        x=[dt, dt],
                        y=[y_data.min(), y_data.max()],
                        mode='lines',
                        line=dict(color=color_mapping['Filtered Heartbeats'], width=1, dash='dot'),
                        showlegend=False
                    ), row=row_counter, col=1)

            fig.update_yaxes(title_text=y_label, row=row_counter, col=1)
            row_counter += 1

        elif channel_type == 'Depth' and 'depth' in [ch.lower() for ch in imu_channels]:
            # Plot Depth
            channel = 'depth'
            df = imu_df_downsampled
            info = imu_info
            original_name = info[channel]['original_name']
            unit = info[channel]['unit']

            y_data = df[channel]
            x_data = df['datetime']

            y_label = f"{original_name} [{unit}]"
            color = color_mapping.get(original_name, color_mapping['Depth'])

            fig.add_trace(go.Scatter(
                x=x_data,
                y=y_data,
                mode='lines',
                name=y_label,
                line=dict(color=color)
            ), row=row_counter, col=1)

            fig.update_yaxes(title_text=y_label, autorange="reversed", row=row_counter, col=1)
            row_counter += 1

        elif channel_type in ['Accel', 'Gyro', 'Mag']:
            # Plot Accel, Gyro, or Mag channels together
            for sub_channel in channels:
                if sub_channel in imu_df_downsampled.columns:
                    df = imu_df_downsampled
                    info = imu_info
                    original_name = info[sub_channel]['original_name']
                    unit = info[sub_channel]['unit']

                    y_data = df[sub_channel]
                    x_data = df['datetime']

                    y_label = f"{original_name} [{unit}]"
                    color = color_mapping.get(original_name, '#000000')

                    fig.add_trace(go.Scatter(
                        x=x_data,
                        y=y_data,
                        mode='lines',
                        name=y_label,
                        line=dict(color=color)
                    ), row=row_counter, col=1)

            fig.update_yaxes(title_text=f"{channel_type} [{unit}]", row=row_counter, col=1)
            row_counter += 1

    fig.update_layout(
        height=200 * num_rows,
        width=1200,
        title_text=f"{data_pkl.selected_deployment['Deployment Name']}",
        showlegend=True
    )
    
    fig.update_xaxes(title_text="Datetime", row=row_counter-1, col=1)

    if draw:
        fig.show()
    else:
        return fig

# Example usage:
# Specify channels and loggers
imu_channels_to_plot = ['depth', 'accX', 'accY', 'accZ', 'gyrX', 'gyrY', 'gyrZ', 'magX', 'magY', 'magZ']
ephys_channels_to_plot = ['ecg']
imu_logger_to_use = 'CC-96'
ephys_logger_to_use = 'UF-01'

plot_tag_data(data_pkl, imu_channels_to_plot, ephys_channels=ephys_channels_to_plot, imu_logger=imu_logger_to_use, ephys_logger=ephys_logger_to_use)


In [None]:
data_pkl.info['channelnames']['CC-96']

In [None]:
import plotly.graph_objs as go
from plotly.subplots import make_subplots

# Assuming your new data columns are as follows:
# ECG signal column: 'ecg'
# Depth column: 'depth1'
# Accelerometer columns: 'accX', 'accY', 'accZ'
# Gyroscope column: 'gyrY'

# Create subplots
fig = make_subplots(rows=5, cols=1, shared_xaxes=True, vertical_spacing=0.03)

# Add Heart Rate (bpm) plot at the top
fig.add_trace(go.Scatter(
    x=filtered_notes['datetime'], 
    y=filtered_notes['value'], 
    mode='markers', 
    marker=dict(color='gray', size=8, symbol='circle-open'),
    name='Heart rate (bpm)'
), row=1, col=1)

# Add ECG plot with light red color and alpha 0.2
fig.add_trace(go.Scatter(
    x=ecg_df50['datetime'], 
    y=ecg_df50['ecg'], 
    mode='lines', 
    name='ECG [mV]', 
    line=dict(color='rgba(255, 0, 0, 0.2)')
), row=2, col=1)

# Add vertical dotted lines for detected heartbeats
for dt in filtered_notes['datetime']:
    fig.add_trace(go.Scatter(
        x=[dt, dt], 
        y=[ecg_df50['ecg'].min(), ecg_df50['ecg'].max()], 
        mode='lines', 
        line=dict(color='gray', width=1, dash='dot'),
        showlegend=False
    ), row=2, col=1)

# Add Depth plot with dark blue color
fig.add_trace(go.Scatter(
    x=CO_df10['datetime'], 
    y=CO_df10['depth1'], 
    mode='lines', 
    name='Depth [m]', 
    line=dict(color='darkblue')
), row=3, col=1)
fig.update_yaxes(autorange="reversed", row=3, col=1)

# Add Accelerometer plots on the same y-axis
fig.add_trace(go.Scatter(
    x=CO_df10['datetime'], 
    y=CO_df10['accX'], 
    mode='lines', 
    name='Accel X [m/s²]', 
    line=dict(color='blue')
), row=4, col=1)

fig.add_trace(go.Scatter(
    x=CO_df10['datetime'], 
    y=CO_df10['accY'], 
    mode='lines', 
    name='Accel Y [m/s²]', 
    line=dict(color='green')
), row=4, col=1)

fig.add_trace(go.Scatter(
    x=CO_df10['datetime'], 
    y=CO_df10['accZ'], 
    mode='lines', 
    name='Accel Z [m/s²]', 
    line=dict(color='red')
), row=4, col=1)

# Add Gyroscope Y plot
fig.add_trace(go.Scatter(
    x=CO_df10['datetime'], 
    y=CO_df10['gyrY'], 
    mode='lines', 
    name='Gyr Y [mrad/s]', 
    line=dict(color='purple')
), row=5, col=1)

# Update layout
fig.update_layout(
    height=800, 
    width=1200, 
    title_text=f"{data_pkl.selected_deployment['Deployment Name']}", 
    showlegend=True
)
fig.update_xaxes(title_text="Datetime", row=5, col=1)

# Update y-axes labels
fig.update_yaxes(title_text="Heart rate (bpm)", row=1, col=1)
fig.update_yaxes(title_text="ECG [mV]", row=2, col=1)
fig.update_yaxes(title_text="Depth [m]", row=3, col=1)
fig.update_yaxes(title_text="Accelerometer [m/s²]", row=4, col=1)
fig.update_yaxes(title_text="Gyr Y [mrad/s]", row=5, col=1)

# Show plot
fig.show()

In [None]:
# Save the interactive plot as an HTML file
fig.write_html(os.path.join(deployment_folder, "outputs", f"{data_pkl.selected_deployment['Deployment Name']}.html")) 
data_pkl.selected_deployment['Deployment Name']