In [None]:
import numpy as np

from nems import Model

# These are in old-NEMS order (i.e. channels first, time second)
spectrogram = np.random.rand(18, 10000)
response = np.random.rand(1, 10000)
single_channel_input = np.random.rand(10000,)
data = {'stimulus': spectrogram, 'response': response}

model = Model.from_keywords('wc.18x1.g-fir.1x15')
_ = model.sample_from_priors()
initial_parameters = model.get_parameter_vector()

In [None]:
# This should raise an error, because WeightChannels and FIR expect time on the
# first axis.
# TODO: make error message more informative.
model.evaluate(spectrogram)

In [None]:
# This should not raise an error, because `Model.evaluate` re-orders the data.
# But the final output should be the same shape as the input, b/c evaluate
# should switch the order back. I *think* this is the most intuitive approach
# since users with re-ordered data would presumably write their other code
# (for plotting, preprocessing, etc) to match the shape of their data.
# And of course, they can always skip the time_axis/channel_axis kwargs and
# just re-order the data beforehand, but I think this is more convenient if
# they're not too familiar with Numpy operations.
# NOTE: Re-ordering the data before fitting will be faster, but it shouldn't be
#       a big difference (an extra ~10 microseconds per loop on some random
#       3D data). Tests on low iteration count showed a ~5% increase in time
#       to fit.
# TODO: add docs about this in `scripts/simple_fit`.
data = model.evaluate(spectrogram, time_axis=1, channel_axis=0)
data['input'].shape

In [None]:
data['output'].shape

In [None]:
# This should raise a ValueError, b/c WeightChannels expects the input to
# have dimension (T, 1) instead of (T,)
wc = model['wc']
wc.evaluate(single_channel_input)

In [None]:
# But here there should be no error, because `Model.evaluate` adds a dummy axis.
just_wc = Model.from_keywords('wc.1x1')
data = just_wc.evaluate(single_channel_input)
data['output'].shape

In [None]:
# Fitting should pass the appropriate options through to `Model.evaluate`.
# TODO: add option to reset model parameters between fits,
#       e.g. Model.initial_parameters()
fitter_options = {'options': {'maxiter': 5, 'ftol': 1e3}}
model.fit(spectrogram, target=response, time_axis=1, channel_axis=0,
          fitter_options=fitter_options)