# Experiment 2 data extraction
This notebook contains code to extract the audio data from the NSX file, generate an automatic
transcript, calculate the spike rates from recorded spike events
and align the spikerates with the transcribed sentences. The goal is to create a dataset consisting of feature (spike rates) / target (sentences) pairs that can be used to train a decoder model.

In [None]:
import pickle
from datetime import datetime
from math import ceil

import brpylib
import numpy as np
import pandas as pd
import sounddevice as sd
from scipy.io.wavfile import write
from tqdm import tqdm

from thesis_project.settings import DATA_DIR, EXPERIMENT2_DIR
from thesis_project.data_loading import construct_spikerates_filename

## Load data

### Neural and audio data

In [4]:
nev_file1 = brpylib.NevFile(EXPERIMENT2_DIR + "/Experiment/20240708-141522/Hub1-20240708-141522-001.nev") # contains the audio recording
nsx_file = brpylib.NsxFile(EXPERIMENT2_DIR+ "/Experiment/20240708-141522/NSP-20240708-141522-001.ns6") # contains the spike events
spike_events = nev_file1.getdata(wave_read=True)['spike_events']


Hub1-20240708-141522-001.nev opened

NSP-20240708-141522-001.ns6 opened


In [5]:
sampling_rate = nsx_file.basic_header['SampleResolution']
sorted_timestamps = spike_events['TimeStamps']
sorted_timestamps.sort() # spike event timestamps in nanoseconds

In [6]:
audio_data = nsx_file.getdata()['data'][0][0]

## Create audio transcript

In [None]:
AUDIO_PATH = f"{DATA_DIR}/audio/recording.wav"
WHISPER_RESULT_PATH = f"{DATA_DIR}/audio/whisper_output.pkl"
WHISPER_RESULT_SENTENCES_PATH = f"{DATA_DIR}/audio/whisper_result_sentences.pkl"

# save audio recording in a .wav file
factor = 50000
scaled = np.int16(audio_data / np.max(np.abs(audio_data)) * factor)
write(AUDIO_PATH, sampling_rate, scaled)

### Transcribe with whisperx

In [None]:
! pip install whisperx

import whisperx
import gc

device = "cuda"
audio_file = "clean_audio_end.wav"
language="de"
batch_size = 16 # reduce if low on GPU memory
compute_type = "float16" # change to "int8" if low on GPU memory (may reduce accuracy)

# transcribe with original whisper (batched)
model = whisperx.load_model("large-v2", device, compute_type=compute_type)
audio = whisperx.load_audio("recording.wav")
result = model.transcribe(audio, batch_size=batch_size)

# align
model_a, metadata = whisperx.load_align_model(language_code=language, device=device)
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)

# save file
with open(WHISPER_RESULT_PATH, "wb") as file:
  pickle.dump(result, file)

### Convert timestamps
Extract whisperx timestamps from the transcript (in seconds) and convert them to indices in the audio recording np array.


In [None]:
with open(WHISPER_RESULT_PATH, "rb") as file:
  whisper_result = pickle.load(file)


# multiply with sampling rate to get timestamps

result_sentences = []

for segment in whisper_result["segments"]:
    start = round(segment["start"] * sampling_rate)
    end = round(segment["end"] * sampling_rate)
    result_sentences.append((segment["text"], (start, end)))

# uncomment code to overwrite file
# with open(WHISPER_RESULT_SENTENCES_PATH, "wb") as file:
#     pickle.dump(result_sentences, file)

## Compute spikerates

### Load sentences

In [None]:
# ### Load sentences
# with open(WHISPER_RESULT_SENTENCES_PATH, "rb") as file:
#     audio_segments = pickle.load(file)

In [17]:
### Load sentences
with open(f"{EXPERIMENT2_DIR}/whisper_sentences.pkl", "rb") as file:
    audio_segments = pickle.load(file)

all_sentences = pd.read_csv(f"{EXPERIMENT2_DIR}/labels.csv")
all_sentences = [(row["word"], (row["start"], row["end"])) for _, row in all_sentences.iterrows()]

In [27]:
# check if calculated indices correspond to the correct intervals in the audio recording
sentence_idx = 75
start_time, end_time = audio_segments[sentence_idx][1][0], audio_segments[sentence_idx][1][1]
sentence = audio_segments[sentence_idx][0]
sd.play(audio_data[start_time:end_time], samplerate=sampling_rate)
print(sentence)

Zum Beispiel, der Sprit für die Traktoren kostet bald mehr Geld.


### Audio / spike alignment
The spike events' audio timestamps do not correspond to the time origin, so it is unclear how
much time passes between the time origin and the first spike.

In [11]:
# spike file time origin
nsx_file.basic_header['TimeOrigin'], datetime.strftime(nsx_file.basic_header['TimeOrigin'], '%Y-%m-%d %H:%M:%S')

(datetime.datetime(2024, 7, 8, 12, 15, 22, 950000), '2024-07-08 12:15:22')

In [12]:
# first spike timestamp
first_spike = datetime.fromtimestamp(sorted_timestamps[0] / 1e9)
first_spike, datetime.strftime(first_spike, '%Y-%m-%d %H:%M:%S')

(datetime.datetime(2024, 4, 14, 15, 23, 13, 72738), '2024-04-14 15:23:13')

I am assuming that the first spike's timestamp corresponds to the beginning of the audio file,
since they are closely aligned (roughly 1ms difference between the duration of the audio file
and the timespan between the maximum and minimum spike timestamp).

In [13]:
# max timestamp - min timestamp (duration) in minutes
(sorted_timestamps[-1] - sorted_timestamps[0]) / 1e9 / 60

115.3453918995

In [20]:
# audio duration in minutes
audio_data = nsx_file.getdata()
len(audio_data['data'][0][0]) / sampling_rate / 60

115.34549944444444

In [21]:
# recording delay in minutes
recording_delay = nsx_file.basic_header['TimeOrigin'] - nev_file1.basic_header['TimeOrigin'] # audio recording is started after neural recording
recording_delay = recording_delay.microseconds * 1000 # convert to nanoseconds
recording_delay / 1e9 / 60

0.00013333333333333334

### Compute sentence spikerates

In [28]:
min_timestamp, max_timestamp = sorted_timestamps[0], sorted_timestamps[-1]

audio_segments = all_sentences

audio_time_factor = 1e9 / sampling_rate # convert audio timestamp in seconds to nanoseconds
buffer_before = 100 * 1e6 # optional buffer times in ms to account for delay in neural processing (~ 100ms)
buffer_after = 100 * 1e6

In [29]:
def audio_to_spike_timestamp(start_timestamp: int, end_timestamp: int) -> tuple[int, int]:
    """
    Convert start and end timestamp of an audio segment (e.g. a sentence or word) in seconds
    to the corresponding index in the spike event array.
    """
    return int(start_timestamp * audio_time_factor + buffer_before - recording_delay + min_timestamp), \
        int(end_timestamp * audio_time_factor + buffer_after - recording_delay + min_timestamp)

The following script assigns each spike event to the sentence during which it occurs.

In [None]:
# start and end timestamps of current audio audio segment
# each audio segment corresponds to a sentence
current_segment_start_timestamp, current_segment_end_timestamp = \
    audio_to_spike_timestamp(audio_segments[0][1][0], audio_segments[0][1][1])

# index of current audio segment
current_segment_index = 0

# list lists of spike timestamp indices per segment index (i.e. index in the ´audio_segments list´)
segment_timestamps = []

# list of spike timestamps indices (i.e. indices in the ´sorted_timestamp´ list) corresponding to the current audio segment
current_segment_timestamps = []

# go over all spike timestamps and assign them to the corresponding audio segments, if applicable
# assumes that spike timestamps and audio timestamps are in ascending order

for i, event_timestamp in tqdm(enumerate(sorted_timestamps)):

    if event_timestamp > current_segment_end_timestamp:
        # increment the segment index if the current spike happens after the segment

        while event_timestamp > current_segment_end_timestamp:
            # skip all audio segments where no spikes occur

            segment_timestamps.append(current_segment_timestamps)
            current_segment_index += 1
        

            if current_segment_index >= len(audio_segments):
                # stop when the last audio segment has been added
                break

            current_segment_start_timestamp, current_segment_end_timestamp = \
                audio_to_spike_timestamp(audio_segments[current_segment_index][1][0], audio_segments[current_segment_index][1][1])

            # create empty spike timestamp array for current audio segment
            current_segment_timestamps = []


    if current_segment_index >= len(audio_segments):
        # stop when the last audio segment has been added
        break

    if event_timestamp < current_segment_start_timestamp:
        # spike has occurred between audio segments
        continue

    # append the current spike index to the current segment's timestamp list
    current_segment_timestamps.append(i)



0it [00:00, ?it/s]

1954676it [00:02, 772360.48it/s]


In [26]:
max_audio_duration = 0 # in nanoseconds
min_audio_timestamps = []

# audio recording durations in ms
for _, (start_time, end_time) in audio_segments:
    duration = (end_time - start_time) * audio_time_factor
    min_audio_timestamps.append(start_time)

    if duration > max_audio_duration:
        max_audio_duration = duration

In [27]:
"Maximum duration of a sentence in seconds:", max_audio_duration / 1e9

('Maximum duration of a sentence in seconds:', 22.455)

Calculate the spike rate timeseries for each sentence by binning the spike events.

In [28]:
n_channels = 256
bin_size = 100 * 1e6 # small bin size in ms because of short sentence duration
n_bins = ceil(max_audio_duration / bin_size) + 1

spikerates = np.zeros((len(audio_segments), n_bins, n_channels))

for i, (segment_idx, audio_segment) in tqdm(enumerate(zip(segment_timestamps, audio_segments))):

    min_audio_timestamp = audio_segment[1][0] * audio_time_factor

    for j in segment_idx:

        timestamp = spike_events['TimeStamps'][j] - spike_events['TimeStamps'][0]
        channel = spike_events['Channel'][j]

        idx = int((timestamp - min_audio_timestamp) / bin_size)
        spikerates[i][idx][channel - 1] += 1

650it [00:02, 318.90it/s]


In [29]:
spikerates.shape

(650, 226, 256)

In [30]:
# Sanity check:
# Check the number of unique spikerate timeseries

n_uniques = list(zip(*np.unique(spikerates, return_counts = True, axis = 0)))
print(f"Number of unique spikerate timeseries: {len(n_uniques)} of {len(spikerates)}")

Number of unique spikerate timeseries: 650 of 650


### Save spikerates

In [101]:
session_id = nsx_file.basic_header['TimeOrigin'].strftime("%Y%m%d")

spikerates_filename = construct_spikerates_filename(session_id=session_id,
                                                    path=f"{EXPERIMENT2_DIR}/binned_spikerates",
                                                    bin_size=int(bin_size / 1e6),
                                                    experiment="experiment2")

with open(file=spikerates_filename, mode="wb") as file:
    pickle.dump(spikerates, file)

sentences = [audio_segment[0] for audio_segment in audio_segments]
with open(file=f"{EXPERIMENT2_DIR}/sentences_new.pkl", mode="wb") as file:
    pickle.dump(sentences, file)