# Setup

In [1]:
import ipypb
import logging
import warnings
import numpy as np
import matplotlib.pyplot as plt
from hdf5storage import loadmat

import torch
import torchaudio
import pytorch_lightning as pl

import dynamic_strf

device = torch.device('cuda:0')

logging.getLogger().setLevel(logging.CRITICAL)
logging.getLogger('pytorch_lightning').setLevel(logging.CRITICAL)
warnings.simplefilter('ignore')

# Data preparation

In [None]:
spect_cfg = dict(
    out_sr = 100,
    freqbins = 64,
    f_max = 11025/2
)

X = []
for i in range(18):
    sound, in_sr = torchaudio.load(f'Data/Sounds/stim{i+1}.flac')
    spect = dynamic_strf.modeling.SpectrogramParser(in_sr, **spect_cfg)(sound)
    X.append(spect)

Y = []
path_fmt = 'Data/LIJ%s_data_TrainOn1If2Records.mat'
for subj_id in ['109', '110', '112', '113', '114', '120']:
    Y.append(loadmat(path_fmt % subj_id)['noisy_resp'].squeeze(0))
Y = [torch.cat([torch.from_numpy(y[i]) for y in Y], dim=1) for i in range(18)]

for i in range(18):
    diff = len(X[i]) - len(Y[i])
    if diff == 1:
        X[i] = X[i][:-1]
    elif diff > 1:
        raise RuntimeError(f'X and Y have different lengths for stim{i+1}!')
    
    X[i] = X[i][100:-50].float()
    Y[i] = Y[i][100:-50].float()

channels = Y[0].shape[1]

# Training configuration

## 1. Dictionaries

In [None]:
optimizer_cfg = {
    
}

scheduler_cfg = {
    
}

trainer = {
    'gpus' = 1,
    'precision' = 16,
}

## 2. Builder functions

In [None]:
def optimizer_cfg(params):
    return torch.optim.AdamW(
        params,
        ...
    )

def scheduler_cfg(optimizer):
    return torch.optim.lr_scheduler.ExponentialLR(
        optimizer,
        ...
    )

def trainer():
    return pl.Trainer(
        gpus = 1,
        precision = 16,
        ...
    )

# Model definition

## 1. Deep Convolutional

In [None]:
def builder():
    return dynamic_strf.modeling.DeepEncoder(
        input_size=freqbins,
        hidden_size=128,
        channels=channels,
        optimizer=optimizer,
        scheduler=scheduler
    ).to(device)

## 2. Linear

In [None]:
def builder():
    return dynamic_strf.modeling.LinearEncoder(
        input_size=freqbins,
        channels=channels,
        optimizer=optimizer,
        scheduler=scheduler
    ).to(device)

# Training

In [None]:
dataset_cfg = dict(
    batch_size = 64,
    num_workers = 4,
)

## 1. Basic

In [None]:
dynamic_strf.modeling.fit_multiple(
    builder=builder,
    data=(x_noisy, y_noisy),
    save_dir='output/models-basic',
    **train_cfg
)

## 2. With cross-validation

In [None]:
dynamic_strf.modeling.fit_multiple(
    builder=builder,
    data=(x_noisy, y_noisy),
    crossval=True,
    save_dir='output/models-cv',
    **train_cfg
)

## 3. With jackknifing

In [None]:
dynamic_strf.modeling.fit_multiple(
    builder=builder,
    data=(x_noisy, y_noisy),
    jackknife=True,
    save_dir='output/models-jackknife',
    **train_cfg
)

## 4. With cross-validation and jackknifing

In [None]:
dynamic_strf.modeling.fit_multiple(
    builder=builder,
    data=(x_noisy, y_noisy),
    crossval=True,
    jackknife=True,
    save_dir='output/models-jackknife-cv',
    **train_cfg
)

# Evaluation

In [None]:
scores = dynamic_strf.modeling.test_multiple(
    modle=builder(),
    checkpoints='output/5x128-jackknife-cv',
    data=(x_noisy, y_noisy),
    crossval=True,
    jackknife=True,
    verbose=1
)

In [None]:
preds = dynamic_strf.modeling.infer_multiple(
    model=builder(),
    checkpoints='output/linear-jackknife-cv',
    data=x_noisy,
    crossval=True,
    jackknife_mode='pred',
    verbose=1,
)

# Estimating dSTRFs

In [None]:
dynamic_strf.estimate.dSTRF_multiple(
    model=builder(),
    checkpoints='output/models-jackknife-cv',
    data=(x_noisy, y_noisy),
    crossval=True,
    jackknife=True,
    save_dir='output/models-jackknife-cv-dstrf',
    chunk_size=100,
    verbose=1
)

# Estimating dSTRF nonlinearities

In [None]:
nonlin = dynamic_strf.estimate.nonlinearities(
    paths='output/models-jackknife-cv-dstrf',
    reduction='mean',
    verbose=0
)

# Visualizing dSTRFs

In [None]:
dynamic_strf.visualize.dSTRF(
    'output/5x128-jackknife-cv-dstrf/dSTRF-000.pt',
    channels=slice(8, 12),
    time_range=slice(1000, 2500),
    aspect=2,
    output_prefix='mov/trial-1-',
    verbose=2
)