In [None]:
import os
from pathlib import Path

import einops
import librosa
import matplotlib.pyplot as plt
# import mne
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from IPython.display import Audio, display

# from datasets import DoubleDataset
from nn_modules import BrainModule
# from trainer import Trainer

In [None]:
dirprocess = '../Spatial-Attention/data/process_v2/'
split_type = 'test'
hidden_test = np.load(os.path.join(dirprocess, f'audio/extract_features_{split_type}4.npy'))
hidden_test = torch.tensor(hidden_test, dtype=torch.float32)
hidden_test = einops.rearrange(hidden_test, 'b t f -> b f t')

df_test = pd.read_csv(os.path.join(dirprocess, f'dataframe/df_{split_type}27.csv'))

meg = dict(np.load(os.path.join(
    dirprocess, 'meg/meg27_sr100_default_v1.npz'
)))

In [3]:
df_test.head(15)

Unnamed: 0,subject_id,session_id,story_id,sound_id,onset_meg,sound_length,sound_fname,segment_start,words,meg100_start,meg100_stop,meg120_start,meg120_stop,meg200_start,meg200_stop,wav_start,wav_stop,wav_index
0,1,0,3,5,739.549,81.943311,the_black_willow_5.wav,0,['hard' 'current' 'cheeks' 'understand'],73955,74255,88746,89106,147910,148510,0,48000,0
1,1,0,3,5,739.549,81.943311,the_black_willow_5.wav,1,['understand' 'You'],74055,74355,88866,89226,148110,148710,16000,64000,1
2,1,0,3,5,739.549,81.943311,the_black_willow_5.wav,2,['You' 'mean' 'finished'],74155,74455,88986,89346,148310,148910,32000,80000,2
3,1,0,3,5,739.549,81.943311,the_black_willow_5.wav,3,['You' 'mean' 'finished' 'another' 'story'],74255,74555,89106,89466,148510,149110,48000,96000,3
4,1,0,3,5,739.549,81.943311,the_black_willow_5.wav,4,['finished' 'another' 'story'],74355,74655,89226,89586,148710,149310,64000,112000,4
5,1,0,3,5,739.549,81.943311,the_black_willow_5.wav,5,['story' 'Yes'],74455,74755,89346,89706,148910,149510,80000,128000,5
6,1,0,3,5,739.549,81.943311,the_black_willow_5.wav,6,['Yes' 'it' 'is' 'finished'],74555,74855,89466,89826,149110,149710,96000,144000,6
7,1,0,3,5,739.549,81.943311,the_black_willow_5.wav,7,['Yes' 'it' 'is' 'finished' 'Ended'],74655,74955,89586,89946,149310,149910,112000,160000,7
8,1,0,3,5,739.549,81.943311,the_black_willow_5.wav,8,['it' 'is' 'finished' 'Ended'],74755,75055,89706,90066,149510,150110,128000,176000,8
9,1,0,3,5,739.549,81.943311,the_black_willow_5.wav,9,['Ended' 'tears' 'birch'],74855,75155,89826,90186,149710,150310,144000,192000,9


In [4]:
row = df_test.iloc[500]
print(row)

sound_fname = row['sound_fname']
wav_start = row['wav_start']
wav_stop = row['wav_stop']
wav_index = row['wav_index']

audio, sr = librosa.load(f'../Spatial-Attention/data/bids_anonym/stimuli/audio/{sound_fname}', sr=None)
sample = audio[wav_start:wav_stop]
Audio(sample, rate=sr)

subject_id                            1
session_id                            0
story_id                              3
sound_id                              8
onset_meg                       1170.58
sound_length                  142.41424
sound_fname      the_black_willow_8.wav
segment_start                       127
words               ['Nathan' 'aghast']
meg100_start                     129758
meg100_stop                      130058
meg120_start                     155710
meg120_stop                      156070
meg200_start                     259516
meg200_stop                      260116
wav_start                       2032000
wav_stop                        2080000
wav_index                           500
Name: 500, dtype: object


In [5]:
hidden_test.shape, df_test.shape

(torch.Size([999, 512, 149]), (48951, 18))

In [6]:
len(set(df_test['wav_index']))

999

In [7]:
df_test.iloc[998:1000]

Unnamed: 0,subject_id,session_id,story_id,sound_id,onset_meg,sound_length,sound_fname,segment_start,words,meg100_start,meg100_stop,meg120_start,meg120_stop,meg200_start,meg200_stop,wav_start,wav_stop,wav_index
998,1,0,3,11,1735.437,119.185533,the_black_willow_11.wav,115,['the' 'water' 'of' 'the' 'stream' 'flowing' '...,185044,185344,222052,222412,370087,370687,1840000,1888000,998
999,1,1,3,5,698.994,81.943311,the_black_willow_5.wav,0,['hard' 'current' 'cheeks' 'understand'],69899,70199,83879,84239,139799,140399,0,48000,0


In [8]:
model = BrainModule(
    n_channels_input=208,
    n_channels_attention=270,
    n_channels_unmix=10,
    use_spatial_attention='3D',
    n_spatial_harmonics=24,
    dirprocess=dirprocess,
    spatial_dropout_number=0,
    spatial_dropout_radius=0.1,
    use_unmixing_layer=True,
    use_subject_layer=True,
    n_subjects=27,
    regularize_subject_layer=False,
    bias_subject_layer=False,
    n_channels_block=320,
    n_features=512,
    head_pool='conv',
    head_stride=2,
    meg_sr=100,
    use_temporal_filter=True,
    use_temporal_activation=False,
    use_unmixing_bias=False,
    use_temporal_bias=False
)
checkpoint = torch.load('checkpoint.pt', weights_only=True, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

BrainModule(
  (spatial_module): SpatialModule(
    (self_attention): Spatial3DAttentionLayer()
    (unmixing_layer): Conv1d(270, 270, kernel_size=(1,), stride=(1,), bias=False)
    (subject_layer): SubjectPlusLayer()
  )
  (depthwise_conv): Conv1d(10, 10, kernel_size=(51,), stride=(1,), padding=same, groups=10, bias=False)
  (temporal_module): TemporalModule(
    (conv_blocks): ModuleList(
      (0): ConvBlock(
        (conv1): Conv1d(10, 320, kernel_size=(3,), stride=(1,), padding=same)
        (conv2): Conv1d(320, 320, kernel_size=(3,), stride=(1,), padding=same, dilation=(2,))
        (conv3): Conv1d(320, 640, kernel_size=(3,), stride=(1,), padding=same, dilation=(2,))
        (batchnorm1): BatchNorm1d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (batchnorm2): BatchNorm1d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation1): GELU(approximate='none')
        (activation2): GELU(approximate='none')
        (activat

In [36]:
def get_sound(row):
    sound_fname = row['sound_fname']
    wav_start = row['wav_start']
    wav_stop = row['wav_stop']
    wav_index = row['wav_index']

    audio, sr = librosa.load(f'../Spatial-Attention/data/bids_anonym/stimuli/audio/{sound_fname}', sr=None)
    sample = audio[wav_start:wav_stop]
    
    return sample, sr

In [37]:
def get_meg_data(row_df, meg, hidden, meg_sr):
    meg_offset = 0.15
    meg_offset = int(meg_offset * meg_sr)

    subject_id, session_id, story_id = row_df['subject_id'], row_df['session_id'], row_df['story_id']
    sbj = torch.tensor(subject_id, dtype=torch.long)
    subject_id = str(subject_id)
    subject_id = '0' + subject_id if len(subject_id) == 1 else subject_id
    subset = f'subject{subject_id}_session{session_id}_story{story_id}'

    meg_start, meg_stop = row_df[f'meg{meg_sr}_start'], row_df[f'meg{meg_sr}_stop']
    meg_start, meg_stop = meg_start + meg_offset, meg_stop + meg_offset
    wav_index = row_df['wav_index']

    meg = torch.tensor(meg[subset][:, meg_start:meg_stop], dtype=torch.float32)
    hid = torch.tensor(hidden[wav_index], dtype=torch.float32)
    widx = torch.tensor(wav_index, dtype=torch.long)

    sample, _ = get_sound(row=row_df)

    return meg, sbj, hid, widx, sample


In [38]:
meg_i, sbj_i, hid_i, widx_i, sample_i = get_meg_data(row_df=df_test.iloc[0], meg=meg, hidden=hidden_test, meg_sr=100)

  hid = torch.tensor(hidden[wav_index], dtype=torch.float32)


In [39]:
meg_i.size(), sbj_i, hid_i.size(), widx_i, sample_i.shape

(torch.Size([208, 300]),
 tensor(1),
 torch.Size([512, 149]),
 tensor(0),
 (48000,))

In [40]:
_, predicted_hidden = model((meg_i.unsqueeze(dim=0), sbj_i.unsqueeze(dim=0)))
print(predicted_hidden.size())

torch.Size([1, 512, 149])


In [48]:
def get_similarity(brainwave_embeddings, audio_embeddings):
    brainwave_embeddings = F.normalize(brainwave_embeddings, dim=(-2, -1))
    audio_embeddings = F.normalize(audio_embeddings, dim=(-2, -1))
    similarity = torch.einsum('Bef, bef -> Bb', brainwave_embeddings, audio_embeddings)
    return similarity

In [50]:
similarity = get_similarity(predicted_hidden, hidden_test)

In [51]:
index_top10 = torch.topk(similarity, 10, dim=-1).indices.detach().cpu().tolist()[0]

In [52]:
index_top10

[0, 302, 450, 91, 79, 49, 660, 254, 913, 387]

In [53]:
tracks = []
for wav_idx in index_top10:
    row = df_test.loc[df_test['wav_index'] == wav_idx].iloc[0]
    sample, sr = get_sound(row)
    if wav_idx == widx_i.item():
        print('Correct id')
    else:
        print('Incorrect id')
    print(row['words'])
    display(Audio(sample, rate=sr))

Correct id
['hard' 'current' 'cheeks' 'understand']


Incorrect id
['why' 'Arthur' 'on' 'do' 'you' 'ro']


Incorrect id
['himself' 'a' 'patchwork' 'design' 'of' 'fabric']


Incorrect id
['disappeared' 'thank' 'pander' 'watched' 'period' 'in']


Incorrect id
['I' 'suppose' 'it' 'is' 'because' 'I' 'enjoy' 'my' 'craft']


Incorrect id
['characters' 'who' 'might' 'live' 'therein' 'not']


Incorrect id
['the' 'hill' 'puffing' 'to' 'reveal' 'Arthur' 'below']


Incorrect id
['a' 'creation' 'less' 'fit']


Incorrect id
['Allan' 'continued' 'slowing' 'his' 'pace' 'to' 'climb' 'across']


Incorrect id
['know' 'what' 'this' 'is' 'like' 'said']
