In [None]:
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np

from nems import Model
from nems.tf.model import build_model
from nems.layers import FIR

fir = FIR(shape=(1, 15), name='fir')
fir_tf = fir.as_tensorflow_layer()()

In [None]:
n_units = 4
n_out = 4

x = np.random.rand(10, 1000, 18)  # 10 stimuli (as spectrogram) 1000 bins each
y = np.random.rand(10, 1000, n_out)   # 10 corresponding responses
data = {'input': x}#, 'target': y}  # only include inputs (stim, state, etc)

# TODO: old NEMS fir would have 1x15xN, but I think that only worked b/c of
#       two of the dimensions being multiplied together for coefficients.
#       Pretty sure this should be the same effect?
nems_model = Model.from_keywords(f'wc.18x{n_units}.g-fir.{n_units}x15x{n_units}')
fir_tf = nems_model['fir'].as_tensorflow_layer()()
# Check that input shape gets evaluated correctly
fir_tf.call(keras.Input(shape=(1000, 4), name='test', dtype='float32'))
# TODO: not sure what the Lambda layer warning is about, but it's not happening
#       during the actual model evaluation so maybe nothing to worry about.

In [None]:
tf_layers = [layer.as_tensorflow_layer()() for layer in nems_model.layers]
tf_model = build_model(
    nems_model, tf_layers, data, eval_kwargs={'input_name': 'input'},
    )

In [None]:
# TODO: move NEMS Model.__repr__ to Model.summary() instead of Model.__str__?
#       (or maybe both, but still switch __repr__ to compact version)
tf_model.summary()  

In [None]:
tf_model.compile(
    optimizer=keras.optimizers.Adam(),
    loss={
        'fir': keras.losses.MeanSquaredError()
    }
)

In [None]:
tf_model.fit(
    data, {'fir': y}
)

In [None]:
tf_model.predict(x).shape

In [None]:
tf_model.layers[2].weights_to_values().shape