In [4]:
import logging
from pathlib import Path
import os
import sys
import mne
import torch
import numpy as np
import bm
from bm import play
from bm.train import main
from bm.events import Word
from matplotlib import pyplot as plt
from IPython import display as disp

mne.set_log_level(False)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
os.chdir(main.dora.dir.parent)
os.environ['NO_DOWNLOAD'] = '1'

In [5]:
sigs = ['34219380', '6e3bf7d7', '557f5f8a', '4395629c']
# sigs = ['889cf0fa']

In [7]:
def _get_segments_and_vocabs(solver):
    from scripts.run_eval_probs import _get_extra_info
    per_split = {}
    for split in ['train', 'test']:
        segments = set()
        sentences = set()
        vocab = set()
        dset = getattr(solver.datasets, split)
        loader = solver.make_loader(dset, shuffle=False)
        for idx, batch in enumerate(loader):
            data, *_ = _get_extra_info(batch, solver.args.dset.sample_rate)
            time_to_main_word = 0 - solver.args.dset.tmin  # location of main word relative to segment start
            # e.g. with MNE we have tmin=-0.5 so the main word is 0.5 seconds after start of MNE Epoch.
            margin = 2 # we need to look a bit after 0.5 due to rounding error, this is in time steps.
            look_at_index = int(time_to_main_word * solver.args.dset.sample_rate + margin)
            word_index = data[:, 0, look_at_index]
            sequence_id = data[:, 1, look_at_index]
            segment_ids = list(zip(word_index.tolist(), sequence_id.tolist()))
            
            segment_duration = data.shape[-1] / solver.args.dset.sample_rate
            for events in batch._event_lists:
                for event in events:
                    if isinstance(event, Word):
                        start = event.start - events[0].start
                        end = start + event.duration
                        if end > 0.02 and start < segment_duration - 0.02:
                            # due to rounding errors, retrieval of related events
                            # can sometime overlap in a non meaningful way, e.g. less than 20ms.
                            # we only consider an event if it overlaps for at least 20ms.
                            sentences.add(event.word_sequence)
                            vocab.add(event.word)
            segments |= set(segment_ids)
#             print(idx, len(loader), end='\r')
#         print(split, "done", " " * 400)
        per_split[split] = (segments, vocab, sentences)
    return per_split


def print_table_line(solver):
    channels = solver.datasets.train[0].meg.shape[0]
    n_subjects = len(set([dataset.recording.subject_uid for dataset in solver.datasets.train.datasets]))
    per_split = _get_segments_and_vocabs(solver)
    assert len(solver.args.dset.selections) == 1
    name = solver.args.dset.selections[0]
    duration = 0.
    for dset in solver.datasets.train.datasets:
        events = dset.recording.events()
        duration += (events.start + events.duration).max()
    
    print(name, channels, '&' , n_subjects, '&', format(duration/ 3600, '.1f') + ' h', end='')
    for split in ('train', 'test'):
        segments, vocab, sentences = per_split[split]
        print('&', len(segments), '&', len(vocab), end='')
    vocab_train = per_split['train'][1]
    vocab_test = per_split['test'][1]
    vocab_overlap = len(vocab_train & vocab_test) / len(vocab_test)
#     print('&', format(vocab_overlap, '.1%'), end='')
    print(r'\\')
    print("Vocab overlap:", format(vocab_overlap, '.1%'))
    
solvers = [play.get_solver_from_sig(sig) for sig in sigs]
print("ALL SOLVERS LOADED")
print("now the table.")
for solver in solvers:
    print_table_line(solver)

INFO:bm.play:Loading solver from XP 34219380. Overrides used: ['model=clip_conv', 'dset.selections=["audio_mous"]', 'seed=2036', 'dset.force_uid_assignement=true']
INFO:bm.dataset:Loading Subjects | 19/96 | 0.85 it/sec
INFO:bm.dataset:Loading Subjects | 38/96 | 1.00 it/sec
INFO:bm.dataset:Loading Subjects | 57/96 | 1.08 it/sec
INFO:bm.dataset:Loading Subjects | 76/96 | 1.12 it/sec
INFO:bm.dataset:Loading Subjects | 95/96 | 1.10 it/sec
INFO:bm.train:Model hash: 3502acedd4c0aad6ce5666c554cf3c70065bec93
INFO:bm.play:Loading solver from XP 6e3bf7d7. Overrides used: ['model=clip_conv', 'dset.selections=["gwilliams2022"]', 'seed=2036', 'optim.lr=0.0003', 'optim.batch_size=128']
INFO:bm.dataset:Loading Subjects | 39/196 | 7.01 it/sec
INFO:bm.dataset:Loading Subjects | 78/196 | 7.13 it/sec
INFO:bm.dataset:Loading Subjects | 117/196 | 7.20 it/sec
INFO:bm.dataset:Loading Subjects | 156/196 | 7.19 it/sec
INFO:bm.dataset:Loading Subjects | 195/196 | 7.23 it/sec
INFO:bm.train:Model hash: 2e99ee0c09

In [None]:
def get_attention_map(solver):
    loader = solver.make_loader(solver.datasets.train)
    batch = next(iter(loader)).to(solver.device)
    model = solver.model
    merger = model.merger
    positions = merger.position_getter.get_positions(batch)
    embedding = merger.embedding(positions)
    meg = batch.meg
    B, C, T = meg.shape
    score_offset = torch.zeros(B, C, device=meg.device)
    score_offset[merger.position_getter.is_invalid(positions)] = float('-inf')
    heads = merger.heads[None].expand(B, -1, -1)
    scores = torch.einsum("bcd,bod->boc", embedding, heads)
    scores += score_offset[:, None]
    weights = torch.softmax(scores, dim=2)
    
    # Weights is of shape [Virtual Channels, Input Channels]
    # Each Virtual Channel is a weighted sum over the input channels.
    # Positions give the normalized 2d position for each Input channel.
    # To get an overall weight for a given input sensor you can for instance do
    # weights[0].sum(dim=0)
    return weights[0], positions[0]


In [None]:
weights.shape, positions.shape