Generate NEMS models with sampled parameters.

Pull out coefficients and reshape them the same way `FIR.evaluate` does.

In [None]:
import numpy as np
from nems import Model

rank = 2
n_outputs = 1
filter_width = 1
n_samples = 10

model = Model.from_keywords(f'fir.{filter_width}x{rank}')
samples = model.sample_from_priors(n=n_samples)
coefs = [m.get_parameter_values()['fir']['coefficients'] for m in samples]
# add output dim, flip rank and time
coefs = [np.flip(c[..., np.newaxis]) for c in coefs]

coefs[0].shape

Define random input and target, where target should be a very easy fit.

In [None]:
inputs = np.random.rand(1, 300, rank).astype(np.float32)
target = inputs[...,:1]/2 + np.random.randn(inputs.shape[0], 1)/100

Define single keras Layer for FIR, loss function, early stopping

In [None]:
import tensorflow as tf

@tf.function
def trivial_broadcast_inputs(inputs):
    return inputs

@tf.function
def trivial_broadcast_kernel(kernel):
    return kernel

@tf.function
def convolve(inputs, kernel):
    input_width = tf.shape(inputs)[1]
    # Reshape will group by output before rank w/o transpose.
    transposed = tf.transpose(inputs, [0, 1, 3, 2])
    # Collapse rank and n_outputs to one dimension.
    # -1 for batch size b/c it can be None.
    reshaped = tf.reshape(
        transposed, [-1, input_width, rank*n_outputs]
        )
    padded_input = tf.pad(
        reshaped, [[0, 0], [filter_width-1, 0], [0, 0]]
        )
    return tf.nn.conv1d(padded_input, kernel, stride=1, padding='VALID')

class SimpleFIR(tf.keras.layers.Layer):
    def __init__(self, shape):
        super().__init__(name='fir')
        constraint = lambda t : tf.clip_by_value(t, -np.inf, np.inf)
        self.kernel = self.add_weight(
            name='kernel', shape=shape, trainable=True, constraint=constraint
            )
    def call(self, inputs):
        input_width = tf.shape(inputs)[1]
        inputs = trivial_broadcast_inputs(inputs)
        kernel = trivial_broadcast_kernel(self.kernel)
        # Make None shape explicit
        rank_4 = tf.reshape(inputs, [-1, input_width, rank, n_outputs])
        return convolve(rank_4, kernel)


# loss function
def tf_nmse(response, prediction, per_cell=True):
    _response = response
    _prediction = prediction

    # Put last dimension (number of output channels) first.
    _response = tf.transpose(_response, np.roll(np.arange(len(response.shape)), 1))
    _prediction = tf.transpose(_prediction, np.roll(np.arange(len(response.shape)), 1))
    # Why the reshaping?
    _response = tf.reshape(_response, shape=(_response.shape[0], 10, -1))
    _prediction = tf.reshape(_prediction, shape=(_prediction.shape[0], 10, -1))

    squared_error = ((_response - _prediction) ** 2)
    nmses = (tf.math.reduce_mean(squared_error, axis=-1) /
             tf.math.reduce_mean(_response**2, axis=-1)) ** 0.5

    mE = tf.math.reduce_mean(nmses, axis=-1)
    # Hard-coded 10 again? Why?
    sE = tf.math.reduce_std(nmses, axis=-1) / 10 ** 0.5

    return mE, sE


# early stopping
class DelayedStopper(tf.keras.callbacks.EarlyStopping):
    """Early stopper that waits before kicking in."""
    def __init__(self, start_epoch=100, **kwargs):
        super(DelayedStopper, self).__init__(**kwargs)
        self.start_epoch = start_epoch

    def on_epoch_end(self, epoch, logs=None):
        if epoch > self.start_epoch:
            super().on_epoch_end(epoch, logs)

Compile model with Adam optimizer, nmse loss.

In [None]:


tf_model = tf.keras.Sequential()
tf_model.add(SimpleFIR(coefs[0].shape))
tf_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-2),
    loss=tf_nmse
)

callbacks = [
    DelayedStopper(
        monitor='loss', patience=10, min_delta=1e-3, verbose=1,
        restore_best_weights=True, start_epoch=10,
        )
    ]

Iterate over random initial conditions, fit and save prediction correlation.

In [None]:
ccs = []
for c in coefs:
    tf_model.get_layer('fir').set_weights([c])
    tf_model.fit(
        inputs, target, epochs=200, callbacks=callbacks,
        )
    prediction = tf_model.predict(inputs)
    ccs.append(np.corrcoef(prediction[..., 0], target[..., 0])[0, 1])


Print results.

In [None]:
print(ccs)