# Inserting into Photometry, PhotometrySynced, and BehaviorIngestion

## Requirements before beginning: 

* dj_local_config.json: edited with appropriate information
* Your data directories are set up in the proper format under the O2 filesystem (please refer to documentation)
* You have created the appropriate .toml file to pair along with your photometry recordings (please refer to documentation)

In [None]:
import os
if os.path.basename(os.getcwd()) == "notebooks": os.chdir("..")
import datajoint as dj
dj.config.load('dj_local_config.json')
dj.conn()

from __future__ import annotations
import datajoint as dj
import pandas as pd
import numpy as np
import warnings
from pathlib import Path
import tomli
import tdt
import typing as T
from copy import deepcopy
import scipy.io as spio
from scipy import signal
from scipy.signal import blackman
from scipy.fft import fft, ifft, rfft

from element_interface.utils import find_full_path
from workflow import db_prefix
from workflow.pipeline import session, subject, lab, reference, ingestion, event, trial, photometry
from workflow.utils.paths import get_raw_root_data_dir
import workflow.utils.photometry_preprocessing as pp
from workflow.utils import demodulation

## Insert subject, session, and session dir.

In [None]:
subject.Subject.insert1(dict(subject='O2Test', 
                             sex='M', 
                             subject_birth_date='2021-10-01', 
                             subject_description='TestingO2functionality'))


In [None]:
session_key = dict(subject = 'O2Test', session_id=1, session_datetime = '2021-10-07 12:00:00')

#session.Session.insert1(session_key)

In [None]:
#remember, your sessionDirectory is relative to the path in your config file!
session.SessionDirectory.insert1(dict(subject=session_key['subject'], session_id=session_key['session_id'],
                                      session_dir='O2Test/Session1'))

In [None]:
#view your subject, session_id, and session_dir
session.SessionDirectory()

## Populate your pipeline of choice

### For Photometry pipeline:

In [None]:
session_key = (session.Session() & "subject='O2Test'").fetch1("KEY")

In [None]:
sd_key = dict(session_key, session_dir = r'O2Test/Session1')

In [None]:
photometry.FiberPhotometry.populate(sd_key)

### For PhotometrySynced pipeline:

In [None]:
photometry.FiberPhotometrySynced.populate(session_key)

### For Behavior pipeline:

In [None]:
ingestion.BehaviorIngestion.populate(sd_key)

In [None]:
ingestion.BehaviorIngestion()

#### We can now begin to fetch the data to view it. We'll start by first looking at the demodulated traces.

In [None]:
#Fetch one fiber
fiber_id = 1
photometry.FiberPhotometry.DemodulatedTrace & session_key & f"fiber_id = '{fiber_id}'"

In [None]:
photometry.FiberPhotometry.DemodulatedTrace & session_key

In [None]:
# Fetch photometry trace
trace_name = "photom"
emission_color = "green"
hemisphere = "right"

restr = {
         "trace_name": trace_name, 
         "emission_color": emission_color,
         "hemisphere": hemisphere
         }
query = photometry.FiberPhotometry.DemodulatedTrace() & session_key & restr
trace = query.fetch1("trace")
query

In [None]:
#Plot a single trace (demodulated)
import seaborn as sns
import matplotlib.pyplot as plt 
from scipy.stats import sem

fig, ax = plt.subplots(figsize=(15,2))

ax.plot(trace, 'k', lw=0.5)
ax.set(xlabel='Time (s)', ylabel='Amplitude')
sns.despine()

In [None]:
# Plot all phoeomtry traces during the session
query = photometry.FiberPhotometry.DemodulatedTrace & session_key

i= 8
inc_height= -1.5
window_start = 1000 
window_stop = 3000
fig, ax = plt.subplots(figsize=(10, 3))
sns.set_palette('deep',n_colors=len(query))

for j, trace in enumerate(photometry.FiberPhotometry.DemodulatedTrace.fetch("trace_name", "emission_color", "hemisphere", "trace", as_dict=True)):
    name = '_'.join([trace["trace_name"], trace["emission_color"], trace["hemisphere"]])
    ax.plot(pp.normalize(pd.DataFrame(trace["trace"]), window=500)[window_start:window_stop] + i, 
            label=name); i += inc_height
    ax.text(x = window_stop + 2,
            y=i-inc_height, s=name, fontsize=12, va="bottom", color=sns.color_palette()[j])
    
ax.set_title(f"{session_key}")
ax.set_xlabel("Time (s)")
ax.set_yticks([])
sns.despine(left=True)

#### We can then take a look at the event-related photometry traces.

In [None]:
# view the names of your defined events
event.EventType & session_key

In [None]:
# Behavioral events during the session
event.Event & session_key

In [None]:
#fetch your event types and/or call particular ones
event_types = (event.EventType & session_key).fetch("event_type")

In [None]:
# Plot peri-event photometry traces per hemisphere
event_types = (event.EventType & session_key).fetch("event_type")
trace_name = "photom"
emission_color = "green"

restr = {
    "trace_name": trace_name,
    "emission_color": emission_color,
    "hemisphere": "left"
}
query = photometry.FiberPhotometrySynced.SyncedTrace() & session_key & restr
traces = query.fetch("trace")

time_buffer = (1, 3)  # before and after each event
sample_rate = photometry.FiberPhotometry.DemodulatedTrace.fetch("demod_sample_rate")[0]
timestamps = np.array((photometry.FiberPhotometrySynced & session_key).fetch1("timestamps"))
timestamps = timestamps/sample_rate


In [None]:
fig, axes = plt.subplots(1, len(event_types), figsize=(23, 3))
for ind, (event_type, ax) in enumerate(zip(event_types, axes)):

    event_traces = []  # Store traces for this event type

    for time, trace in zip(timestamps, traces):
        times = np.linspace(timestamps[0], timestamps[-1], len(trace))
        df = pd.DataFrame({"timestamps": times, "photometry_trace": trace})

        # Query the event_start_time for the respective event type
        query = event.Event & session_key & f"event_type='{event_type}'"
        event_ts = query.fetch("event_start_time")

        # Iterate over each event time
        for ts in event_ts:
            # Find the corresponding index in the trace for the event time
            index = np.searchsorted(df["timestamps"], ts)

            # Define the time window around the event
            window_start = index - int(time_buffer[0])
            window_end = index + int(time_buffer[1]) + 1

            # Extract the peri-event window
            peri_event_window = df.iloc[window_start:window_end]

            event_traces.append(peri_event_window["photometry_trace"].values)
    if event_traces:  # Check if there are event traces
        event_traces = np.array(event_traces)  # trial x time

        # Compute the mean and standard error of the event traces
        mean_trace = np.mean(event_traces, axis=0)
        sem_trace = sem(event_traces, axis=0)
        mean_trace_timestamps = np.arange(len(mean_trace))

        # Plot the mean trace with standard error
        ax.plot(mean_trace_timestamps, mean_trace, label=event_type, lw=2)
        ax.fill_between(mean_trace_timestamps, mean_trace - sem_trace, mean_trace + sem_trace, alpha=0.3)

    ax.axvline(x=0, linewidth=0.5, ls='--')
    if ind == 0:
        ax.set_ylabel("Trace Name", fontsize=15)
    ax.set(xlabel='Time (s)', title=event_type)
    sns.despine()

plt.legend(loc='center left', bbox_to_anchor=(0.98, 0.5))
plt.show()