In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.markers import MarkerStyle
import seaborn as sns
import tdt
import trompy as tp

import dill

In [None]:
DATAFOLDER = Path("..\\data")
df = pd.read_excel("..\\experiment_info_gsheet.xlsx", sheet_name="Sheet1")

In [None]:
# this function should get snips from the tank
# should add pre, post and other functions
def get_snips(tank):
    # preprocess streams
    
    data = tdt.read_block(tank)
    blue = data.streams["x65A"].data
    uv = data.streams["x05A"].data
    fs = data.streams["x05A"].fs
    
    filtered_sig = tp.processdata(blue, uv, fs=fs)
    
    # get epochs
    sol = data.epocs.Sper.onset
    
    #make snips
    snips = tp.snipper(filtered_sig, sol, fs=fs, pre=10, post=20, bins=300)[0]
    
    # baseline correction - correct to minimum value in the pre period
    snips = snips - snips[:, :100].min(axis=1)[:, None]
    # and maybe scale them so they all go between -1 and 1 based on the same parameters
    
    # return snips
    return snips

TANK = str(DATAFOLDER / "Svg130-210727-115203")
snips = get_snips(TANK)

In [None]:
snips_dict = {}
for row in df.iterrows():
    TANK = str(DATAFOLDER / row[1].tank)
    snips = get_snips(TANK)
    
    snips_dict[f"{row[1].rat}_{row[1].condition}"] = snips

In [None]:
with open(DATAFOLDER / "snips_dict.pickle", "wb") as f:
    dill.dump(snips_dict, f)

In [None]:
def snips_fig(snips):
    fig, ax = plt.subplots()
    for snip in snips:
        ax.plot(snip, color="black", alpha=0.3)

    ax.plot(np.mean(snips, axis=0))
    
def snips_changing_baseline_fig(snips, highlights=[]):
    fig, ax = plt.subplots()
    for idx, snip in enumerate(snips):
        if idx in highlights:
            ax.plot(snip+idx, color="red", alpha=0.99)
        else:
            ax.plot(snip+idx, color="black", alpha=0.3)

def heatmap(snips):
    fig, ax = plt.subplots()
    sns.heatmap(snips, cmap="Greys", ax=ax, vmin=0, vmax=5)
    

In [None]:
snips = snips_dict["SVG135_FD"]
highlights = [1, 8, 10, 23, 26]

f, [[ax1, ax2],
    [ax3, ax4]] = plt.subplots(nrows=2, ncols=2,
                     gridspec_kw={'width_ratios': [20, 1],
                                  "hspace": 0.1,
                                  "wspace": 0.1},
                     figsize=(5, 5),
                     )

sns.heatmap(snips, cmap="Greys", ax=ax1, vmin=0, vmax=4, cbar_ax=ax2)
ax1.set_xticks([])
ax1.set_yticks([])

right_triangle = MarkerStyle(marker=(3, 0, -90))
for hl in highlights:
    ax1.plot(-5, hl+0.5, marker=right_triangle, color="red", zorder=20, clip_on=False)

ax1.plot((250, 300), (32, 32), color="black", clip_on=False)
ax1.text(275, 33, "5 s", ha="center", va="top")


for hl in highlights:
    ax3.plot(snips[hl,:], color="red", alpha=0.4)

ax3.plot(np.mean(snips, axis=0), color="black")
ax3.set_xticks([])
ax3.set_yticks([])
ax3.sharex(ax1)
ax3.axhline(0, linestyle="--", color="black")

for spines in ["top", "right", "left", "bottom"]:
    ax3.spines[spines].set_visible(False)

ax3.plot((310, 310), (1, 2), color="black", clip_on=False)
ax3.text(312, 1.5, "1Z", ha="left", va="center")

ax4.remove()


In [None]:
heatmap(snips_dict["SVG135_FD"])

In [None]:
snips_changing_baseline_fig(snips_dict["SVG135_FD"], highlights=[7, 12, 23])

In [None]:
def calculate_minima(data, window_size):
    """
    Calculate the rolling minima over a specified window size.

    Parameters:
    data (pd.Series or pd.DataFrame): The input data.
    window_size (int): The size of the rolling window.

    Returns:
    pd.Series or pd.DataFrame: The rolling minima.
    """
    return data.rolling(window=window_size).min()

calculate_minima(pd.Series(snips[0,:]), 10)