# stim_timing.ipynb (still need to reorganise substantially)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tifffile as tf
import os

from scipy.ndimage import gaussian_filter1d
from scipy.signal import find_peaks

In [None]:
root = '/Volumes/data_jm_share/data_raw/calibration/stim_timing/'
save_dir = 'utils/calibration_stim_timing/'
# find and print all sessions
sessions = [f for f in os.listdir(root)]
sessions.sort()
print("Available sessions:")
for s in sessions:
    print(f"{s}")

In [None]:
session = '2025-11-07_1isi_5ms' 
ch = 'Ch1'
align_to = 'onset'  # 'onset' or 'peak'
sigma = 50  # gaussian smoothing sigma in pixel time

# dataset specifics (only to calculate units)

stim_len = 50 if 'default' in session else session.split('_')[-1]
stim_len = float(stim_len.replace('ms','')) if 'ms' in str(stim_len) else stim_len

interstim_len = 1000 # in ms
n_reps = 9 
frame_period = 33.602823 # in ms 
xy_pixels = 512
onset_threshold = 0.5 # threshold for onset detection (relative to max)
px_time = frame_period / xy_pixels**2

wind_pre_ms = 10 # in ms
wind_post_ms = 110 # in ms
# wind_post_ms = 52


# now get onsets 
onsets = []
for i in range(n_reps):
    onset = i * (stim_len + interstim_len) + (interstim_len + stim_len)
    # convert to px_index
    onset = int(onset / px_time)
    onsets.append(onset)  
print('Calculated onsets (currently this is missaligned because of flyback time most likely):', onsets)

In [None]:
# converting window from frames to samples
wind_pre = int(wind_pre_ms / px_time)
wind_post = int(wind_post_ms / px_time)

In [None]:
# now find tseries
session_path = os.path.join(root, session)
tseries = [f for f in os.listdir(session_path) if f.startswith('TSeries')]
if len(tseries) != 1:
    raise ValueError(f"Expected exactly one TSeries file, found {len(tseries)}: {tseries}")
tseries_path = os.path.join(session_path, tseries[0])

# now find file with ch in name and ending in .tif
tiff_files = [f for f in os.listdir(tseries_path) if ch in f and f.endswith('.tif')]
if len(tiff_files) != 1:
    raise ValueError(f"Expected exactly one tif file with channel {ch}, found {len(tiff_files)}: {tiff_files}")
tiff_path = os.path.join(tseries_path, tiff_files[0])

In [None]:
print(f"Loading TIFF file: {tiff_path}")

In [None]:
tiff_flat = tf.imread(tiff_path).flatten()


In [None]:
# smooth with a Gaussian filter
tiff_smooth = gaussian_filter1d(tiff_flat, sigma=sigma)


In [None]:
# Find peaks in the smoothed data
if align_to == 'peak':
    peaks, _ = find_peaks(tiff_smooth, prominence=np.max(tiff_smooth)/2, distance=512**2)  # Adjust height and distance as needed
    
elif align_to == 'onset':
    #NOTE: IMPORTANT TO NOT GET CONFUSED: because of previous version of the code, here 'peaks' are actually onsets!!!

    # get onsets by by binarisng based on half max of der
    bin_tiff_smooth = tiff_smooth > (np.max(tiff_smooth) * onset_threshold)
    derivative = np.diff(bin_tiff_smooth.astype(int), prepend=0)
    peaks = np.where(derivative == 1)[0]
    # filter peaks if the distance between them is less than the stim length (in px)
    cutoff = (wind_pre + wind_post)*2 # in px
    
    # peaks = peaks[np.insert(np.diff(peaks) > cutoff, 0, True)]
    # if there is a duplicate within cutoff, keep only the first one
    filtered_peaks = []
    previous_peak = peaks[0]
    for peak in peaks[1:]:
        if peak - previous_peak > cutoff:
            filtered_peaks.append(previous_peak)
            previous_peak = peak
    peaks = np.array(filtered_peaks)
    # get offsets in the same way
    offsets = np.where(derivative == -1)[0]
    # offsets = offsets[np.insert(np.diff(offsets) > cutoff, 0, True)]
    filtered_offsets = []
    last_offset = -np.inf
    for offset in offsets:
        if offset - last_offset > cutoff:
            filtered_offsets.append(offset)
            last_offset = offset
    peaks = np.array(filtered_peaks)
    offsets = np.array(filtered_offsets)
    
    print('Detected onsets at:', peaks)



    # now compute all artefact durations
    if len(offsets) != len(peaks):
        print(f"WARNING: Number of detected offsets ({len(offsets)}) does not match number of detected peaks ({len(peaks)}).")
        print('truncating to the minimum of the two.')
        min_len = min(len(offsets), len(peaks))
        offsets = offsets[:min_len]
        peaks = peaks[:min_len]
    
    durations = offsets - peaks
    durations_ms = durations * px_time

    print('Detected artefact durations (in ms):', durations_ms)


In [None]:
plt.figure(figsize=(10, 2))
plt.plot(tiff_smooth)
plt.scatter(np.array(peaks), np.zeros_like(np.array(peaks)), color='red', label='Stimulus Onsets')
if align_to == 'onset':
    plt.scatter(np.array(offsets), np.zeros_like(np.array(peaks)), color='blue', label='Calculated Onsets')
plt.scatter(peaks, tiff_smooth[peaks], color='green', label='Empirically detected Peaks')

In [None]:
# now zoom in around first peak
plt.figure(figsize=(10, 2))
plt.plot(tiff_smooth)
plt.scatter(np.array(onsets), np.zeros_like(np.array(onsets)), color='red', label='Stimulus Onsets')
plt.scatter(np.array(offsets), np.zeros_like(np.array(peaks)), color='blue', label='Calculated Onsets')
plt.scatter(peaks, tiff_smooth[peaks], color='green', label='Empirically detected Peaks')
plt.xlim(peaks[0]-wind_pre, peaks[0]+wind_post)



In [None]:

all_tr = np.zeros((len(peaks), wind_pre + wind_post))

for i, peak in enumerate(peaks):
    all_tr[i, :] = tiff_smooth[(peak - wind_pre):(peak + wind_post)]
    

In [None]:
plt.figure(figsize=(6, 4))
label = np.arange(len(peaks)) if align_to == 'peak' else np.round(durations_ms, 1)
plt.plot(all_tr.T, label=label)
plt.legend(title='Duration (ms)', fontsize='small')
plt.xlabel('Time (ms)')


wind_npx = wind_pre + wind_post
wind_ms = int(wind_npx * px_time)
print(f"Window rounded to int ms: {wind_ms} ms")

# change tics to ms with 1 decimal

mean_tr = np.mean(all_tr, axis=0)
peak_idx = np.argmax(mean_tr)

# add 1 tick at peak_idx (0) and 3 peaks before and after
tick_positions = np.linspace(mean_tr.shape[0], 0, num=7, dtype=float)

tick_labels = [f"{(pos - peak_idx) * px_time:.1f}"
                for pos in tick_positions]
plt.xticks(tick_positions, tick_labels)

plt.axvline(x=peak_idx, color='grey', linestyle='--', label='Peak Alignment')

plt.xlabel('Time (ms)')
plt.ylabel('F')
plt.title(f'Align. on peak. Session: {session}, Channel: {ch}')


In [None]:
plt.plot(all_tr.flatten())

In [None]:
# concatenate all trials and run power spectral density (i am looking for somethng very high freq)
from scipy.signal import welch
f, Pxx = welch(all_tr.flatten(), fs=1/px_time, nperseg=1024*4)
plt.figure(figsize=(6,4))
plt.semilogy(f, Pxx)
plt.xlabel('Frequency (Hz)')
plt.ylabel('Power Spectral Density')
plt.title('Power Spectral Density of Stimulus Artefact Signal') 


In [None]:
# plot the mean
plt.figure(figsize=(6, 4))
plt.plot(mean_tr, color='black', label='Mean Trace')
plt.axvline(x=peak_idx, color='grey', linestyle='--', label='Peak Alignment')
plt.xticks(tick_positions, tick_labels)
plt.xlabel('Time (ms)')