Colormaps:
* PiYG, PRGn, BrBG, Puor, RoGy, RoBu, RoYIBu, RoYIGn, Spectral, coolwarm, bwr, seismic

In [1]:
import ipypb
import logging
import warnings
import numpy as np
import matplotlib.pyplot as plt
from hdf5storage import loadmat

import torch
import radam
import torchaudio
import pytorch_lightning as pl

import dynamic_strf

device = torch.device('cuda:0')

logging.getLogger('radam').setLevel(logging.CRITICAL)
logging.getLogger('pytorch_lightning').setLevel(logging.CRITICAL)
logging.getLogger().setLevel(logging.CRITICAL)
warnings.simplefilter('ignore')

In [2]:
top_db = 70
out_sr = 100
freqbins = 64
x_noisy = []
for i in range(19):
    sound, in_sr = torchaudio.load(f'Data/Sounds/stim{i+1}.flac')
    spect = torch.nn.Sequential(
        torchaudio.transforms.MelSpectrogram(in_sr, n_fft=1024, hop_length=int(in_sr/out_sr), f_min=20, f_max=11025/2, n_mels=freqbins, power=2.0),
        torchaudio.transforms.AmplitudeToDB('power', top_db=top_db),
        type("Squeeze", (torch.nn.Module,), dict(forward=lambda self, x: x.squeeze(0).T.float()))()
        #type("Normalize", (torch.nn.Module,), dict(forward=lambda self, x: (x - x.max()).squeeze(0).T.float() / top_db + 1))()
    )(sound)
    x_noisy.append(spect)

y_noisy = []
path_fmt = 'Data/LIJ%s_data_TrainOn1If2Records.mat'
for subj_id in ['109', '110', '112', '113', '114', '120']:
    y_noisy.append(loadmat(path_fmt % subj_id)['noisy_resp'].squeeze(0))
y_noisy = [torch.cat([torch.from_numpy(y[i]) for y in y_noisy], dim=1) for i in range(19)]

for i in range(19):
    diff = len(x_noisy[i]) - len(y_noisy[i])
    if diff == 1:
        x_noisy[i] = x_noisy[i][:-1]
    elif diff > 1:
        raise RuntimeError(f'X and Y have different lengths for stim{i+1}!')
    
    x_noisy[i] = x_noisy[i][100:-50].float()
    y_noisy[i] = y_noisy[i][100:-50].float()

x_noisy = x_noisy[:-1]
y_noisy = y_noisy[:-1]

channels = y_noisy[0].shape[1]

In [None]:
def builder():
    return dynamic_strf.modeling.SharedEncoder(
        input_size=freqbins,
        hidden_size=128,
        channels=channels
    ).to(device)

dynamic_strf.modeling.fit_multiple(
    builder=builder,
    data=(x_noisy, y_noisy),
    crossval=True,
    jackknife=False,
    save_dir='5x128-cv',
    batch_size=64,
    num_workers=4,
    gpus=1,
    precision=16,
    verbose=1
)

model = builder()

dynamic_strf.estimate.dSTRF_multiple(
    model=model,
    checkpoints='5x128-cv',
    data=x_noisy,
    crossval=True,
    jackknife=False,
    save_dir='5x128-cv-dstrf',
    chunk_size=100,
    verbose=1
)

scores = dynamic_strf.modeling.test_multiple(
    modle=builder(),
    checkpoints='5x128-cv',
    data=(x_noisy, y_noisy),
    crossval=True,
    jackknife=False,
    verbose=1
)

nonlin = dynamic_strf.estimate.nonlinearities(
    paths='5x128-cv-dstrf',
    reduction='mean',
    verbose=0
)

Directory "5x128-cv" already exists.
Fitting model for leave out: [0]... Skip.
Fitting model for leave out: [1]... Skip.
Fitting model for leave out: [2]... Skip.
Fitting model for leave out: [3]... Skip.
Fitting model for leave out: [4]... Skip.
Fitting model for leave out: [5]... Skip.
Fitting model for leave out: [6]... Skip.
Fitting model for leave out: [7]... Skip.
Fitting model for leave out: [8]... Skip.
Fitting model for leave out: [9]... Skip.
Fitting model for leave out: [10]... Skip.
Fitting model for leave out: [11]... Skip.
Fitting model for leave out: [12]... Skip.
Fitting model for leave out: [13]... Skip.
Fitting model for leave out: [14]... Skip.
Fitting model for leave out: [15]... Skip.
Fitting model for leave out: [16]... Skip.
Fitting model for leave out: [17]... Skip.
Directory "5x128-cv-dstrf" already exists.
Found 18 model checkpoints in specified directory.
Computing dSTRFs for stimulus 01/18... 

Done.
Computing dSTRFs for stimulus 02/18... 

Done.
Computing dSTRFs for stimulus 03/18... 

Done.
Computing dSTRFs for stimulus 04/18... 

Done.
Computing dSTRFs for stimulus 05/18... 

Done.
Computing dSTRFs for stimulus 06/18... 

Done.
Computing dSTRFs for stimulus 07/18... 

Done.
Computing dSTRFs for stimulus 08/18... 

Done.
Computing dSTRFs for stimulus 09/18... 