Colormaps:
* PiYG, PRGn, BrBG, Puor, RoGy, RoBu, RoYIBu, RoYIGn, Spectral, coolwarm, bwr, seismic

In [1]:
import ipypb
import logging
import warnings
import numpy as np
import matplotlib.pyplot as plt
from hdf5storage import loadmat

import torch
import radam
import torchaudio
import pytorch_lightning as pl

import dynamic_strf

device = torch.device('cuda:0')

logging.getLogger('radam').setLevel(logging.CRITICAL)
logging.getLogger('pytorch_lightning').setLevel(logging.CRITICAL)
logging.getLogger().setLevel(logging.CRITICAL)
warnings.simplefilter('ignore')

In [18]:
top_db = 70
out_sr = 100
freqbins = 64
x_noisy = []
for i in range(19):
    sound, in_sr = torchaudio.load(f'Data/Sounds/stim{i+1}.flac')
    spect = dynamic_strf.modeling.SpectrogramParser(
        in_sr, out_sr, freqbins, f_max=11025/2, normalize=False
    )(sound)
    x_noisy.append(spect)

y_noisy = []
path_fmt = 'Data/LIJ%s_data_TrainOn1If2Records.mat'
for subj_id in ['109', '110', '112', '113', '114', '120']:
    y_noisy.append(loadmat(path_fmt % subj_id)['noisy_resp'].squeeze(0))
y_noisy = [torch.cat([torch.from_numpy(y[i]) for y in y_noisy], dim=1) for i in range(19)]

for i in range(19):
    diff = len(x_noisy[i]) - len(y_noisy[i])
    if diff == 1:
        x_noisy[i] = x_noisy[i][:-1]
    elif diff > 1:
        raise RuntimeError(f'X and Y have different lengths for stim{i+1}!')
    
    x_noisy[i] = x_noisy[i][100:-50].float()
    y_noisy[i] = y_noisy[i][100:-50].float()

x_noisy = x_noisy[:-1]
y_noisy = y_noisy[:-1]

channels = y_noisy[0].shape[1]

In [None]:
def builder():
    return dynamic_strf.modeling.DeepEncoder(
        input_size=freqbins,
        hidden_size=128,
        channels=channels
    ).to(device)

dynamic_strf.modeling.fit_multiple(
    builder=builder,
    data=(x_noisy, y_noisy),
    crossval=True,
    jackknife=False,
    save_dir='output/5x128-cv',
    batch_size=64,
    num_workers=4,
    gpus=1,
    precision=16,
    verbose=1
)

model = builder()

dynamic_strf.estimate.dSTRF_multiple(
    model=model,
    checkpoints='output/5x128-cv',
    data=x_noisy,
    crossval=True,
    jackknife=False,
    save_dir='output/5x128-cv-dstrf',
    chunk_size=100,
    verbose=1
)

scores = dynamic_strf.modeling.test_multiple(
    model=builder(),
    checkpoints='output/5x128-cv',
    data=(x_noisy, y_noisy),
    crossval=True,
    jackknife_mode='pred',
    verbose=1
)

nonlin = dynamic_strf.estimate.nonlinearities(
    paths='output/5x128-cv-dstrf',
    reduction='mean',
    verbose=0
)