In [None]:
import hdf5storage
import helpers
import get_data
import numpy as np
from bat.helpers_bat import *

%load_ext autoreload
%autoreload 2


## Loading in the bat data (LFP and positional data)
We first need to load in the LFP data, which in this case is stored in a MATLAB file. We can do this using ```hdf5storage```. The bat's positional data is stored in a different file, but luckily the accessors for this were provided by the Yartsev lab (thanks Kevin Qi!).

In [None]:
# Loading bat LFP data
lfp_mat = hdf5storage.loadmat('./bat/ephys/32622_231007_lfp.mat')

# Check the structure of lfp_mat
print("Structure of lfp_mat:", type(lfp_mat['lfp']), lfp_mat['lfp'].shape)


In [None]:
data_path = './bat' # Replace this
bat_id = '32622'
date =  '231007'

In [None]:
session = FlightRoomSession(data_path, bat_id, date, use_cache = False) # use_cache = True to save time on future loads

In [None]:
pos = session.cortex_data.bat_pos # (num_timepoints, XYZ)
pos.shape
pos[:,0]

In [None]:
cleaned_pos = np.copy(pos)
# still position on walls are nans (dont interpolate large gaps of nans)
cleaned_pos[:, 0] = interpolate_nans(pos[:, 0])
cleaned_pos[:, 1] = interpolate_nans(pos[:, 1])
cleaned_pos[:, 2] = interpolate_nans(pos[:, 2])
cleaned_pos.shape

## Time synchronization
Before we get to the main attraction (the LFP data), we need to ensure our data is synchronized. To do this, we need to extract global timestamps from both the LFP and positional data and make sure they start at the same time.

In [None]:
timestamps = lfp_mat['global_sample_timestamps_usec'] #global timestamps in microseconds

np.diff(timestamps.flatten()) #we will see a 400 microsecond difference between each timestamp,
#which means that the sampling rate is 2500 Hz

In [None]:
timestamps.shape

Above, we filtered out the negative timestamps from the position recording relative to the global timestamp start. We simply had to filter out the corresponding samples in position to synchronize. Now we're ready to make timebins from the LFP timestamp data and bin the positional data accordingly.

In [None]:
from scipy.signal import decimate

lfp_timestamps_decimated_bins = decimate(timestamps.flatten(), 100) #decimate from 2500 Hz to 25 Hz (100x decimation)
lfp_timestamps_decimated_bins.shape


In [None]:
lfp_indices = lfp_timestamps_decimated_bins > 0
lfp_timestamps_decimated_bins = lfp_timestamps_decimated_bins[lfp_indices] # lop off negative timestamps
lfp_timestamps_decimated_bins = np.insert(lfp_timestamps_decimated_bins, 0, 0) # insert 0 at the beginning

lfp_timestamps_decimated_bins.shape
#lfp_timestamps_decimated_bins[:10]

In [None]:
pos_timestamps = session.cortex_data.cortex_global_sample_timestamps_sec * 1e6 #converting to microseconds (usec)

valid_indices = pos_timestamps > 0

pos_timestamps = pos_timestamps[valid_indices] # lop off negative timestamps 

np.diff(pos_timestamps) 
#pos_timestamps.shape

In [None]:
# Filter out cluster 1 flight position and LFP timestamps
def exclude_cluster1_indices(session, pos_timestamps, lfp_timestamps):
    cluster1_pos_indices = []
    cluster1_lfp_indices = []
    
    for flight_number in range(session.cortex_data.num_flights):
        if session.cortex_data.cluster_ids[flight_number] == 1:
            flight = session.flights[flight_number]
            start_idx = flight.timebin_start_idx
            end_idx = flight.timebin_end_idx
            cluster1_pos_indices.extend(range(start_idx, end_idx + 1))

    # Convert original indices to timestamps
    cluster1_pos_timestamps = pos_timestamps[cluster1_pos_indices]

    # Create masks to exclude these timestamps
    pos_mask = np.isin(pos_timestamps, cluster1_pos_timestamps, invert=True)
    lfp_mask = np.isin(lfp_timestamps, cluster1_pos_timestamps, invert=True)

    # Apply masks to timestamps
    filtered_pos_indices = np.where(pos_mask)[0]
    filtered_lfp_indices = np.where(lfp_mask)[0]

    return filtered_pos_indices, filtered_lfp_indices

# Call the function with the session, pos_timestamps, and lfp_timestamps
filtered_pos_indices, filtered_lfp_indices = exclude_cluster1_indices(session, pos_timestamps, lfp_timestamps_decimated_bins)


In [None]:
##pos_timestamps = pos_timestamps[valid_indices]
#pos_timestamps.shape

In [None]:
cleaned_pos = cleaned_pos[valid_indices] # lop off the corresponding positions
cleaned_pos.shape

In [None]:
cleaned_pos[:,0]

In [None]:
# Bin positional data using the provided label_timebins function
binned_pos_x = label_timebins(lfp_timestamps_decimated_bins, cleaned_pos[:, 0], pos_timestamps, is_discrete=False)
binned_pos_y = label_timebins(lfp_timestamps_decimated_bins, cleaned_pos[:, 1], pos_timestamps, is_discrete=False)
binned_pos_z = label_timebins(lfp_timestamps_decimated_bins, cleaned_pos[:, 2], pos_timestamps, is_discrete=False)




In [None]:
# Ensure the lengths match
min_length = min(len(binned_pos_x), len(binned_pos_y), len(binned_pos_z), len(lfp_timestamps_decimated_bins))
# Truncate arrays to the minimum length
binned_pos_x = binned_pos_x[:min_length]
binned_pos_y = binned_pos_y[:min_length]
binned_pos_z = binned_pos_z[:min_length]
lfp_timestamps_decimated_bins = lfp_timestamps_decimated_bins[:min_length]

# Initialize binned_pos array with the truncated length
binned_pos = np.zeros((3, min_length))

In [None]:
binned_pos = np.zeros((3,lfp_timestamps_decimated_bins.shape[0]))
# Assign binned positions to binned_pos array
binned_pos[0, :] = binned_pos_x
binned_pos[1, :] = binned_pos_y
binned_pos[2, :] = binned_pos_z

# Verify the contents of binned_pos
print(f"binned_pos shape: {binned_pos.shape}")
print(f"First few entries in binned_pos:\n{binned_pos[:, :5]}")


In [None]:
binned_pos.shape

In [None]:
flight_behavior = session.cortex_data
labels = np.full([flight_behavior.num_cortex_timebins], 0)
cluster_flights = session.get_flights_by_cluster((2,9))
cluster_flights_id = [flight.cluster_id for flight in cluster_flights]
for i_flight in range(len(cluster_flights)):
    s = cluster_flights[i_flight].timebin_start_idx
    e = cluster_flights[i_flight].timebin_end_idx
    labels[s:e] = i_flight + 1

In [None]:
plt.plot(labels)

We are creating labels (timebin_labels) to associate which timebins are related to which flight so we can access them accordingly.

In [None]:
labels = labels[valid_indices]

In [None]:
import scipy.interpolate as interpolate
def find_multi_label_bins(spk_timebins, labels, label_timestamps_sec):
    bin_indices = np.digitize(label_timestamps_sec, spk_timebins) - 1
    valid_mask = (bin_indices >= 0) & (bin_indices < len(spk_timebins) - 1)
    bin_indices = bin_indices[valid_mask]
    valid_labels = labels[valid_mask]
   
    # Find unique bin-label combinations
    unique_combinations, counts = np.unique(np.column_stack((bin_indices, valid_labels)), axis=0, return_counts=True)
   
    # Count the number of unique labels for each bin
    bin_label_counts = np.zeros(len(spk_timebins) - 1, dtype=int)
    np.add.at(bin_label_counts, unique_combinations[:, 0], 1)
   
    # Find bins with multiple different labels
    bins_with_multi_unique_labels = np.where(bin_label_counts > 1)[0]
    #print(unique_combinations, counts)

    return bins_with_multi_unique_labels, unique_combinations, counts

def label_timebins(spk_timebins, labels, label_timestamps_sec, is_discrete):
    """
    result = label_timebins([0,2,4,6,8], np.array([2,2,4,5,6,6,6,6,1,2,2,4,3,3,3]), np.array([0.5,1,1.5,1.6,1.5,1,1,1,3.5,6.5,6.6,6.5,6.4,6.5,6.6]), is_discrete=True)
    print(result)
    """
    # Ensure inputs are numpy arrays
    spk_timebins = np.array(spk_timebins)
    labels = np.array(labels)
    label_timestamps_sec = np.array(label_timestamps_sec)
   
    # Calculate the midpoints of spk_timebins
    spk_midpoints = (spk_timebins[:-1] + spk_timebins[1:]) / 2
   
    if is_discrete:
        # For discrete labels, use nearest neighbor interpolation
        f = interpolate.interp1d(label_timestamps_sec, labels, kind='nearest',
                                 bounds_error=False, fill_value=0)
        resampled_labels = f(spk_midpoints)
       
        # Find bins with multiple different labels
        multi_label_bins, unique_combinations, counts = find_multi_label_bins(spk_timebins, labels, label_timestamps_sec)
        #print(multi_label_bins)
       
        # Correct labels for bins with multiple different labels
        if len(multi_label_bins) > 0:
            for bin_index in multi_label_bins:
                bin_label_counts = np.argmax(counts[unique_combinations[:, 0] == bin_index])
                #print(unique_combinations[unique_combinations[:, 0] == bin_index,:])
                resampled_labels[bin_index] = unique_combinations[unique_combinations[:, 0] == bin_index, 1][bin_label_counts]
       
        # Set labels to 0 for bins without any labels
        valid_bins = np.digitize(label_timestamps_sec, spk_timebins) - 1
        valid_bins = valid_bins[(valid_bins >= 0) & (valid_bins < len(spk_timebins) - 1)]
        invalid_bins = np.setdiff1d(np.arange(len(spk_timebins) - 1), valid_bins)
        resampled_labels[invalid_bins] = 0
       
    else:
        # For continuous labels, use linear interpolation
        f = interpolate.interp1d(label_timestamps_sec, labels, kind='linear',
                                 bounds_error=False, fill_value=np.nan)
        resampled_labels = f(spk_midpoints)
       
        # Set labels to NaN for bins without any labels
        valid_bins = np.digitize(label_timestamps_sec, spk_timebins) - 1
        valid_bins = valid_bins[(valid_bins >= 0) & (valid_bins < len(spk_timebins) - 1)]
        invalid_bins = np.setdiff1d(np.arange(len(spk_timebins) - 1), valid_bins)
        resampled_labels[invalid_bins] = np.nan
   
    return resampled_labels


def label_in_timebin(timebin_edges, label_timestamps):
    timebin_edges = np.array(timebin_edges)
    label_timestamps = np.array(label_timestamps)
   
    # Create an array of booleans, one for each timebin
    result = np.zeros(len(timebin_edges) - 1, dtype=bool)
   
    # Use numpy's digitize to find which bin each label falls into
    bin_indices = np.digitize(label_timestamps, timebin_edges) - 1
   
    # Filter out any indices that are out of bounds
    valid_indices = (bin_indices >= 0) & (bin_indices < len(result))
    bin_indices = bin_indices[valid_indices]
   
    # Set the corresponding bins to True
    result[bin_indices] = True
   
    return result





timebin_labels = label_timebins(lfp_timestamps_decimated_bins, labels, pos_timestamps, is_discrete=True)

In [None]:
plt.plot(binned_pos_x[timebin_labels >0], binned_pos_y[timebin_labels >0])

In [None]:
def get_flightID(session, binned_pos, binned_indices):
    flight_info = []
    flight_number = 1

    # Initialize a flight number array with default value -1
    flight_numbers = -np.ones(binned_pos.shape[1], dtype=int)
    feeders_visited = -np.ones(binned_pos.shape[1], dtype=int)

    # Iterate through each flight
    for flight_number in range(session.num_flights):
        if session.cortex_data.cluster_ids[flight_number] == 1:
            continue  # Exclude cluster 1

        flight = session.flights[flight_number]
        start_idx = flight.timebin_start_idx
        end_idx = flight.timebin_end_idx

        # Convert original indices to binned indices
        binned_start_idx = binned_indices[start_idx]
        binned_end_idx = binned_indices[end_idx]

        # Ensure indices are within bounds
        if binned_end_idx >= binned_pos.shape[1]:
            print(f"Skipping flight with end_idx {binned_end_idx} as it exceeds bounds.")
            continue

        # Determine which feeder the bat visited based on the x and y coordinates of the last position
        end_x = binned_pos[0, binned_end_idx]
        end_y = binned_pos[1, binned_end_idx]

        if end_y > 0 and end_x < 0:
            feeder_visited = 0  # Perch
        elif end_y > 0 and end_x > 0:
            feeder_visited = 1  # Feeder 1
        elif end_y < 0 and end_x > 0:
            feeder_visited = 2  # Feeder 2
        else:
            feeder_visited = -1  # In case it doesn't match any criteria (should not happen)

        # Assign the flight number and feeder visited to each sample in the binned_pos array
        flight_numbers[binned_start_idx:binned_end_idx + 1] = flight_number + 1
        feeders_visited[binned_start_idx:binned_end_idx + 1] = feeder_visited

    # Create the flightID array with all samples included
    for idx in range(binned_pos.shape[1]):
        flight_info.append([flight_numbers[idx], feeders_visited[idx], binned_pos[0, idx], binned_pos[1, idx], binned_pos[2, idx]])

    # Convert to numpy array
    flightID = np.array(flight_info)

    return flightID
# Call get_flightID with the session and binned_pos
flightID = get_flightID(session, binned_pos, binned_indices)

# Verify the contents of flightID
print(f"flightID shape: {flightID.shape}")
print(f"First few entries in flightID:\n{flightID[:5]}")


## LFP extraction and downsampling

In [None]:
# Extract subarrays and check their structure
lfp_data_1 = lfp_mat['lfp'][0, 0]
lfp_data_2 = lfp_mat['lfp'][0, 1]

print(f"Type of lfp_data_1: {type(lfp_data_1)}, Shape of lfp_data_1: {lfp_data_1.shape}")
print(f"Type of lfp_data_2: {type(lfp_data_2)}, Shape of lfp_data_2: {lfp_data_2.shape}")

n_channels = lfp_data_1.shape[0] #same # of channels for lfp_data_1 and lfp_data_2 (change if not the case)


In [None]:

# bat LFP data sampled at 2500 Hz
lfp_bat_1 = get_data.get_LFP_from_mat(lfp_data_1,n_channels,2500)
lfp_bat_2 = get_data.get_LFP_from_mat(lfp_data_2,n_channels,2500)

lfp_bat_combined = np.concatenate((lfp_bat_1, lfp_bat_2), axis=1)


In [None]:
print("lfp_bat_1 shape:", lfp_bat_1.shape) # (n_samples, n_channels)
print("lfp_bat_2 shape:", lfp_bat_2.shape) # (n_samples, n_channels)
print("lfp_bat_combined shape:", lfp_bat_combined.shape) # (n_samples, 2*n_channels)

Once LFP is loaded in, we can downsample to 25hz and apply a Hilbert transform.

In [None]:
LFPs = helpers.filter_data(lfp_bat_combined, 0.1, fs=25, use_hilbert=True)

In [None]:
LFPs.shape #

In [None]:
bin_LFP = LFPs[lfp_indices]
bin_LFP=bin_LFP[:min_length]

In [None]:
bin_LFP.shape # (# of binned samples, n_channels)

In [None]:
from matplotlib import pyplot as plt
from TIMBRE import TIMBRE
from bat.helpers_bat import *

fig, axs = plt.subplots(4, 4, figsize=(20, 5))

n_folds = 5
which_fold = 0
num_samples_at_end = 5  # Number of samples at the end of each flight to use for classification

test_inds, train_inds = test_train_bat(flightID, n_folds, which_fold, num_samples_at_end)

wLFPs, _, _ = helpers.whiten(bin_LFP, train_inds)

# Verify the shapes of the indices arrays
#print(f"Test indices shape: {test_inds.shape}")
#print(f"Train indices shape: {train_inds.shape}")

# Assuming `wLFPs` and `flightID` are used in TIMBRE
#print(f"wLFPs shape: {wLFPs.shape}")
#print(f"flightID shape: {flightID.shape}")


#n_bins = 20
#pos_binned = helpers.group_by_pos(data['lapID'][:, 4], n_bins, train_inds)  # Convert position along the track into discrete bins.
#arm_and_pos_binned = data['lapID'][:, 1] * n_bins + pos_binned  # Represent arm x position as integer between 0-19 (arm 1), 20-39 (arm 2), 40-59 (arm 3)
# Additional debug statements to verify input to TIMBRE
print(f"X (wLFPs) shape: {wLFPs.shape}")
print(f"Y (flightID[:, 1]) shape: {flightID[:, 1].shape}")
print(f"inds_train shape: {train_inds.shape}")
print(f"inds_test shape: {test_inds.shape}")
print(f"X[inds_train, :] shape: {wLFPs[train_inds, :].shape}")
print(f"Y[inds_train] shape: {flightID[train_inds, 1].shape}")
print(f"X[inds_test, :] shape: {wLFPs[test_inds, :].shape}")
print(f"Y[inds_test] shape: {flightID[test_inds, 1].shape}")


titles = ['Projection (real part)', 'Amplitude', 'Softmax 1', 'Softmax 2 (Output)'];
for i in range(axs.shape[0]):
    print(f"Training network {i + 1} of {axs.shape[0]} (hidden layer size {3 * 2 ** i})")  # try 4 different hidden layer sizes
    m, _, _ = TIMBRE(wLFPs, flightID[:, 1], test_inds, train_inds, hidden_nodes=3, learn_rate=0.001, is_categorical=True, verbosity=1)
    for j in range(axs.shape[1]):  # Loop through each layer
        p = helpers.layer_output(wLFPs[test_inds], m, j)  # Calculate layer's response to input, using only test data
        if j == 0:
            p = p[:, :p.shape[1] // 2]  # just get real component for complex-valued output
            axs[i, 0].set_ylabel(str(3 * 2 ** i) + ' features');
        if i == 0:
            axs[0, j].set_title((titles[j]));
        #axs[i, j].plot(helpers.accumarray(arm_and_pos_binned[test_inds], p));  # plot mean response of layer to test data as a function of position
        axs[i, j].autoscale(enable=True, axis='both', tight=True);
        if i < axs.shape[0] - 1:
            axs[i, j].set_xticks([]);
        else:
            axs[i, j].set_xlabel('Position');
