In [25]:
import mne
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tqdm

from scipy import signal
from scipy.io import wavfile
from scipy.stats import pearsonr, zscore
from mne_bids import BIDSPath
from functools import partial
from nilearn.plotting import plot_markers

import torch
from torch import nn
import torchaudio
from transformers import WhisperProcessor, WhisperModel, AutoFeatureExtractor

  from .autonotebook import tqdm as notebook_tqdm


In [26]:
model = WhisperModel.from_pretrained("openai/whisper-base")
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
processor = WhisperProcessor.from_pretrained("openai/whisper-base")



In [35]:
# PARAMS

bids_root = "/srv/nfs-data/sisko/storage/ECoG_podcast/ds005574-1.0.2" 
subject = '03'
func = partial(zscore, axis=1)
ecog_sr = 512
whisper_sr = 16000

## Get Data

In [34]:
def preprocess_raw_audio(x, fs, to_fs=ecog_sr, lowcut=200, highcut=5000):

    # See https://scipy-cookbook.readthedocs.io/items/ButterworthBandpass.html
    def butter_bandpass(lowcut, highcut, fs, order=5):
        nyq = 0.5 * fs
        low = lowcut / nyq
        high = highcut / nyq
        b, a = signal.butter(order, [low, high], btype="band")
        return b, a

    def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
        b, a = butter_bandpass(lowcut, highcut, fs, order=order)
        y = signal.lfilter(b, a, data)
        return y

    assert x.ndim == 1

    y = butter_bandpass_filter(x, lowcut, highcut, fs, order=5)
    envelope = np.abs(signal.hilbert(y - y.mean()))

    return envelope

In [30]:
audio_path = f"{bids_root}/stimuli/podcast.wav"

audio_sf, audio_wave = wavfile.read(audio_path)
if audio_wave.ndim > 1:
    audio_wave = audio_wave[:, 0]
audio_wave_clean = preprocess_raw_audio(audio_wave, audio_sf)

  audio_sf, audio_wave = wavfile.read(audio_path)


In [31]:
file_path = BIDSPath(root=bids_root+"/derivatives/ecogprep",
                     subject=subject,
                     task="podcast",
                     datatype="ieeg",
                     description="highgamma",
                     suffix="ieeg",
                     extension="fif")

  file_path = BIDSPath(root=bids_root+"/derivatives/ecogprep",


In [None]:
# def get_whisper_embedding(audio, T):

#     model.encoder.embed_positions = nn.Embedding(T, 512)
#     model.eval()
#     audio = torchaudio.transforms.Resample(audio_sf, whisper_sr)(torch.tensor(audio).float())

    


In [6]:
transcript_path = f"{bids_root}/stimuli/podcast_transcript.csv"

df = pd.read_csv(transcript_path)
df.dropna(subset=['start'], inplace=True)
df.sort_values("start", inplace=True)
events = np.zeros((len(df), 3))
events[:, 0] = df.start

In [None]:
def get_stimuli_and_brain(file_path, audio_wave_clean, audio_sf, df, events, tmax=2.0):

    raw = mne.io.read_raw_fif(file_path, verbose=False)
    raw.load_data(verbose=False)
    # raw = raw.apply_function(func, channel_wise=False, verbose=False)

    epochs = mne.Epochs(
        raw,
        (events * raw.info['sfreq']).astype(int),
        tmin=-1.0,
        tmax=tmax,
        baseline=None,
        proj=None,
        event_id=None,
        preload=True,
        event_repeated="merge",
        verbose=False
    )
    epochs_snippet = epochs._data
    good_idx = epochs.selection
    print(f"Epochs object has a shape of: {epochs_snippet.shape}")

    audio_snippet = np.zeros((len(good_idx), int(tmax * audio_sf)))
    for idx, row in tqdm.tqdm(enumerate(good_idx)):
        row = df.iloc[idx]
        start_sample = int((row['start']) * audio_sf) 
        end_sample = start_sample + int(tmax * audio_sf)
        snippet = audio_wave_clean[start_sample:end_sample]
        if len(snippet) < int(tmax * audio_sf):
            padding_len = int(tmax * audio_sf) - len(snippet)
            snippet = np.pad(snippet, (0, padding_len), mode='constant')
        audio_snippet[idx, :] = snippet
    print(f"Audio object has a shape of: {audio_snippet.shape}")

    return epochs_snippet, audio_snippet

In [24]:
brain_data, audio_data = get_stimuli_and_brain(file_path, audio_wave_clean, audio_sf, df, events)

Epochs object has a shape of: (5130, 235, 1537)


5130it [00:01, 3075.42it/s]

Audio object has a shape of: (5130, 88200)



