In [14]:
%matplotlib qt

import re
import h5py
from tqdm import tqdm
import numpy as np
from scipy.interpolate import interp1d
from scipy.signal import detrend
from matplotlib import pyplot as plt

from vrAnalysis import fileManagement as files
from vrAnalysis.helpers import errorPlot

data_path = files.localDataPath()

file_tree = dict(
    in_time = "DataAcquisition/FPConsole/Signals/Series0001/AnalogIn/Time",
    in_data = "DataAcquisition/FPConsole/Signals/Series0001/AnalogIn/AIN03",
    out_time = "DataAcquisition/FPConsole/Signals/Series0001/AnalogOut/Time",
    out1 = "DataAcquisition/FPConsole/Signals/Series0001/AnalogOut/AOUT01",
    out2 = "DataAcquisition/FPConsole/Signals/Series0001/AnalogOut/AOUT02",
    out3 = "DataAcquisition/FPConsole/Signals/Series0001/AnalogOut/AOUT03",
)

def create_file_dict(file):
    data = {}
    for key, value in file_tree.items():
        data[key] = np.array(file[value])
    return data

def get_doric_files(mouse_name):
    """Get all doric files and there dates from the data path"""
    directory = []
    file_index = []
    data = []
    mouse_directory = data_path / mouse_name
    date_directories = [x for x in mouse_directory.iterdir() if x.is_dir()]
    for date_directory in date_directories:
        for file in date_directory.glob('*.doric'):
            file_index_match = re.match(r'.*_(\d+).doric', file.name)
            if file_index_match:
                c_file_index = int(file_index_match.group(1))
            else:
                print(f"Could not parse file index from {file.parent}/{file.name}")
                continue
            with h5py.File(file, 'r') as f:
                file_data = create_file_dict(f)
            file_index.append(c_file_index)
            directory.append(date_directory.name)
            file_data["index"] = file_index
            data.append(file_data)
    return directory, file_index, data

def check_doric_filetree(mouse_name):
    """Find a doric file and print the filetree to inspect contents"""
    mouse_directory = data_path / mouse_name
    date_directories = [x for x in mouse_directory.iterdir() if x.is_dir()]
    for date_directory in date_directories:
        for file in date_directory.glob('*.doric'):
            file_index_match = re.match(r'.*_(\d+).doric', file.name)
            if file_index_match:
                c_file_index = int(file_index_match.group(1))
            else:
                raise ValueError(f"Could not parse file index from {file.name}")
            with h5py.File(file, 'r') as f:
                # Print the full filetree
                f.visit(print)
            return None
        
def _filter_cycle_markers(markers, first, last, keep_first=False, keep_last=False):
    """Filter cycle markers to only include valid cycles."""
    if keep_first:
        markers = markers[markers >= first]
    else:
        markers = markers[markers > first]
    if keep_last:
        markers = markers[markers <= last]
    else:
        markers = markers[markers < last]
    return markers

def get_cycles(data, cycle_period_tolerance=0.1):
    """Get output cycles with interleaved data on out1 and out2.
    
    Find start stop indices for each cycle. 
    Check that the cycles are interleaved correctly.
    Return the start and stop indices for each cycle.
    Return an index to all samples within the target cycles.

    Target cycle definition:
    First cycle is always on out1 (will clip if necessary) - last cycle is out2. 
    """
    diff1 = np.diff(data["out1"])
    diff2 = np.diff(data["out2"])
    start1 = np.where(diff1 == 1)[0] + 1
    start2 = np.where(diff2 == 1)[0] + 1
    stop1 = np.where(diff1 == -1)[0] + 1
    stop2 = np.where(diff2 == -1)[0] + 1
    first_valid_idx = start1[0]
    last_valid_idx = stop2[-1]

    start1 = _filter_cycle_markers(start1, first_valid_idx, last_valid_idx, keep_first=True)
    start2 = _filter_cycle_markers(start2, first_valid_idx, last_valid_idx)
    stop1 = _filter_cycle_markers(stop1, first_valid_idx, last_valid_idx)
    stop2 = _filter_cycle_markers(stop2, first_valid_idx, last_valid_idx, keep_last=True)

    start1 = _filter_cycle_markers(start1, first_valid_idx, stop1[-1], keep_first=True)
    stop2 = _filter_cycle_markers(stop2, start2[0], last_valid_idx, keep_last=True)
    
    if len(start1) != len(start2):
        raise ValueError("Unequal number of start markers")
    if len(stop1) != len(stop2):
        raise ValueError("Unequal number of stop markers")
    if len(start1) != len(stop1):
        raise ValueError("Unequal number of start and stop markers")
    if not np.all(start1 < stop1):
        raise ValueError("Start marker after stop marker for channel 1")
    if not np.all(start2 < stop2):
        raise ValueError("Start marker after stop marker for channel 2")
    
    period1 = stop1 - start1
    period2 = stop2 - start2
    period1_deviation = period1 / np.mean(period1)
    period2_deviation = period2 / np.mean(period2)
    bad_period1 = np.abs(period1_deviation - 1) > cycle_period_tolerance
    bad_period2 = np.abs(period2_deviation - 1) > cycle_period_tolerance
    if np.sum(np.diff(np.where(bad_period1)[0]) < 2) > 2:
        raise ValueError("Too many consecutive bad periods in channel 1")
    if np.sum(np.diff(np.where(bad_period2)[0]) < 2) > 2:
        raise ValueError("Too many consecutive bad periods in channel 2")
    
    # Remove bad periods and filter stop / start signals
    valid_period = ~bad_period1 & ~bad_period2
    start1 = start1[valid_period]
    stop1 = stop1[valid_period]
    start2 = start2[valid_period]
    stop2 = stop2[valid_period]

    if not np.all(data["out1"][start1] == 1) or not np.all(data["out2"][start2] == 1):
        raise ValueError("Start indices are not positive for out1 / out2!")
    if not np.all(data["out1"][stop1] == 0) or not np.all(data["out2"][stop2] == 0):
        raise ValueError("Stop indices are not zero for out1 / out2!")
    
    return start1, stop1, start2, stop2

def get_opto_cycles(data, min_period=1, cycle_period_tolerance=0.01):
    """Get opto cycles (out3) with a minimum period.
    
    Returns the start times for each cycle and an average cycle signal. 
    """
    diff3 = np.diff(data["out3"])
    start3 = np.where(diff3 == 1)[0] + 1
    stop3 = np.where(diff3 == -1)[0] + 1
    first_valid_idx = start3[0]
    last_valid_idx = stop3[-1]
    start3 = _filter_cycle_markers(start3, first_valid_idx, last_valid_idx, keep_first=True)
    start_time = data["out_time"][start3]

    valid_starts = [start3[0]]
    valid_times = [start_time[0]]

    for i in range(1, len(start3)):
        if start_time[i] > (valid_times[-1] + min_period):
            valid_starts.append(start3[i])
            valid_times.append(start_time[i])

    # Convert valid starts to numpy array (reuse start3 for consistent terminology with get_cycles)
    start3 = np.array(valid_starts)

    # Measure period between cycles
    period3 = start3[1:] - start3[:-1]
    period3_deviation = period3 / np.mean(period3)
    if not np.all(period3_deviation >= 1-cycle_period_tolerance) and np.all(period3_deviation <= 1+cycle_period_tolerance):
        min_period = np.min(period3)
        max_period = np.max(period3)
        raise ValueError(f"Excess period variation in opto cycles! min={min_period:.2f}, max={max_period:.2f}")
    
    min_period = np.min(period3)
    stop3 = start3 + min_period

    if stop3[-1] >= len(data["out3"]):
        start3 = start3[:-1]
        stop3 = stop3[:-1]

    cycles = []
    for istart, istop in zip(start3, stop3):
        cycles.append(data["out3"][istart:istop])
    average_cycle = np.mean(np.stack(cycles), axis=0)
    
    return start3, stop3, average_cycle

def get_cycle_data(signal, start, stop, keep_fraction=0.5, signal_cv_tolerance=0.05):
    """Extract cycle data from a signal."""
    num_samples = len(start)
    assert keep_fraction > 0 and keep_fraction < 1, "Invalid keep_fraction, must be in between 0 and 1"
    assert num_samples == len(stop), "Start and stop indices mismatch"
    cycle_data = []
    invalid_cycle = []
    for i in range(num_samples):
        c_stop = stop[i] - 1
        c_start = start[i] + int(keep_fraction * (c_stop - start[i]))
        cycle_signal = signal[c_start:c_stop]
        cycle_cv = np.std(cycle_signal) / np.mean(cycle_signal)
        invalid_cycle.append(cycle_cv > signal_cv_tolerance)
        cycle_data.append(signal[c_start:c_stop])
    cycle_data = np.array([np.mean(cd) for cd in cycle_data])
    return cycle_data, np.array(invalid_cycle)
    
def analyze_data(data, preperiod=0.1, cycle_period_tolerance=0.5, keep_fraction=0.5, signal_cv_tolerance=0.05):
    """Process a data file, return results and filtered signals."""
    # First check if the data is valid and meets criteria for processing.
    num_samples = len(data["in_data"])
    if not num_samples > 0:
        raise ValueError("No data found! in_data has 0 samples.")
    for key in ["out1", "out2", "out3"]:
        assert num_samples == len(data[key]), f"{key} and in_data length mismatch"
        uvals = np.unique(data[key])
        if not np.array_equal(uvals, np.array([0.0, 1.0])):
            raise ValueError(f"Invalid values in {key}: {uvals}")
    for key in ["in_time", "out_time"]:
        assert num_samples == len(data[key]), f"{key} and in_data length mismatch"
    
    # Get start and top indices for the interleaved cycles
    time = data["in_time"]
    start1, stop1, start2, stop2 = get_cycles(data, cycle_period_tolerance=cycle_period_tolerance) 
    cycle_timestamps = (time[stop2] + time[start1]) / 2 # Midpoint of full cycles
    in1, invalid1 = get_cycle_data(data["in_data"], start1, stop1, keep_fraction=keep_fraction, signal_cv_tolerance=signal_cv_tolerance)
    in2, invalid2 = get_cycle_data(data["in_data"], start2, stop2, keep_fraction=keep_fraction, signal_cv_tolerance=signal_cv_tolerance)
    
    # Upsample the cycle data to match the original timestamps
    upsample_cycle_timestamps = time[(time >= cycle_timestamps[0]) & (time <= cycle_timestamps[-1])]
    upsample_in1 = detrend(interp1d(cycle_timestamps, in1)(upsample_cycle_timestamps))
    upsample_in2 = detrend(interp1d(cycle_timestamps, in2)(upsample_cycle_timestamps))
    upsample_offset = np.nonzero(time >= upsample_cycle_timestamps[0])[0][0]

    if np.any(invalid1) or np.any(invalid2):
        print(f"Warning: excess co. of var. detected for {np.sum(invalid1)/num_samples*100:.2f}% of cycles are invalid for channel 1 and {np.sum(invalid2)/num_samples*100:.2f}% for channel 2.")
    
    # Get start indices for opto cycles
    start3, stop3, _ = get_opto_cycles(data, min_period=1.0, cycle_period_tolerance=cycle_period_tolerance)
    start3 = start3 - upsample_offset
    stop3 = stop3 - upsample_offset

    # Get opto start time in upsampled time
    opto_start_time = upsample_cycle_timestamps[start3]
    upsample_opto_data = data["out3"][upsample_offset:]
    upsample_opto_data = upsample_opto_data[:len(upsample_cycle_timestamps)]

    samples_pre = int(preperiod / np.mean(np.diff(upsample_cycle_timestamps)))

    # Get cycle data for opto cycles
    in1_opto = []
    in2_opto = []
    out3_opto = []
    time_opto = []
    for istart, istop in zip(start3, stop3):
        in1_opto.append(upsample_in1[istart-samples_pre:istop])
        in2_opto.append(upsample_in2[istart-samples_pre:istop])
        out3_opto.append(data["out3"][istart+upsample_offset-samples_pre:istop+upsample_offset])
        time_opto.append(upsample_cycle_timestamps[istart-samples_pre:istop] - upsample_cycle_timestamps[istart])

    in1_opto = np.stack(in1_opto)
    in2_opto = np.stack(in2_opto)
    out3_opto = np.stack(out3_opto)
    time_opto = np.mean(np.stack(time_opto), axis=0) # variance across opto cycles should be within sample error

    results = dict(
        in1_opto = in1_opto,
        in2_opto = in2_opto,
        out3_opto = out3_opto,
        time_opto = time_opto,
        opto_start_time = opto_start_time,
        data_in1 = upsample_in1,
        data_in2 = upsample_in2,
        data_opto = upsample_opto_data,
        time_data = upsample_cycle_timestamps,
    )

    return results

In [31]:
mouse_name = "ATL065"
dirs, findex, data = get_doric_files(mouse_name)
print(f"Found {len(data)} files")

Found 9 files


In [16]:
# For looking at a single session
ises = 9
preperiod = 0.15
results = analyze_data(data[ises], preperiod=preperiod)

fig, ax = plt.subplots(1, 2, figsize=(8, 4), layout="constrained", sharey=False)
ax[0].plot(results["time_data"], results["data_in1"], label="in1")
ax[0].plot(results["time_data"], results["data_in2"], label="in2")
ax[0].scatter(results["opto_start_time"], np.zeros_like(results["opto_start_time"]), color="red", label="opto start", s=5)

errorPlot(results["time_opto"], results["in1_opto"], se=True, axis=0, label="in1", ax=ax[1], alpha=0.2)
errorPlot(results["time_opto"], results["in2_opto"], se=True, axis=0, label="in2", ax=ax[1], alpha=0.2)
ax[1].set_xlim(-preperiod, 0.8)
plt.show()


In [32]:
# For showing a single mouse across sessions
preperiod = 0.2
postperiod = 1.0
samples = np.linspace(-preperiod, postperiod, int((postperiod - preperiod) * 1000))

cmap = plt.get_cmap("rainbow")
average = []
for ifile, file in enumerate(data):
    print(f"Processing file {ifile+1}/{len(data)}")
    results = analyze_data(file, preperiod=preperiod+0.01)
    c_idx = results["time_opto"] < postperiod + preperiod
    c_time = results["time_opto"][c_idx]
    c_data = np.mean(results["in2_opto"][:, c_idx] - results["in1_opto"][:, c_idx], axis=0)
    c_interp = interp1d(c_time, c_data, kind="cubic")(samples)
    c_interp = detrend(c_interp)
    average.append(c_interp)
average = np.stack(average)

Processing file 1/9
Processing file 2/9
Processing file 3/9
Processing file 4/9
Processing file 5/9
Processing file 6/9
Processing file 7/9
Processing file 8/9
Processing file 9/9


In [33]:
fig, ax = plt.subplots(1, 1, figsize=(4, 4), layout="constrained")
for iavg, avg in enumerate(average):
    ax.plot(samples, avg, color=cmap(iavg/len(average)))
ax.set_xlim(-preperiod, postperiod/2)
ax.plot([0, 0.03], [-0.001, -0.001], color="black", linewidth=3)

[<matplotlib.lines.Line2D at 0x1e8a46ac7f0>]

In [None]:
# For making an average across all mice
preperiod = 0.2
postperiod = 1.0
samples = np.linspace(-preperiod, postperiod, int((postperiod - preperiod) * 1000))

mouse_list = ["ATL061", "ATL062", "ATL063", "ATL064", "ATL065"]
colors = ["red", "blue", "green", "purple", "brown"]
average = []
for mouse in tqdm(mouse_list):
    c_mouse_average = []    
    dirs, findex, data = get_doric_files(mouse)
    for file in data:
        results = analyze_data(file, preperiod=preperiod+0.01)
        c_idx = results["time_opto"] < postperiod + preperiod
        c_time = results["time_opto"][c_idx]
        c_data = np.mean(results["in2_opto"][:, c_idx] - results["in1_opto"][:, c_idx], axis=0)
        c_interp = interp1d(c_time, c_data, kind="cubic")(samples)
        c_mouse_average.append(c_interp)
    average.append(np.stack(c_mouse_average))
    if mouse == "ATL061":
        average[-1] = average[-1][:3]

In [200]:
preperiod = 0.2
postperiod = 1.0
samples = np.linspace(-preperiod, postperiod, int((postperiod - preperiod) * 1000))

mouse_list = ["ATL061", "ATL062", "ATL063", "ATL064", "ATL065"]
colors = ["red", "blue", "green", "purple", "brown"]
average = []
for mouse in tqdm(mouse_list):
    c_mouse_average = []    
    dirs, findex, data = get_doric_files(mouse)
    for file in data:
        results = analyze_data(file, preperiod=preperiod+0.01)
        c_idx = results["time_opto"] < postperiod + preperiod
        c_time = results["time_opto"][c_idx]
        c_data = np.mean(results["in2_opto"][:, c_idx] - results["in1_opto"][:, c_idx], axis=0)
        c_interp = interp1d(c_time, c_data, kind="cubic")(samples)
        c_mouse_average.append(c_interp)
    average.append(np.stack(c_mouse_average))
    if mouse == "ATL061":
        average[-1] = average[-1][:3]

100%|██████████| 5/5 [00:19<00:00,  3.80s/it]


In [207]:
fig, ax = plt.subplots(1, 1, figsize=(4, 4), layout="constrained")
for imouse, (mouse, color, avg) in enumerate(zip(mouse_list, colors, average)):
    errorPlot(samples, avg - np.mean(avg[:, 0]), se=True, axis=0, label=mouse, ax=ax, color=color, alpha=0.2)
    # ax.plot(samples, avg.T, label=mouse, color=color)
ax.set_xlim(-preperiod, postperiod/2)
ax.legend(loc="upper right")
ax.plot([0, 0.03], [-0.001, -0.001], color="black", linewidth=3)

[<matplotlib.lines.Line2D at 0x192c0849850>]

In [212]:
plt.plot(samples, average[0].T)

[<matplotlib.lines.Line2D at 0x192ae253b20>,
 <matplotlib.lines.Line2D at 0x192ae253b50>,
 <matplotlib.lines.Line2D at 0x192ae253c40>]

: 

In [194]:
average[0][:, 0]

array([-0.03055913, -0.03442416, -0.00556974,  1.33912469])

In [160]:
[t.shape for t in times[0]]

[(62452,), (62361,), (62452,), (62452,)]