In [1]:
import hdf5storage
from helpers import *
from get_data import *
from synchrony import *
import numpy as np


%load_ext autoreload
%autoreload 2


### Loading in the bat data (LFP and positional data)
#### We first need to load in the LFP data, which in this case is stored in a MATLAB file. We can do this using ```hdf5storage```. The bat's positional data is stored in a matlab file (not accessible for public use), but luckily the accessors for this data can be found in ```dataset.py``` thanks to the Yartsev Lab.

In [None]:
data_path = './bat/data'
bat_id = '32622'
date = '231007'
lfp_file_path = './bat/data/ephys/32622_231007_lfp.mat'


#Clean up position data (remove NaNs, etc.) and load in LFP from given file path
lfp_mat, cleaned_pos, session = load_and_clean_bat_data(data_path, bat_id, date, lfp_file_path)

### Time synchronization
#### Before we get to the main attraction (the LFP data), we need to ensure our data is synchronized. To do this, we need to extract global timestamps from both the LFP and positional data and make sure they start at the same time.

In [None]:
lfp_timestamps_edges, binned_pos, pos_timestamps, lfp_indices, valid_indices = sync_and_bin_data(lfp_mat, session)

#lfp_timestamp_edges stores edges between timebins. this will be useful for aligning the LFP data with the position data
#binned_pos is the cleaned position averaged over the timebins
#valid_indices is a boolean array that marks the non-negative position timestamps
#pos_timestamps is the cleaned and filtered timestamps of the position data
#lfp_indices is a boolean array that marks the non-negative, decimated LFP timestamps


#### Inside of ```lfp_timestamps_edges```, we store the *edges* between timebins. We will use this to later to bin the position data; instead of downsampling the data like we did the LFP, we will average across bins (between two edges) of the LFP timebins to get synchronized data streams.

In [None]:
print("First few elements of binned_pos:\n", binned_pos[:, :5]) # NaN values at beginning and end are expected; position is not recorded when bat is not visible.

print("First few LFP bins:", lfp_timestamps_edges[:5])

##### Notice above that the LFP timestamp edges have N+1 the shape of the binned position. This makes sense and is expected; `lfp_timestamps_edges` contains the bins (which are stored in groups of two, i.e. the first bin is [0, 4514.4426] and so on) for which the position was binned into.

### Organizing behavioral data
#### To better organize the binned flight data, we need to construct a flightID array which will contain all the binned positions for each flight, accounting for which feeder (or the perch) was visited for each data point entered in that flight.

In [20]:
from get_data import get_flightID

flightID = get_flightID(session, binned_pos, valid_indices, lfp_timestamps_edges, pos_timestamps, off_samples = 125) #includes the 5 seconds before and after flight

## LFP extraction and downsampling

In [None]:
lfp_bat_combined = extract_and_downsample_lfp_data(lfp_mat)

##### Imported raw LFP

In [None]:
from synchrony import plot_raw_lfp

# Example usage
plot_raw_lfp(lfp_bat_combined, n_channels=192, start_time=0, end_time=100)


#### Once LFP is loaded in and downsampled, we can apply a filter and Hilbert transform to get our complex-valued LFP!

#### *Note: At 25hz, a signal has at most 12.5hz frequency of usable data. Given this property, we don't need to do a bandpass filter (to cap out the high and low range). As such, we only need to do a highpass filter at 1hz.*

In [None]:
LFPs = filter_data(lfp_bat_combined, 1, fs=25, filt_type='high', use_hilbert=True) 
LFPs.shape

In [None]:
LFPs = LFPs[lfp_indices]
LFPs.shape


#### We now have our processed LFP. `LFPs` contains the filtered and (Hilbert) transformed LFP data for all of the valid `binned_pos` entries. However, we are mainly interested in the bat flights, which are just a *fraction* of the total of `binned_pos`. To filter out the non-flight entries from the LFP, we will apply a similar filtering method as we did in `get_flightID` with a `get_flightLFP` function:

In [None]:
from get_data import get_flightLFP

flightLFP = get_flightLFP(session, LFPs, valid_indices, lfp_timestamps_edges, pos_timestamps, off_samples=125) # Make sure off_samples is the same for flightID and flightLFP.
flightLFP.shape

# Applying TIMBRE


In [None]:
import numpy as np
from matplotlib import pyplot as plt
from TIMBRE import TIMBRE
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

# Define class names matching your labels (labels from 0 to 5)
class_names = [
    'Perch to Feeder 1',       # Label 0
    'Feeder 1 to Perch',       # Label 1
    'Perch to Feeder 2',       # Label 2
    'Feeder 2 to Perch',       # Label 3
    'Feeder 1 to Feeder 2',    # Label 4
    'Feeder 2 to Feeder 1'     # Label 5
]

# Initialize parameters
n_folds = 5
hidden_sizes = [3, 6, 12, 24]
all_accuracies = []
all_cm = {}

for hidden_size in hidden_sizes:
    print(f"\nEvaluating hidden size: {hidden_size} nodes")
    fold_accuracies = []
    cm_total = None

    for which_fold in range(n_folds):
        print(f"  Fold {which_fold + 1}/{n_folds}")

        # Get train and test indices
        test_inds, train_inds = test_train_bat(flightID, n_folds, which_fold)

        # Whiten LFPs
        wLFPs, _, _ = whiten(LFPs, train_inds)

        # Adjust labels to start from 0
        labels = flightID[:, 1].astype(int) - 1

        # Train the TIMBRE model
        m, _, _ = TIMBRE(
            wLFPs, labels, test_inds, train_inds,
            hidden_nodes=hidden_size, is_categorical=True
        )

        # Get predictions on test data
        output_layer = layer_output(wLFPs[test_inds], m, -1)
        predictions = np.argmax(output_layer, axis=1)
        true_labels = labels[test_inds]

        # Calculate accuracy
        accuracy = np.mean(predictions == true_labels)
        fold_accuracies.append(accuracy)

        # Compute confusion matrix
        labels_list = np.arange(len(class_names))  # Labels from 0 to 5
        cm = confusion_matrix(
            true_labels, predictions, labels=labels_list
        )

        # Accumulate confusion matrices
        if cm_total is None:
            cm_total = cm
        else:
            cm_total += cm

        # Classification report
        report = classification_report(
            true_labels, predictions,
            labels=labels_list,
            target_names=class_names, zero_division=0
        )
        print(f"Classification Report for Fold {which_fold + 1}:\n{report}")

    # Average accuracy
    avg_accuracy = np.mean(fold_accuracies)
    all_accuracies.append(avg_accuracy)
    all_cm[hidden_size] = cm_total / n_folds

    print(f"Average accuracy for hidden size {hidden_size}: {avg_accuracy:.4f}")

# Plot average accuracy vs hidden layer size
plt.figure(figsize=(10, 5))
plt.plot(hidden_sizes, all_accuracies, marker='o')
plt.title('Model Accuracy vs Hidden Layer Size')
plt.xlabel('Number of Hidden Nodes')
plt.ylabel('Average Accuracy over Folds')
plt.grid(True)
plt.show()

# Plot average confusion matrices for each hidden size
fig, axs = plt.subplots(2, 2, figsize=(20, 15))
fig.suptitle('Average Confusion Matrices for Bat Flight End Position Prediction', fontsize=16)

for idx, hidden_size in enumerate(hidden_sizes):
    cm_avg = all_cm[hidden_size]

    # Normalize confusion matrix to show percentages
    cm_normalized = cm_avg.astype('float') / cm_avg.sum(axis=1)[:, np.newaxis]
    cm_normalized = np.nan_to_num(cm_normalized)  # Replace NaNs with zeros if any class has zero samples

    ax = axs[idx // 2, idx % 2]
    sns.heatmap(
        cm_normalized, annot=True, fmt='.2f', ax=ax,
        xticklabels=class_names, yticklabels=class_names, cmap='Blues'
    )
    ax.set_title(f'Confusion Matrix (Hidden Nodes: {hidden_size})')
    ax.set_xlabel('Predicted')
    ax.set_ylabel('True')

plt.tight_layout()
plt.show()


In [35]:
def group_by_pos_bat(positions, n_bins, train_inds):
    min_pos = np.min(positions[train_inds], axis=0)
    max_pos = np.max(positions[train_inds], axis=0)
    return np.floor((positions - min_pos) / (max_pos - min_pos) * n_bins).astype(int)

In [None]:
from matplotlib import pyplot as plt
from TIMBRE import TIMBRE
import numpy as np
import helpers  # Assuming you have a helpers module
from sklearn.decomposition import PCA

fig, axs = plt.subplots(4, 4, figsize=(20, 15))
fig.suptitle('TIMBRE Model Performance for Bat Flight Position Prediction', fontsize=16)

n_folds = 5
which_fold = 0
n_bins = 20  # Adjust as needed

# Step 1: Obtain test and train indices
test_inds, train_inds = test_train_bat(flightID, n_folds, which_fold)

# Step 2: Whiten the LFPs
wLFPs, _, _ = helpers.whiten(LFPs, train_inds)

# Step 3: Extract positions and apply PCA
positions = flightID[:, 2:5]  # X, Y, Z positions
pca = PCA(n_components=1)
positions_1d = pca.fit_transform(positions).flatten()

# Step 4: Bin the 1D positions
pos_bins = np.linspace(positions_1d.min(), positions_1d.max(), n_bins + 1)
pos_binned = np.digitize(positions_1d, bins=pos_bins) - 1
labels = pos_binned

# Step 5: Training and Plotting
titles = ['Projection (real part)', 'Amplitude', 'Softmax 1', 'Softmax 2 (Output)']
for i in range(axs.shape[0]):
    hidden_nodes = 3 * 2 ** i
    print(f"Training network {i + 1} of {axs.shape[0]} (hidden layer size {hidden_nodes})")
    
    # Train the TIMBRE model
    m, _, _ = TIMBRE(wLFPs, labels, test_inds, train_inds, hidden_nodes=hidden_nodes)
    
    for j in range(axs.shape[1]):
        # Calculate layer's response to input, using only test data
        p = helpers.layer_output(wLFPs[test_inds], m, j)
        
        if j == 0:
            p = p[:, :p.shape[1] // 2]
            axs[i, 0].set_ylabel(f'{hidden_nodes} features')
        
        if i == 0:
            axs[0, j].set_title(titles[j])
        
        # Compute mean response per position bin
        mean_response = helpers.accumarray(labels[test_inds], p)
        
        # Plot the mean response
        axs[i, j].plot(mean_response)
        axs[i, j].autoscale(enable=True, axis='both', tight=True)
        
        if i < axs.shape[0] - 1:
            axs[i, j].set_xticks([])
        else:
            axs[i, j].set_xlabel('Position along Principal Component')
