In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tifffile as tf
import os

from scipy.ndimage import gaussian_filter1d
from scipy.signal import find_peaks

In [None]:
session = 'point_5ms' # point_1ms, point_1.7ms, point_5ms, spiral_5ms_10um
ch = 'Ch2'

sigma = 100  # gaussian smoothing sigma in pixel time

# dataset specifics (only to calculate units)
stim_len = 5 # in ms
interstim_len = 1000 # in ms
n_reps = 9 
frame_period = 33.602823 # in ms 
xy_pixels = 512
px_time = frame_period / xy_pixels**2

# now get onsets 
onsets = []
for i in range(n_reps):
    onset = i * (stim_len + interstim_len) + (interstim_len + stim_len)
    # convert to px_index
    onset = int(onset / px_time)
    onsets.append(onset)  
print('Calculated onsets (currently this is missaligned because of flyback time most likely):', onsets)

In [None]:
# now find tseries
root = '/Volumes/data_jm_share/data_raw/calibration/shutter_speed/'
session_path = os.path.join(root, session)
tseries = [f for f in os.listdir(session_path) if f.startswith('TSeries')]
if len(tseries) != 1:
    raise ValueError(f"Expected exactly one TSeries file, found {len(tseries)}: {tseries}")
tseries_path = os.path.join(session_path, tseries[0])

# now find file with ch in name and ending in .tif
tiff_files = [f for f in os.listdir(tseries_path) if ch in f and f.endswith('.tif')]
if len(tiff_files) != 1:
    raise ValueError(f"Expected exactly one tif file with channel {ch}, found {len(tiff_files)}: {tiff_files}")
tiff_path = os.path.join(tseries_path, tiff_files[0])

In [None]:
print(f"Loading TIFF file: {tiff_path}")

In [None]:
tiff_flat = tf.imread(tiff_path).flatten()


In [None]:
# smooth with a Gaussian filter
tiff_smooth = gaussian_filter1d(tiff_flat, sigma=sigma)


In [None]:
# Find peaks in the smoothed data
peaks, _ = find_peaks(tiff_smooth, prominence=np.max(tiff_smooth)/2, distance=512**2)  # Adjust height and distance as needed

In [None]:
plt.figure(figsize=(10, 2))
plt.plot(tiff_smooth)
plt.scatter(np.array(onsets), np.zeros_like(np.array(onsets)), color='red', label='Stimulus Onsets')
plt.scatter(peaks, tiff_smooth[peaks], color='green', label='Empirically detected Peaks')

In [None]:
wind_pre = 10000
wind_post = 50000

all_tr = np.zeros((len(peaks), wind_pre + wind_post))

for i, peak in enumerate(peaks):
    all_tr[i, :] = tiff_smooth[peak - wind_pre:peak + wind_post]
    

In [None]:
plt.figure(figsize=(6, 4), dpi=300)
plt.plot(all_tr.T, label=np.arange(len(peaks)))
plt.legend()
plt.xlabel('Time (ms)')
# now add ticks every 1 ms based on px_time
wind_npx = wind_pre + wind_post
wind_ms = int(wind_npx * px_time)
print(f"Window rounded to int ms: {wind_ms} ms")

# add x tick every 1 m
plt.xticks(np.linspace(0, wind_npx, wind_ms+1), np.arange(wind_ms+1))

plt.ylabel('F')
plt.title(f'Align. on peak. Session: {session}, Channel: {ch}')

In [None]:
wind_pre = 50000
wind_post = 50000

all_tr = np.zeros((len(peaks), wind_pre + wind_post))

for i, peak in enumerate(peaks):
    all_tr[i, :] = tiff_smooth[peak - wind_pre:peak + wind_post]

# now align them based on correlation
from scipy.signal import correlate
aligned_tr = np.zeros_like(all_tr)
for i in range(all_tr.shape[0]):
    corr = correlate(all_tr[i], all_tr.mean(axis=0), mode='full')
    lag = np.argmax(corr) - (len(all_tr[i]) - 1)
    aligned_tr[i] = np.roll(all_tr[i], -lag)



In [None]:
wind_trunc_pre = 30000 # remove the first pixels
wind_trunc_post = 40000 - wind_trunc_pre # remove the last pixels (dont change 40000 to keep the window size constant)
aligned_tr_plot = aligned_tr[:, wind_trunc_pre: -wind_trunc_post]

plt.figure(figsize=(6, 4), dpi=300)
plt.plot(aligned_tr_plot.T, label=np.arange(len(peaks)))
plt.legend()
plt.xlabel('Time (ms)')

# now add ticks every 1 ms based on px_time
wind_npx = wind_pre-wind_trunc_pre + wind_post-wind_trunc_post
wind_ms = int(wind_npx * px_time)
print(f"Window rounded to int ms: {wind_ms} ms")

# add x tick every 1 m
plt.xticks(np.linspace(0, wind_npx, wind_ms+1), np.arange(wind_ms+1))

plt.ylabel('F')
plt.title(f'Align. on peak + register. Session: {session}, Channel: {ch}')
