In [None]:
import numpy as np

from nems import Model
from nems.layers import FIR, STRF

x = np.random.rand(10000, 18)  # fake spectrogram
m = Model.from_keywords('wc.18x4.g-fir.4x15x10')
# TODO: currently all the sample/mean methods still return the samples even if
#       inplace=True (the default for model and layer methods). Did that so that
#       other code wouldn't need to check if something got returned or not, but
#       maybe it would be more intuitive if nothing got returned. Look at pandas
#       code again, see what their default is when inplace=True.
_ = m.sample_from_priors()
m

In [None]:
wc = m.layers[0]
fir = m.layers[1]
y = wc.evaluate(x)
z1 = fir.evaluate(y)
z2 = fir.old_evaluate(y)
np.sum(np.isnan(z1)) + np.sum(np.isnan(z2))  # nans always !=

In [None]:
np.sum(np.round(z1, 9) != np.round(z2, 9))

In [None]:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1,1)
ax.plot(z1[:100,0], c='black')
ax.plot(z2[:100,0], c='red', linestyle='dashed')

In [None]:
# Should work directly on the fake spectrogram
# (FIR would still work too, STRF is just an alias)
strf = STRF(shape=(18,25,4))
out = strf.evaluate(x)
out[0].shape


In [None]:
# TODO: well... testing this case, now I'm thinking it might be better to
#       switch the time-channel specification after all, to match the data.
#       Would certainly be much more intuitive for this example to
#       use shape=(25, 10, 10, 5). Maybe not that big of an issue, but we should
#       definitely decide on that soon before the amount of documentation to
#       update gets a lot larger.
# NOTE: this starts to take a lot longer as the dimensionality increases, but
#       that must be in the scipy implementation (and is probably unavoidable).
#       I did test with 4-D data as well, but for sake of time it's probably
#       fine to just stick with 3D for test scripts.
threeD = np.random.rand(10000, 10, 10)
fir3d = FIR(shape=(10, 25, 10, 5))  # 25 is still time bins, 5 is n filters
out3d = fir3d.evaluate(threeD)
out3d[0].shape  # should be (10000, 5)