In [1]:
import os
import matplotlib.pyplot as plt
import numpy as np
import pynwb

# Define the path to the folder
folder_path = '000017/sub-Cori/'

# Check if the folder exists and list its contents
if os.path.exists(folder_path):
    contents = os.listdir(folder_path)
else:
    contents = None

contents

['sub-Cori_ses-20161214T120000.nwb',
 'sub-Cori_ses-20161218T120000.nwb',
 'sub-Cori_ses-20161217T120000.nwb']

In [2]:
def open_nwb_data(nwb_file_path):
    """
    Opens an NWB file and returns the NWB data object.

    Parameters:
    nwb_file_path (str): The file path of the NWB file.

    Returns:
    NWBData: An object containing the NWB data.
    """
    with pynwb.NWBHDF5IO(nwb_file_path, 'r') as io:
        return io.read()


In [3]:
nwb_file_path = f'{folder_path}/{contents[0]}'  # Replace with your NWB file path
nwb_data = open_nwb_data(nwb_file_path)
nwb_data

  warn("%s '%s': Length of data does not match length of timestamps. Your data may be transposed. "


In [67]:
def get_spiking_data(nwb_data, start_time, end_time):
    """
    Retrieves spike times for all neurons during a specific time interval.

    Parameters:
    nwb_data (NWBData): The NWB data object.
    start_time (int): Start time.
    end_time (int): End time.

    Returns:
    numpy.ndarray: A NumPy array containing all spike times for all neurons.
    """
    spike_data = []

    if nwb_data.units:
        # Iterate through all units
        for i in range(len(nwb_data.units)):
            unit = nwb_data.units[i]
            spike_times = unit['spike_times'][:]
            # Filter spike times within the given interval
            trial_spike_times = spike_times[(spike_times >= start_time) & (spike_times <= end_time)]
            spike_data.extend(trial_spike_times)

    print(spike_data)

    return np.array(spike_data)

In [5]:
def get_interval_times(nwb_sub_data, interval_type):
    """
    Opens an NWB file and retrieves start and stop times for specified intervals.

    Parameters:
    nwb_file_path (str): The file path of the NWB file.
    interval_type (str): Type of interval ('spontaneous' or 'trials').

    Returns:
    tuple of lists: Two lists, one containing the start times and another containing the stop times for each interval.
    """
    if nwb_sub_data is None or len(nwb_sub_data) == 0:
        raise ValueError(f"No data found for interval type '{interval_type}'")

    try:
        start_times = nwb_sub_data['start_time'].data[:]
        stop_times = nwb_sub_data['stop_time'].data[:]
    except Exception as e:
        raise RuntimeError(f"Error extracting data: {e}")

    return start_times, stop_times


In [None]:
get data
get all missed data
get the start and stop times for missed data
get spike data
save in data > missed

In [29]:
def get_missed_time_data(nwb_file_path):
    with pynwb.NWBHDF5IO(nwb_file_path, 'r') as io:
        nwb_data = io.read()
        
        # Extracting data
        response_choice = nwb_data.intervals['trials']['response_choice'].data[:]
        start_time = nwb_data.intervals['trials']['start_time'].data[:]
        stop_time = nwb_data.intervals['trials']['stop_time'].data[:]

        # Filtering data where response_choice is not -1
        valid_indices = response_choice == 0
        filtered_data = np.array([
            start_time[valid_indices],
            stop_time[valid_indices]
        ])
        
    return filtered_data

missed_times = get_missed_time_data(nwb_file_path)


In [None]:
def get_missed_time_data(nwb_file_path):
    with pynwb.NWBHDF5IO(nwb_file_path, 'r') as io:
        nwb_data = io.read()
        
        # Extracting data
        response_choice = nwb_data.intervals['trials']['response_choice'].data[:]
        start_time = nwb_data.intervals['trials']['start_time'].data[:]
        stop_time = nwb_data.intervals['trials']['stop_time'].data[:]

        # Filtering data where response_choice is not -1
        valid_indices = response_choice == 0
        filtered_data = np.array([
            start_time[valid_indices],
            stop_time[valid_indices]
        ])
        
    return filtered_data

missed_times = get_missed_time_data(nwb_file_path)


In [45]:
import pynwb
import numpy as np

def get_passive_time_data(nwb_file_path):
    with pynwb.NWBHDF5IO(nwb_file_path, 'r') as io:
        nwb_data = io.read()
        data = np.array([
            nwb_data.intervals['spontaneous']['start_time'].data[:],
            nwb_data.intervals['spontaneous']['stop_time'].data[:],
        ])
    return data

passive_times = get_passive_time_data(nwb_file_path)


In [51]:
print(missed_times.shape)
print(passive_times.shape)

(2, 74)
(2, 4)


In [55]:
print(np.mean(missed_times[1]  - missed_times[0]))
print(np.min(missed_times[1]  - missed_times[0]))
print(np.max(missed_times[1]  - missed_times[0]))

4.65988056195935
3.396688497112109
11.212790198471794


In [48]:
np.mean(passive_times[1]  - passive_times[0])

58.422704880107936

Given the lack of passive data, we will take 5 seconds of missed trial data and segment passive data to 5s intervals. This will give 48 data points. The trial data also massively varies. I will take 5 second intervals and pad the end with zeros for all data to make sure we have max amount of data although may cause model to bias due to it being more frequent in the missed data

In [58]:
loop through array
get 5 s of data
ensure 5s of data with padding
move through (either next for missed or next 5s for passive until end)

SyntaxError: invalid decimal literal (367469315.py, line 3)

In [126]:
import numpy as np

def bin_spikes(spike_times, bin_size=0.1):
    """
    Bins the spike times into fixed intervals.

    Parameters:
    spike_times (np.array): Array of spike times for a neuron.
    bin_size (float): The size of each time bin in seconds.

    Returns:
    np.array: An array representing the number of spikes in each time bin.
    """
    if spike_times.size == 0:
        return np.array([])

    # Calculate the number of bins needed
    max_time = np.max(spike_times)
    num_bins = int(np.ceil(max_time / bin_size))

    # Use numpy histogram to bin the spikes
    binned_spikes, _ = np.histogram(spike_times, bins=num_bins, range=(0, max_time))

    return binned_spikes

def get_spiking_data(nwb_data, trial_start_time, trial_end_time):
    """
    Retrieves and bins spike times for specified neurons during a specific trial.

    Parameters:
    nwb_data (NWBData): The NWB data object.
    trial_start_time (float): The start time of the trial.
    trial_end_time (float): The end time of the trial.

    Returns:
    np.array: A 2D array with binned spike counts for each neuron across the trial duration.
    """
    bin_size = 0.1
    no_neurons = len(nwb_data.units)
    
    # Calculate the number of bins for the given trial duration
    trial_duration = trial_end_time - trial_start_time
    bin_no = int(np.ceil(trial_duration / bin_size))
    spike_data = np.zeros((bin_no, no_neurons))

    if nwb_data.units:
        for i in range(no_neurons):
            unit = nwb_data.units[i]
            spike_times = unit['spike_times'][:].values[0]
            trial_spike_times = spike_times[(spike_times >= trial_start_time) & (spike_times <= trial_end_time)]
            spike_times_normed = trial_spike_times - trial_start_time
            try: 
                spike_data[:, i] = bin_spikes(spike_times_normed, bin_size)
            except ValueError:
                continue

    return spike_data


    
def get_missed_data(nwb_file_path, missed_times):
    with pynwb.NWBHDF5IO(nwb_file_path, 'r') as io:
        nwb_data = io.read()
        for start_time in missed_times[0]:
            data = get_spiking_data(nwb_data, start_time, start_time+10)
            np.save(f'data10/missed/{start_time}', data)
    return None
    
get_missed_data(nwb_file_path, missed_times)

  warn("%s '%s': Length of data does not match length of timestamps. Your data may be transposed. "


In [127]:
def create_more_passive(passive_times, time_bins = 10):
    times = []
    tdiff = passive_times[1] - passive_times[0]
    for i in range(passive_times.shape[1]):
        time_avail = tdiff[i]
        j = 0 
        while time_avail > time_bins:
            new_t = passive_times[0][i] + (j * time_bins)
            times.append(new_t)
            j += 1
            time_avail -= time_bins
    return times
passive_times_added = create_more_passive(passive_times)

In [128]:
def get_passive_data(nwb_file_path, passive_times):
    with pynwb.NWBHDF5IO(nwb_file_path, 'r') as io:
        nwb_data = io.read()
        for start_time in passive_times:
            data = get_spiking_data(nwb_data, start_time, start_time+5)
            np.save(f'data10/passive/{start_time}', data)
    return None
    
get_passive_data(nwb_file_path, passive_times_added)

  warn("%s '%s': Length of data does not match length of timestamps. Your data may be transposed. "
