In [None]:
# Importing necessary libraries
import os
import re
import json
import matplotlib
matplotlib.use('nbagg') # Or 'notebook'
import pandas as pd
import numpy as np
import time
import scipy
import math
import scipy.io as sio

import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.gridspec as gridspec
import matplotlib.patches as patches
import matplotlib.dates as mdates
from matplotlib.patches import Patch

from datetime import datetime
from datetime import datetime, timedelta, date

from pathlib import Path

import tifffile

from open_ephys.analysis import Session

import xml.etree.ElementTree as ET

import formic as fm



In [None]:
# import class after switching to PMI directory
os.chdir(os.path.expanduser("~/FLEX/"))
from pmi_core import PopulationMuscleImaging as PMI
from pmi_core.image.image_series import ImageSeries

In [None]:
# Define event channel names 
event_channel_names = ['airpuff','airpuff session','UNUSED','camera','speaker 1','speaker 2','audio session','bpod lick trials']


In [None]:
which_condition = 'reward'
#which_condition = 'startle'
#which_condition = 'airpuff'
bundle_name = "experiment1"

In [None]:
# set directory names
bundle_dir = '/vol/cortex/cd2/machadolab/Data/madhav/FLEX_trial/M252/05_28_25/'
oe_dir = bundle_dir
flex_dir = bundle_dir+'M252_Arduino_2025-05-28_kinetix/'
bpod_dir = bundle_dir+'M252_Bpod_2025-05-28_kinetix/'

# Initialize PMI object
pmi = PMI(bundle_dir)
# set series number and channels to look for events
if which_condition == 'reward':
    which_series = 0
    event_channels = ['bpod lick trials']
elif which_condition == 'airpuff':
    which_series = 1
    event_channels = ['airpuff']
elif which_condition == 'startle':
    which_series = 2
    event_channels = ['speaker 1','speaker 2']


series_name = f'series{which_series:03}'

short_name = f"{bundle_name}, {series_name}"

print(f"Short name: {short_name}")

In [None]:
# Load ephys file


# threshold for time between image bursts
inter_burst_interval = 0.5

which_recording = 2

# Load into Session variable
session = Session(oe_dir)

# store event times
E = session.recordnodes[0].recordings[0].events

# get absolute start time of the ephys data 
oe_start_text = ET.parse(os.path.join(oe_dir,'settings.xml')).getroot().find('./INFO/DATE').text
oe_start = datetime.strptime(oe_start_text,'%d %b %Y %H:%M:%S')


In [None]:
# Get camera exposure times


# get timestamps
cam_starts = np.array(E.timestamp[(E['line'] == 4) & (E['state'] == 1)])
cam_ends = np.array(E.timestamp[(E['line'] == 4) & (E['state'] == 0)])

# if session starts with camera off (falling below threshold), delete first camera end time
if cam_ends[0] < cam_starts[0]:
    cam_ends = cam_ends[1:]

# compute differentials
cam_start_deltas = np.diff(cam_starts)
cam_lengths = np.array([e-s for (e,s) in zip(cam_ends,cam_starts)])


# identify bursts

# identify first and last index from each burst
big_deltas = np.where( cam_start_deltas > inter_burst_interval )[0]
new_burst_start_inds = np.concatenate(([0],big_deltas+1))
new_burst_end_inds = np.concatenate((big_deltas,[len(cam_starts)-1]))-1

# sort into image starts and image stops for each burst
for ii in range(len(new_burst_start_inds)):
    start_ind = new_burst_start_inds[ii]
    end_ind = new_burst_end_inds[ii]
    burst_times = cam_starts[start_ind:end_ind+1]
    #print(len(burst_times))



n_bursts = len(new_burst_start_inds)

# identify long bursts, ie those likely to have been saved as image series
burst_durations = [e-s for s,e in zip(cam_starts[new_burst_start_inds],cam_starts[new_burst_end_inds])]
min_burst_duration = np.median(burst_durations)/2
long_bursts = np.where(burst_durations>min_burst_duration)[0]


print(f"Recording start: {oe_start}")
print(f"   Events begin: {oe_start + timedelta(seconds=E.timestamp.iloc[0])}")
print(f"     Events end: {oe_start + timedelta(seconds=E.timestamp.iloc[-1])}")
print("")
print(f"Camera starts: {len(cam_starts):,}")
print(f"  Camera ends: {len(cam_ends):,}")

In [None]:
# Find the correct bpod mat file using # trials

# Load bpod triggers from OE file
bpod_channel = event_channel_names.index('bpod lick trials') + 1
bpod_starts = np.array(E.timestamp[(E['line'] == bpod_channel) & (E['state'] == 1)])
bpod_ends = np.array(E.timestamp[(E['line'] == bpod_channel) & (E['state'] == 0)])

# Show bpod event count from OE file
print(f"{len(bpod_starts):,} bpod starts")
print(f"{len(bpod_ends):,} bpod ends")
print("")

# Show trial count from each bpod file
bpod_files = sorted(glob.glob(os.path.join(bpod_dir, '*.mat')))

if len(bpod_files) == 0:
    print("No .mat files found in bpod_dir!")
    print(f"bpod_dir: {bpod_dir}")
    print(f"Directory exists: {os.path.exists(bpod_dir)}")
else:
    for bb, file_name in enumerate(bpod_files):
        try:
            behavior_info = helper.load_one_behavior_info(file_name)
            n_trials = behavior_info['n_trials']
            print(f"{n_trials} trials in file {bb}")
            print(f'     {file_name}')
            
            # Check for mismatch
            if n_trials != len(bpod_starts):
                print(f"     ⚠️  MISMATCH: {n_trials} trials in .mat vs {len(bpod_starts)} bpod starts in OE")
            else:
                print(f"     ✅ MATCH: {n_trials} trials matches {len(bpod_starts)} bpod starts")
            print()
            
        except Exception as e:
            print(f"Error reading file {bb}: {e}")
            print(f'     {file_name}\n')

print(f"Bpod files found: {bpod_files}")


In [None]:
# load bpod file
bpod_file = bpod_files[0]
bpod_session = sio.loadmat(bpod_file)["SessionData"]
behav_info = helper.load_one_behavior_info(bpod_file)

# convert bpod mat file events to OE times
bpod_events = helper.convert_bpod_mat_to_oe(bpod_session, bpod_starts)

In [None]:
# Load basic info for each image series (including first index)

# For each image series:
#
#  - identify files
#  - store file names
#  - load exposure index of first image (if saved)

# Check if we have series### directories or just session_1
regex = re.compile(r"series\d{3}$")
series_folders = [f for f in Path(flex_dir).iterdir() if f.is_dir() and regex.match(f.name)]

# If no series directories found, check for session_1 structure
if len(series_folders) == 0:
    session_dir = Path(flex_dir) / "session_1"
    if session_dir.exists():
        print("No series### directories found, using session_1 structure")
        # Treat the main flex_dir as series000
        series_folders = [Path(flex_dir)]
    else:
        raise ValueError(f"No series directories or session_1 found in {flex_dir}")

print(f"Found {len(series_folders)} series/session directories")

for ss in range(len(series_folders)):
    
    print('\033[1m' + 'Series folder: ' + '\033[0m' + f"{series_folders[ss]}")

    # find metadata file
    if 1:
        # make list of possible metadata file names
        tiff_metadata_paths = [os.path.join(series_folders[ss],'frame_timestamps.csv'),
                              os.path.join(series_folders[ss],'other_possible_name.txt')]
        
        # search for which exists, if any
        tiff_metadata_path = []
        for tmp in tiff_metadata_paths:
            if os.path.exists(tmp):
                tiff_metadata_path = tmp
    
        # error if none found
        if len(tiff_metadata_path) == 0:
            print(f"Warning: No metadata found for {series_folders[ss]}")
            tiff_metadata_path = ""
        else:
            print('\033[1m' + '     Metadata:\n' + '\033[0m'+f"          {tiff_metadata_path}")
    else:
        tiff_metadata_path = ''
        print('\033[1m' + '     No metadata loaded\n' + '\033[0m')
    
    # Search for tif files
    # get folder
    tiff_folder = os.path.join(series_folders[ss],'session_1')
    
    if not os.path.exists(tiff_folder):
        raise ValueError(f"Could not find tiff folder {tiff_folder}")
    else:
        # get list of files ending in '.tif'
        tiff_files_maybe = list(Path(tiff_folder).glob("session*.tif"))

        # error if none found
        if len(tiff_files_maybe) == 0:
            raise ValueError(f"No .tif files found in {tiff_folder}")

        # convert to strings
        tiff_files_maybe = [str(f) for f in tiff_files_maybe]

        # sort file names
        tiff_files_maybe = sorted(tiff_files_maybe, key=helper.natural_key)
    
    print('\033[1m' + f"     Tif files ({len(tiff_files_maybe)}):" + '\033[0m')
    for t in tiff_files_maybe[:3]:  # Show first 3 files
        print(f"          {t}")
    if len(tiff_files_maybe) > 3:
        print(f"          ... and {len(tiff_files_maybe)-3} more files")

    # load exposure index of first image (if file exists)
    first_index_filename = os.path.join(series_folders[ss],"first_index.txt")
    
    if Path(first_index_filename).is_file():
        with open(first_index_filename, "r") as file:
           first_image_exposure_index = int(file.read())
        print('\033[1m' + "     First index: " + '\033[0m' + f'{first_image_exposure_index}')
    else:
        first_image_exposure_index = []
        print('\033[1m' + "     First index: " + '\033[0m' + "not found")
    
    print('')
    
    this_dict = {"series_path":series_folders[ss],
                 "tiff_dir":tiff_folder,
                 "tiff_metadata_path":tiff_metadata_path,
                 "tiff_files":tiff_files_maybe,
                 "first_index":first_image_exposure_index}

    if ss==0:
        series_dicts = [this_dict]
    else:
        series_dicts.append(this_dict)

In [None]:
# Display info from the OE file and first index (guess?)


# Date and time

print('\033[1m' + oe_dir + '\n' + '\033[0m')

print('\033[1m' + f"{'Date':<13}" + '\033[0m' + oe_start.strftime('%a, %b %d, %Y'))
print('\033[1m' + f"{'Start time':<13}" + '\033[0m' + oe_start.strftime('%H:%M:%S.%f'))
print('\n')


# Event counts per channel

print('\033[1m' + f"{'Event count':>11}   {'Channel':<9} Data" + '\033[0m')
print('')

for cc in range(8):
    #print(cc)
    #print("     Location: {0:20} Revision {1}".format(event_channel_names[cc], {cc}))

    print(f"{E.timestamp[E['line'] == cc+1].shape[0]:>11,}   {cc:>4}      {event_channel_names[cc]:<20}")


# Camera exposures

print('\n')
print('\033[1m' + 'Bursts of camera exposures\n' + '\033[0m')

print('\033[1m' + f"{'Burst':<7}{'Event count':>11}{'Duration':>12}{'Hz':>12}{'T_start':>12}{'T_end':>10}"
      f"{'Index_start':>15}{'Index_end':>13}{'Series':>11}" + '\033[0m')
print('')


for bb in range(n_bursts):
    start_time = cam_starts[new_burst_start_inds[bb]]
    end_time = cam_starts[new_burst_end_inds[bb]]
    n_triggers = new_burst_end_inds[bb] - new_burst_start_inds[bb] + 1
    hz = n_triggers/(end_time-start_time)

    # guess which image series this might correspond to
    which_series_guess = np.where(bb==long_bursts)[0]
    if len(which_series_guess)==0:
        # if this 
        which_series_guess = '   '
    else:
        which_series_guess = f"guess: {which_series_guess[0]}, "

    # check if file noting series first index exists
    which_series_disk = 'no first_index file'
    if 'series_dicts' in locals():
        this_start = new_burst_start_inds[bb]
        for series_idx in range(len(series_dicts)):
            # Check if first_index exists and is not empty
            first_idx = series_dicts[series_idx]["first_index"]
            if first_idx and this_start == first_idx:
                which_series_disk = f'first_index.txt: {series_idx}'
                break

    print(f"{bb:>4}   {n_triggers:>11,}{end_time-start_time:>12.4f}{hz:>12.4f}{start_time:>12.2f}{end_time:>10.2f}"
          f"{new_burst_start_inds[bb]:>15}{new_burst_end_inds[bb]:>13}{which_series_guess:>13}{which_series_disk}")

In [None]:
# Create files noting the correspondence between camera bursts and image series folders

#raise Exception('Are you sure you want to create these files?????')

# Make text files noting the first index of each image series
# Note: With session_1 structure, we only have one main directory, so we'll create
# first_index.txt for the long bursts that correspond to actual imaging series

for bb in range(n_bursts):

    # display which image series this might correspond to
    which_series = np.where(bb==long_bursts)[0]
    if len(which_series)==0:
        print(f"Burst {bb}: Short burst, skipping (not an imaging series)")
    else:
        # For session_1 structure, we use the main series folder (index 0)
        # and create a first_index.txt file there
        series_idx = which_series[0]
        print(f"Burst {bb}: Long burst {series_idx}, corresponds to imaging series")
        
        # Since we have session_1 structure, use the main directory
        file_name = os.path.join(series_folders[0], "first_index.txt")
        
        # Only write the first long burst's index (subsequent long bursts are different series
        # but in session_1 structure they're all in the same directory)
        if series_idx == 0:  # Only write for the first long burst
            with open(file_name, "w") as file:
                print(f"Writing to file {file_name}:")
                print(f"    {new_burst_start_inds[bb]}")
                file.write(f"{new_burst_start_inds[bb]}")
        else:
            print(f"    Skipping file creation for subsequent long burst (series_idx={series_idx})")

In [None]:
# Incorporate exposure times from ephys data

print(f'Series {0}\n')
print(f'Data from metadata file\n')

# Load metadata
series_dict = series_dicts[0]
image_metadata_info = helper.get_image_metadata_info(series_dict["tiff_files"],series_dict["tiff_metadata_path"])

#print('\nKeys for image_metadata_info object:')
#print(image_metadata_info.keys())

# get image exposure times

first_index = series_dict["first_index"]

the_burst_to_load = np.where(new_burst_start_inds==first_index)[0]

if the_burst_to_load.size == 0:
    raise Exception(f'Index {first_index} is not the beginning of any burst.')
else:
    the_burst_to_load = the_burst_to_load[0]

exposure_times = cam_starts[new_burst_start_inds[the_burst_to_load]:new_burst_end_inds[the_burst_to_load]+1]

print(f"\nLoaded {len(exposure_times):,} exposure times from OE file, burst {the_burst_to_load}")
#print(f"Loaded {len(exposure_times)} exposure times")

# get image count from metadata file
image_count = len(image_metadata_info['fs_norm'])

# error mismatch between 1) OE file exposures and 2) metadata image count
if len(exposure_times) != image_count:
    raise Exception(f'Image count ({image_count}) does not match the number of exposures in burst {the_burst_to_load} ({len(exposure_times)})')


In [None]:
# Load image data

series_path = series_dicts[0]['series_path']



im_se = ImageSeries(os.path.join(series_path,'session_1'),
                    os.path.join(series_path,'frame_timestamps.csv'),
                    n_channels=2)


#os.system("say done!")

pmi.images.append(im_se)