# Tag to animal frame: re-orienting tag to match animal's axes

Here's a tutorial my friend Max developed to help with this headache: https://flukeandfeather.com/posts/2024-08-30-animal-orientation-with-imu-ta/

## Load and inspect data
Load pickle file and inspect contents

In [None]:
import os
import pickle

# Import necessary pyologger utilities
from pyologger.load_data.datareader import DataReader
from pyologger.load_data.metadata import Metadata
from pyologger.plot_data.plotter import *
from pyologger.process_data.sampling import upsample
from pyologger.calibrate_data.zoc import *
from pyologger.plot_data.plotter import plot_depth_correction
from pyologger.calibrate_data.calibrate_acc_mag import *

# Change the current working directory to the root directory
# os.chdir("/Users/fbar/Documents/GitHub/pyologger")
os.chdir("/Users/jessiekb/Documents/GitHub/pyologger")

root_dir = os.getcwd()
data_dir = os.path.join(root_dir, "data")
color_mapping_path = os.path.join(root_dir, "color_mappings.json")

# Verify the current working directory
print(f"Current working directory: {root_dir}")

In [None]:
# Initialize the info class
metadata = Metadata()
metadata.fetch_databases(verbose=False)

# Save databases
dep_db = metadata.get_metadata("dep_DB")
logger_db = metadata.get_metadata("logger_DB")
rec_db = metadata.get_metadata("rec_DB")
animal_db = metadata.get_metadata("animal_DB")

# Assuming you have the metadata and dep_db loaded:
datareader = DataReader()
deployment_folder = datareader.check_deployment_folder(dep_db, data_dir)

if deployment_folder:
    datareader.read_files(metadata, save_csv=True, save_parq=True)

In [None]:
# Load the data_reader object from the pickle file
pkl_path = os.path.join(deployment_folder, 'outputs', 'data.pkl')

with open(pkl_path, 'rb') as file:
    data_pkl = pickle.load(file)

for logger_id, info in data_pkl.info.items():
    sampling_frequency = info.get('datetime_metadata', {}).get('fs', None)
    if sampling_frequency is not None:
        # Format the sampling frequency to 5 significant digits
        print(f"Sampling frequency for {logger_id}: {sampling_frequency} Hz")
    else:
        print(f"No sampling frequency available for {logger_id}")

In [11]:
# Change out preferred source of IMU or ephys data depending on your deployment
imu_logger = 'CC-96'
ephys_logger = 'UF-01'

acc_data = np.vstack((data_pkl.data[imu_logger]['accX_adjusted'].values, 
                              data_pkl.data[imu_logger]['accY_adjusted'].values, 
                              data_pkl.data[imu_logger]['accZ_adjusted'].values)).T
mag_data = np.vstack((data_pkl.data[imu_logger]['magX_adjusted'].values, 
                        data_pkl.data[imu_logger]['magY_adjusted'].values, 
                        data_pkl.data[imu_logger]['magZ_adjusted'].values)).T
gyr_data = np.vstack((data_pkl.data[imu_logger]['gyrX'].values, 
                        data_pkl.data[imu_logger]['gyrY'].values, 
                        data_pkl.data[imu_logger]['gyrZ'].values)).T

# # First, check if corrected accelerometer data exists
# if 'corr_accX' in data_pkl.data[imu_logger].columns:
#     print("Corrected accelerometer data ('corr_accX', 'corr_accY', 'corr_accZ') is available. No further adjustments needed.")
# else:
#     # If corrected data does not exist, check if adjusted accelerometer data exists
#     if 'accX_adjusted' in data_pkl.data[imu_logger].columns and 'magX_adjusted' in data_pkl.data[imu_logger].columns:
#         # Use adjusted accelerometer data
#         acc_data = np.vstack((data_pkl.data[imu_logger]['accX_adjusted'].values, 
#                               data_pkl.data[imu_logger]['accY_adjusted'].values, 
#                               data_pkl.data[imu_logger]['accZ_adjusted'].values)).T
#         mag_data = np.vstack((data_pkl.data[imu_logger]['magX_adjusted'].values, 
#                               data_pkl.data[imu_logger]['magY_adjusted'].values, 
#                               data_pkl.data[imu_logger]['magZ_adjusted'].values)).T
#         gyr_data = np.vstack((data_pkl.data[imu_logger]['gyrX'].values, 
#                               data_pkl.data[imu_logger]['gyrY'].values, 
#                               data_pkl.data[imu_logger]['gyrZ'].values)).T
#         print("Using calibrated accelerometer and magnetometer values from Step 2.")
#     else:
#         # Throw a warning and use raw accelerometer data
#         print("Warning: Accelerometer values have not yet been adjusted. Check Step 2: Calibrating accelerometer and magnetometer. Using raw accelerometer values (accX, accY, accZ) instead.")
#         acc_data = np.vstack((data_pkl.data[imu_logger]['accX'].values, 
#                               data_pkl.data[imu_logger]['accY'].values, 
#                               data_pkl.data[imu_logger]['accZ'].values)).T
#         mag_data = np.vstack((data_pkl.data[imu_logger]['magX'].values, 
#                               data_pkl.data[imu_logger]['magY'].values, 
#                               data_pkl.data[imu_logger]['magZ'].values)).T
#         gyr_data = np.vstack((data_pkl.data[imu_logger]['gyrX'].values, 
#                               data_pkl.data[imu_logger]['gyrY'].values, 
#                               data_pkl.data[imu_logger]['gyrZ'].values)).T

In [None]:
import numpy as np
from magnetic_field_calculator import MagneticFieldCalculator

def orientation_and_heading_correction(abar0, acc_data, mag_data, gyr_data=None):
    """
    Corrects the orientation and heading of tag data to align with the reference frame of an animal.

    Parameters
    ----------
    abar0 : array_like
        A 3-element vector representing the accelerometer readings when the animal is stationary on its belly.
    acc_data : array_like
        A 2D array where each row represents a 3-element accelerometer reading.
    mag_data : array_like
        A 2D array where each row represents a 3-element magnetometer reading.
    gyr_data : array_like, optional
        A 2D array where each row represents a 3-element gyroscope reading. If provided, gyroscope data will also be corrected.

    Returns
    -------
    pitch_deg : array_like
        The pitch angles in degrees for the entire dataset.
    roll_deg : array_like
        The roll angles in degrees for the entire dataset.
    heading_deg : array_like
        The heading angles in degrees for the entire dataset.
    acc_corr : array_like
        The corrected accelerometer data.
    mag_corr : array_like
        The corrected magnetometer data.
    gyr_corr : array_like, optional
        The corrected gyroscope data. Only returned if gyr_data is provided.

    Notes
    -----
    This function seeks to rotate tag data to the reference frame of an animal.
    The function first normalizes the provided stationary accelerometer vector, computes the pitch and roll angles,
    and then applies the corresponding rotation matrices to correct the input accelerometer, magnetometer, and optionally,
    gyroscope data. The function returns the corrected orientation in terms of pitch, roll, and heading angles.

    Examples
    --------
    >>> abar0 = np.array([0.1, 0.2, -0.98])
    >>> acc_data = np.random.rand(100, 3)
    >>> mag_data = np.random.rand(100, 3)
    >>> pitch_deg, roll_deg, heading_deg, acc_corr, mag_corr = orientation_and_heading_correction(abar0, acc_data, mag_data)
    >>> pitch_deg, roll_deg, heading_deg, acc_corr, mag_corr, gyr_corr = orientation_and_heading_correction(abar0, acc_data, mag_data, gyr_data=np.random.rand(100, 3))
    """
    # Normalize abar0 to create abar
    abar = abar0 / np.linalg.norm(abar0)
    
    # Calculate initial pitch (p0) and roll (r0)
    p0 = -np.arcsin(abar[0])
    r0 = np.arctan2(abar[1], abar[2])
    # Constrain p to [-pi / 2, pi / 2]
    if p0 > np.pi / 2:
        p0 = np.pi / 2 - p0
        r0 = r0 + np.pi

    # Define rotation matrices for pitch and roll
    def rotP(p):
        return np.array([[np.cos(p), 0, np.sin(p)],
                         [0, 1, 0],
                         [-np.sin(p), 0, np.cos(p)]])
    
    def rotR(r):
        return np.array([[1, 0, 0],
                         [0, np.cos(r), -np.sin(r)],
                         [0, np.sin(r), np.cos(r)]])
    
    # Calculate rotation matrix W
    W = np.matmul(rotP(p0), rotR(r0)).T

    # Correct the accelerometer and magnetometer data for the entire dataset
    acc_corr = np.matmul(acc_data, W)
    mag_corr = np.matmul(mag_data, W)
    
    # Correct the gyroscope data if provided
    if gyr_data is not None:
        gyr_corr = np.matmul(gyr_data, W)
    else:
        gyr_corr = None
    
    # Calculate magnitude of the corrected accelerometer vectors
    A = np.linalg.norm(acc_corr, axis=1)
    
    # Calculate pitch and roll in degrees from corrected accelerometer data
    pitch_deg = -np.degrees(np.arcsin(acc_corr[:, 0] / A))
    roll_deg = np.degrees(np.arctan2(acc_corr[:, 1], acc_corr[:, 2]))
    
    #mag_horiz = np.matmul(np.matmul(mag_corr, rotR(np.deg2rad(roll_deg)).T), rotP(np.deg2rad(pitch_deg)).T) # gimbaling by applying 1. un-roll and then 2. un-pitch
    # Initialize an array to hold the gimbaled magnetic data
    mag_horiz = np.zeros_like(mag_corr)
    
    # Apply the un-roll and un-pitch rotation for each time step
    for i in range(len(pitch_deg)):
        mag_horiz[i, :] = np.matmul(np.matmul(mag_corr[i, :], rotR(np.deg2rad(roll_deg[i])).T), rotP(np.deg2rad(pitch_deg[i])).T)

    latitude = 32.764567  # Example latitude
    longitude = -117.228665  # Example longitude

    # Get the declination using MagneticFieldCalculator
    calculator = MagneticFieldCalculator()
    result = calculator.calculate(latitude=latitude, longitude=longitude)
    declination = result['field-value']['declination']

    print(f"The declination at latitude {latitude} and longitude {longitude} is {declination} degrees.")
    # Calculate heading in degrees from corrected magnetometer data
    heading_deg = np.degrees(np.arctan2(mag_horiz[:, 1], mag_horiz[:, 0])) + declination['value']
    
    # Return the corrected pitch, roll, and heading for the entire dataset
    if gyr_corr is not None:
        return pitch_deg, roll_deg, heading_deg, acc_corr, mag_corr, gyr_corr
    else:
        return pitch_deg, roll_deg, heading_deg, acc_corr, mag_corr

abar0 = [0, 0, -9.8]
# Use the function to get corrected orientation and heading for the entire dataset
pitch_deg, roll_deg, heading_deg, acc_corr, mag_corr, gyr_corr = orientation_and_heading_correction(
    abar0, 
    acc_data=acc_data, 
    mag_data=mag_data, 
    gyr_data=gyr_data)


In [None]:
import matplotlib.pyplot as plt

plt.plot(gyr_corr)
plt.xlabel('Time')
plt.ylabel('Gyroscope Data')
plt.title('Gyroscope Data Plot')
plt.show()

In [None]:
data_pkl.info[imu_logger]['channelinfo']['accX']

In [None]:
data = data_pkl.data[imu_logger][0:5]

In [None]:
acc_corr[0:10]

In [None]:
# Define the metadata for the corrected channels
# Define the metadata for the corrected channels
corrected_channels = ['corr_accX', 'corr_accY', 'corr_accZ', 
                      'corr_magX', 'corr_magY', 'corr_magZ', 
                      'corr_gyrX', 'corr_gyrY', 'corr_gyrZ', 
                      'pitch', 'roll', 'heading']

# Loop through the corrected channels and add the corrected data and metadata
for channel in corrected_channels:
    if 'corr_acc' in channel:
        index = ['corr_accX', 'corr_accY', 'corr_accZ'].index(channel)
        data_pkl.data[imu_logger][channel] = acc_corr[:, index]
        # Dynamically update the metadata
        data_pkl.info[imu_logger]['channelinfo'][channel] = {
            'original_name': f"Corrected {data_pkl.info[imu_logger]['channelinfo'][channel.replace('corr_', '')]['original_name']}",
            'unit': data_pkl.info[imu_logger]['channelinfo'][channel.replace('corr_', '')]['unit']
        }
    elif 'corr_mag' in channel:
        index = ['corr_magX', 'corr_magY', 'corr_magZ'].index(channel)
        data_pkl.data[imu_logger][channel] = mag_corr[:, index]
        # Dynamically update the metadata
        data_pkl.info[imu_logger]['channelinfo'][channel] = {
            'original_name': f"Corrected {data_pkl.info[imu_logger]['channelinfo'][channel.replace('corr_', '')]['original_name']}",
            'unit': data_pkl.info[imu_logger]['channelinfo'][channel.replace('corr_', '')]['unit']
        }
    elif 'corr_gyr' in channel:
        index = ['corr_gyrX', 'corr_gyrY', 'corr_gyrZ'].index(channel)
        data_pkl.data[imu_logger][channel] = gyr_corr[:, index]
        # Dynamically update the metadata
        data_pkl.info[imu_logger]['channelinfo'][channel] = {
            'original_name': f"Corrected {data_pkl.info[imu_logger]['channelinfo'][channel.replace('corr_', '')]['original_name']}",
            'unit': data_pkl.info[imu_logger]['channelinfo'][channel.replace('corr_', '')]['unit']
        }
    elif channel == 'pitch':
        data_pkl.data[imu_logger][channel] = pitch_deg
        data_pkl.info[imu_logger]['channelinfo'][channel] = {
            'original_name': 'Pitch [degrees]',
            'unit': 'degrees'
        }
    elif channel == 'roll':
        data_pkl.data[imu_logger][channel] = roll_deg
        data_pkl.info[imu_logger]['channelinfo'][channel] = {
            'original_name': 'Roll [degrees]',
            'unit': 'degrees'
        }
    elif channel == 'heading':
        data_pkl.data[imu_logger][channel] = heading_deg
        data_pkl.info[imu_logger]['channelinfo'][channel] = {
            'original_name': 'Heading [degrees]',
            'unit': 'degrees'
        }

imu_channels_to_plot = ['depth', 
                        'accX', 'accY', 'accZ', 
                        'corr_accX', 'corr_accY', 'corr_accZ', 
                        'gyrX', 'gyrY', 'gyrZ', 
                        'corr_gyrX', 'corr_gyrY', 'corr_gyrZ', 
                        'corr_magX', 'corr_magY', 'corr_magZ', 
                        'pitch', 'roll', 'heading']
ephys_channels_to_plot = []
imu_logger_to_use = imu_logger
ephys_logger_to_use = ephys_logger

# Get the overlapping time range
imu_df = data_pkl.data[imu_logger_to_use]
ephys_df = data_pkl.data[ephys_logger_to_use]
start_time = max(imu_df['datetime'].min(), ephys_df['datetime'].min()).to_pydatetime()
end_time = min(imu_df['datetime'].max(), ephys_df['datetime'].max()).to_pydatetime()

# Define notes to plot
notes_to_plot = {
    'exhalation_breath': 'depth'
}

plot_tag_data_interactive2(data_pkl, imu_channels_to_plot, imu_sampling_rate=5, ephys_channels=ephys_channels_to_plot, 
                          imu_logger=imu_logger_to_use, ephys_logger=ephys_logger_to_use, note_annotations= notes_to_plot,
                          time_range=(start_time, end_time), color_mapping_path=color_mapping_path)

In [None]:
# Optional: save new pickle file

with open(pkl_path, 'wb') as file:
        pickle.dump(data_pkl, file)

In [None]:
# IN PROGRESS, NOT WORKING YET

import matplotlib.pyplot as plt
import numpy as np

def plot_acc_for_exhalation_breaths(data_pkl):
    """
    Plot accX, accY, and accZ values for each exhalation breath event and 
    return the average acceleration vector (abar0) around the events.

    Parameters
    ----------
    data_pkl : object
        The structured data object containing sensor data and notes_df.

    Returns
    -------
    abar0 : numpy.ndarray
        A vector containing the mean of accX, accY, and accZ during the 10 seconds 
        surrounding each exhalation breath event.
    """
    # Extract the relevant accelerometer data
    accX = data_pkl.data['CC-96']['accX'].values
    accY = data_pkl.data['CC-96']['accY'].values
    accZ = data_pkl.data['CC-96']['accZ'].values
    datetime_data = data_pkl.data['CC-96']['datetime'].values

    # Convert datetime_data to numpy.datetime64 for proper subtraction
    datetime_data = np.array(datetime_data, dtype='datetime64[ns]')

    # Filter the notes_df for 'exhalation_breath' events
    breath_events = data_pkl.notes_df[data_pkl.notes_df['key'] == 'exhalation_breath']

    # Initialize lists to store the surrounding data
    accX_segments = []
    accY_segments = []
    accZ_segments = []

    # Plot the accelerometer data for each breath event
    plt.figure(figsize=(15, 10))

    for _, event in breath_events.iterrows():
        event_time = event['datetime']  # Convert event_time to numpy.datetime64

        # Round to the nearest second
        event_time_rounded = event_time.round('1s')

        # Find the index of the rounded event time
        rounded_time_index = datetime_data.tolist().index(event_time_rounded)

        # Define the window of ±5 seconds around the breath event
        time_window = int(5 * data_pkl.info['CC-96']['datetime_metadata']['fs'])
        start_index = max(rounded_time_index - time_window, 0)
        end_index = min(rounded_time_index + time_window, len(datetime_data))

        # Extract the segments
        accX_segment = accX[start_index:end_index]
        accY_segment = accY[start_index:end_index]
        accZ_segment = accZ[start_index:end_index]
        time_segment = datetime_data[start_index:end_index]

        # Append to lists
        accX_segments.append(accX_segment)
        accY_segments.append(accY_segment)
        accZ_segments.append(accZ_segment)

        # Plot accX, accY, and accZ around the breath event
        plt.plot(time_segment, accX_segment, label='accX', color='blue', alpha=0.5)
        plt.plot(time_segment, accY_segment, label='accY', color='green', alpha=0.5)
        plt.plot(time_segment, accZ_segment, label='accZ', color='red', alpha=0.5)

        # Highlight the breath event
        plt.axvline(datetime_data[closest_time_index], color='black', linestyle='--', alpha=0.7)

    plt.xlabel('Datetime')
    plt.ylabel('Acceleration (g)')
    plt.title('Accelerometer Data (accX, accY, accZ) During Exhalation Breaths')
    plt.legend(loc='upper right')
    plt.xticks(rotation=45)
    plt.grid(True)
    plt.show()

    # Calculate the mean vector abar0
    mean_accX = np.mean(np.concatenate(accX_segments))
    mean_accY = np.mean(np.concatenate(accY_segments))
    mean_accZ = np.mean(np.concatenate(accZ_segments))
    abar0 = np.array([mean_accX, mean_accY, mean_accZ])

    return abar0

# Example usage (doesn't work currently)
#abar0 = plot_acc_for_exhalation_breaths(data_pkl)
#print("Average acceleration vector (abar0):", abar0)
