In [242]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

In [243]:
class AdditiveDenseLinear(tf.keras.layers.Layer):
    def __init__(self, units, **kwargs):
        super(AdditiveDenseLinear, self).__init__(**kwargs)
        self.units = units

    def build(self, input_shape):
        self.kernel = self.add_weight(
            name="kernel",
            shape=(input_shape[-1], self.units),
            initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.01),
            trainable=True,
        )
        self.bias = self.add_weight(
            name="bias",
            shape=(self.units),
            initializer=tf.keras.initializers.Zeros(),
            trainable=True,
        )
        super(AdditiveDenseLinear, self).build(input_shape)

    def call(self, inputs):
        output = tf.matmul(inputs, self.kernel) + self.bias
        return output

    def get_config(self):
        config = super(AdditiveDenseLinear, self).get_config()
        config.update({"units": self.units})
        return config

In [244]:
class MultiplicativeDenseLinear(tf.keras.layers.Layer):
    def __init__(self, units, **kwargs):
        super(MultiplicativeDenseLinear, self).__init__(**kwargs)
        self.units = units

    def build(self, input_shape):
        self.kernel = self.add_weight(
            name="kernel",
            shape=(self.units, input_shape[-1]),
            initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.0001),
            trainable=True,
        )
        self.bias = self.add_weight(
            name="bias",
            shape=(self.units),
            initializer=tf.keras.initializers.Constant(0.0001),
            trainable=True,
        )
        super(MultiplicativeDenseLinear, self).build(input_shape)

    def call(self, inputs):
        output = (
            tf.reduce_prod(inputs[..., tf.newaxis, :] * self.kernel, axis=-1)
            * self.bias
        )
        return output

    def get_config(self):
        config = super(MultiplicativeDenseLinear, self).get_config()
        config.update({"units": self.units})
        return config

In [245]:
class AdditiveDenseRelu(tf.keras.layers.Layer):
    def __init__(self, units, **kwargs):
        super(AdditiveDenseRelu, self).__init__(**kwargs)
        self.units = units

    def build(self, input_shape):
        self.kernel = self.add_weight(
            name="kernel",
            shape=(input_shape[-1], self.units),
            initializer=tf.keras.initializers.he_uniform(),
            trainable=True,
        )
        self.bias = self.add_weight(
            name="bias",
            shape=(self.units),
            initializer=tf.keras.initializers.Constant(0.1),
            trainable=True,
        )
        super(AdditiveDenseRelu, self).build(input_shape)

    def call(self, inputs):
        output = tf.matmul(inputs, self.kernel) + self.bias
        output = tf.nn.relu(output)
        return output

    def get_config(self):
        config = super(AdditiveDenseRelu, self).get_config()
        config.update({"units": self.units})
        return config

In [246]:
def train_model(inputs, outputs):
    input_layer = tf.keras.layers.Input(shape=inputs.shape[1:])
    hidden_layer = input_layer
    for i in range(2):
        hidden_layer = tf.keras.layers.concatenate(
            [
                # hidden_layer,
                AdditiveDenseLinear(8)(hidden_layer),
                AdditiveDenseRelu(8)(hidden_layer),
                # MultiplicativeDenseLinear(2)(hidden_layer),
            ]
        )
    output_layer = AdditiveDenseLinear(outputs.shape[1:][0])(hidden_layer)
    model = tf.keras.Model(inputs=input_layer, outputs=output_layer)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss=tf.keras.losses.MeanSquaredError(),
    )
    model.fit(inputs, outputs, epochs=20)
    model.summary()
    return model

In [247]:
def train_model_on_function(input_shape, f):
    inputs = tf.random.uniform(
        shape=(10000,) + input_shape, minval=-10, maxval=10, dtype=tf.float32
    )
    outputs = tf.map_fn(f, inputs)
    model = train_model(inputs, outputs)
    return model

In [248]:
train_model_on_function(
    (2,), lambda i: tf.stack([i[0] + i[1] + 3, i[0] - i[1]])
).predict(tf.constant([[1, 2], [3, 4]]))

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Model: "model_33"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_37 (InputLayer)       [(None, 2)]               0         
                                                                 
 additive_dense_linear_32 (  (None, 8)                 24        
 AdditiveDenseLinear)                                            
                                                                 
 concatenate_34 (Concatenat  (None, 8)                 0         
 e)                                                              
                                                                 
 additive_dense_linear_33 (  (None, 8)                 72        
 AdditiveDenseLinear)    

array([[ 6.       , -1.       ],
       [10.       , -0.9999998]], dtype=float32)

In [249]:
train_model_on_function(
    (2,), lambda i: tf.stack([i[0] + i[1] + 3, i[0] * i[1]])
).predict(tf.constant([[2, 3], [4, 5]]))

In [None]:
train_model_on_function(
    (3,), lambda i: tf.stack([i[0] * i[2] + i[1] * i[2], (i[0] + i[1]) * i[2]])
).predict(tf.constant([[2, 3, 4]]))