In [16]:
from param_funct import *
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from transformers import AutoProcessor, Wav2Vec2Model, ClapTextModelWithProjection
from transformers import GPT2Tokenizer, GPT2Model, AutoTokenizer

In [17]:
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")

model_text = GPT2Model.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

clap_text = ClapTextModelWithProjection.from_pretrained("laion/clap-htsat-unfused")
clap_tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")

device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
clap_model = clap_text.to(device)
model_text = model_text.to(device)

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
stimuli_path = meg_path + '/stimuli/audio'
patient = ['01']
sub_decim = 10
brain_signal_data = []
audio_spect_data = []
audio_w2v_data = []
text_gpt_data = []

In [None]:
for subject in tqdm(patient):
    print('PATIENT: ', subject)
    for sess in range(len(session)):
        print("SESSION: ", session[sess])
        if sess == 0:
            for story in task_list:
                print('AUDIO_NAME: ', story)
                selected_sound_ids = tasks_with_sound_ids[story]
                story_uid = int(task[story])
                print("STORY_UID: ", story_uid)
                raw = get_bids_raw(meg_path, subject, session[sess], str(story_uid))
                for z, sound_id in enumerate(selected_sound_ids):
                    print("SOUND_ID: ", float(sound_id))
                    epochs_data = get_epochs(raw, float(story_uid), float(sound_id), sub_decim)
                    epoch_signal = get_meg_from_raw_epochs(epochs_data)
                    print('MEG_SHAPE: ', epoch_signal.shape)
                    brain_signal_data.append(epoch_signal)
                    
                    if subject == '01':
                        audio_path = f"{stimuli_path}/{story}_{z}.wav"
                        waveform, sr = torchaudio.load(audio_path)
                        if sr != sampling_audio:
                            waveform = torchaudio.functional.resample(waveform, sr, sampling_audio)
                        waveform = waveform.squeeze(0).to(device)
                        data_audio_chunks = []
                        data_audio_spect = []
                        for j in range(epoch_signal.shape[0]):
                            start = epochs_data[j]._metadata["start"].item()
                            sample_start = round(start * sampling_audio)
                            sample_end = round((start + duration) * sampling_audio)
                            y = waveform[sample_start:sample_end]
                            expected_len = int(duration * sampling_audio)
                            if y.shape[0] < expected_len:
                                pad_len = expected_len - y.shape[0]
                                y = torch.nn.functional.pad(y, (0, pad_len), value=0.0)
                            elif y.shape[0] > expected_len:
                                y = y[:expected_len]
                            data_audio_chunks.append(y)
                            spec = torchaudio.transforms.Spectrogram(n_fft=n_fft_speech, hop_length=hop_len_speech).to(device)(y.unsqueeze(0))
                            spec_db = torchaudio.transforms.AmplitudeToDB()(spec)
                            data_audio_spect.append(spec_db.squeeze(0))
                        audio_tensor_chunk = torch.stack(data_audio_chunks)  # [batch, T]
                        audio_tensor_spect = torch.stack(data_audio_spect)   
                        inputs_w2v = processor(audio_tensor_chunk, sampling_rate=sampling_audio, return_tensors="pt", padding=True)
                        w2v_input = inputs_w2v.input_values.squeeze(0).to(device)
                        with torch.no_grad():
                            outputs = model(w2v_input)
                        last_hidden_w2v = outputs.last_hidden_state.cpu()

                        print('AUDIO_SHAPE: ', audio_tensor_spect.shape)
                        print('AUDIO_W2V: ', last_hidden_w2v.shape)
                        audio_spect_data.append(audio_tensor_spect)
                        audio_w2v_data.append(last_hidden_w2v)
        else:
            continue


brain_signal_data_tensor = torch.cat(brain_signal_data, dim=0)
audio_spect_data_tensor = torch.cat(audio_spect_data, dim=0)
audio_w2v_data_tensor = torch.cat(audio_w2v_data, dim=0)

  0%|          | 0/1 [00:00<?, ?it/s]

PATIENT:  01
SESSION:  0
AUDIO_NAME:  lw1
STORY_UID:  0
Reading 0 ... 395999  =      0.000 ...   395.999 secs...


task_order: "[0, 1, 2, 3]"
n_sessions: 2
mri: fsaverage
native_english_speaker: y
  raw = mne_bids.read_raw_bids(bids_path, verbose=False)


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 6601 samples (6.601 s)

NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
SOUND_ID:  0.0
Applying baseline correction (mode: mean)
MEG_SHAPE:  torch.Size([180, 208, 401])
AUDIO_SHAPE:  torch.Size([180, 257, 126])
AUDIO_W2V:  torch.Size([180, 99, 768])
SOUND_ID:  1.0
Applying baseline correction (mode: mean)
MEG_SHAPE:  torch.Size([139, 208, 401])
AUDIO_SHAPE:  torch.Size([139, 257, 126])
AUDIO_W2V:  torch.Size([139, 99, 768])
SOUN

task_order: "[0, 1, 2, 3]"
n_sessions: 2
mri: fsaverage
native_english_speaker: y
  raw = mne_bids.read_raw_bids(bids_path, verbose=False)


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 6601 samples (6.601 s)

NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
SOUND_ID:  0.0
Applying baseline correction (mode: mean)
MEG_SHAPE:  torch.Size([300, 208, 401])
AUDIO_SHAPE:  torch.Size([300, 257, 126])
AUDIO_W2V:  torch.Size([300, 99, 768])
SOUND_ID:  1.0
Applying baseline correction (mode: mean)
MEG_SHAPE:  torch.Size([325, 208, 401])
AUDIO_SHAPE:  torch.Size([325, 257, 126])
AUDIO_W2V:  torch.Size([325, 99, 768])
SOUN

task_order: "[0, 1, 2, 3]"
n_sessions: 2
mri: fsaverage
native_english_speaker: y
  raw = mne_bids.read_raw_bids(bids_path, verbose=False)


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 6601 samples (6.601 s)

NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
SOUND_ID:  0.0
Applying baseline correction (mode: mean)
MEG_SHAPE:  torch.Size([257, 208, 401])
AUDIO_SHAPE:  torch.Size([257, 257, 126])
AUDIO_W2V:  torch.Size([257, 99, 768])
SOUND_ID:  1.0
Applying baseline correction (mode: mean)
MEG_SHAPE:  torch.Size([190, 208, 401])
AUDIO_SHAPE:  torch.Size([190, 257, 126])
AUDIO_W2V:  torch.Size([190, 99, 768])
SOUN

task_order: "[0, 1, 2, 3]"
n_sessions: 2
mri: fsaverage
native_english_speaker: y
  raw = mne_bids.read_raw_bids(bids_path, verbose=False)


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 6601 samples (6.601 s)

NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
SOUND_ID:  0.0
Applying baseline correction (mode: mean)
MEG_SHAPE:  torch.Size([358, 208, 401])
AUDIO_SHAPE:  torch.Size([358, 257, 126])
AUDIO_W2V:  torch.Size([358, 99, 768])
SOUND_ID:  1.0
Applying baseline correction (mode: mean)
MEG_SHAPE:  torch.Size([262, 208, 401])
AUDIO_SHAPE:  torch.Size([262, 257, 126])
AUDIO_W2V:  torch.Size([262, 99, 768])
SOUN

100%|██████████| 1/1 [23:13<00:00, 1393.58s/it]

AUDIO_SHAPE:  torch.Size([338, 257, 126])
AUDIO_W2V:  torch.Size([338, 99, 768])
SESSION:  1





In [34]:
brain_signal_data_tensor.shape, audio_spect_data_tensor.shape, audio_w2v_data_tensor.shape 

(torch.Size([8561, 208, 401]),
 torch.Size([8561, 257, 126]),
 torch.Size([8561, 99, 768]))

## TEXT

In [None]:
for subject in tqdm(patient):
    epochs_list = []
    for sess in tqdm(session):
        print('---------', str(sess), '----------')
        for task in [0, 1, 2, 3]:
            print('---------', str(task), '----------')
            bids_path = mne_bids.BIDSPath(
                subject=subject,
                session=str(sess),
                task=str(task),
                datatype="meg",
                root=meg_path,
            )
            try:
                raw = mne_bids.read_raw_bids(bids_path, verbose=False)
            except FileNotFoundError:
                print("missing", subject, sess, task)
                pass
            raw = raw.pick_types(
                meg=True, misc=False, eeg=False, eog=False, ecg=False
            )
            raw.load_data().filter(0.5, 30.0, n_jobs=1)
            if task == 0:
                for sound_id in lw1:
                    epochs = get_epochs(raw, float(task), float(sound_id), sub_decim)
                    epochs_list.append(epochs)
            if task == 1:
                for sound_id in cable_spool_fort:
                    epochs = get_epochs(raw, float(task), float(sound_id), sub_decim)
                    epochs_list.append(epochs)
            if task == 2:
                for sound_id in easy_money:
                    epochs = get_epochs(raw, float(task), float(sound_id), sub_decim)
                    epochs_list.append(epochs)
            if task == 3:
                for sound_id in the_black_willow:
                    epochs = get_epochs(raw, float(task), float(sound_id), sub_decim)
                    epochs_list.append(epochs)

    if subject == '01':
        concat_epochs = mne.concatenate_epochs(epochs_list)
        y_text = concat_epochs.metadata.word.to_numpy()
        X_brain = concat_epochs.get_data()
        _, y_text_sentence = generate_sent_matrix(X_brain, y_text)

        # tokenizer.pad_token = tokenizer.eos_token
        inputs_tr = clap_tokenizer(list(y_text_sentence), padding=True, return_tensors="pt")
        clap_model.eval().to('cpu')
        
        with torch.no_grad():
            outputs_tr = clap_model(**inputs_tr)
        last_hidden_states_tr = outputs_tr.last_hidden_state.cpu()

        

In [18]:
inputs_tr = clap_tokenizer(list(y_text_sentence), padding=True, return_tensors="pt")
clap_model.eval().to('cpu')

with torch.no_grad():
        outputs_tr = clap_model(**inputs_tr)
last_hidden_states_tr = outputs_tr.text_embeds

In [None]:
# GPT2: torch.Size([17122, 35, 768])
# CLAP Last: torch.Size([17122, 35, 768])
# CLAP Feat: torch.Size([17122, 512])

In [19]:
last_hidden_states_tr.shape

torch.Size([17122, 512])

In [20]:
meg_path = '/srv/nfs-data/sisko/matteoc/meg'
save_stimulus_dir = os.path.join(meg_path, 'save_stimulus')

torch.save(last_hidden_states_tr, os.path.join(save_stimulus_dir, "clap_512_tensor.pt"))