In [1]:
import tensorflow as tf


In [3]:
X = tf.random.uniform((2, 20))

In [5]:
class MLP(tf.keras.Model):
    
    def __init__(self) -> None:
        super().__init__()
        self.hidden = tf.keras.layers.Dense(units=256, activation=tf.nn.relu)
        self.out = tf.keras.layers.Dense(units= 10)
        
    def call(self, X):
        return self.out(self.hidden(X))

In [6]:
net = MLP()
net(X)

<tf.Tensor: shape=(2, 10), dtype=float32, numpy=
array([[ 0.17279005,  0.04906879, -0.03388657,  0.09934852,  0.15473732,
         0.18375075, -0.26239842, -0.032392  , -0.08150564,  0.0077733 ],
       [ 0.14707191,  0.05371407, -0.33504188,  0.16726816,  0.06553102,
        -0.09037317, -0.0247557 ,  0.38117188, -0.10275548,  0.04035759]],
      dtype=float32)>

In [13]:
class MySequential(tf.keras.Model):
    def __init__(self, *args):
        super().__init__()
        self.modules = []
        for block in args:
            self.modules.append(block)
    
    def call(self, X):
        for module in self.modules:
            X = module(X)
        return X    

In [14]:
net = MySequential(
    tf.keras.layers.Dense(units = 256, activation=tf.nn.relu),
    tf.keras.layers.Dense(10)
)

net(X)

<tf.Tensor: shape=(2, 10), dtype=float32, numpy=
array([[-0.03313748, -0.07171223, -0.14724204,  0.02923996,  0.16463302,
         0.31647795,  0.00710964,  0.01459003,  0.19680715,  0.478291  ],
       [-0.06661346,  0.02348321, -0.02184349,  0.1182263 ,  0.1842809 ,
        -0.00726774, -0.06038144,  0.02656977,  0.02565152,  0.22776687]],
      dtype=float32)>

In [15]:
class FixedHiddenMLP(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.flatten = tf.keras.layers.Flatten()
        self.rand_weight = tf.constant(tf.random.uniform((20, 20)))
        self.dense = tf.keras.layers.Dense(20, activation = tf.nn.relu)
    
    def call(self, inputs):
        X = self.flatten(inputs)
        X = tf.nn.relu(tf.matmul(X, self.rand_weight) + 1)
        X = self.dense(X)
        while tf.reduce_sum(tf.math.abs(X))  > 1:
            X /= 2
        return tf.reduce_sum(X)

In [16]:
net = FixedHiddenMLP()
net(X)

<tf.Tensor: shape=(), dtype=float32, numpy=0.76809734>

In [17]:
class NestMLP(tf.keras.Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.net = tf.keras.Sequential()
        self.net.add(tf.keras.layers.Dense(64, activation=tf.nn.relu))
        self.net.add(tf.keras.layers.Dense(32, activation=tf.nn.relu))
        self.dense = (tf.keras.layers.Dense(16, activation=tf.nn.relu))
    
    def call(self, inputs):
        return self.dense(self.net(inputs))

chimera = tf.keras.Sequential()
chimera.add(NestMLP())
chimera.add(tf.keras.layers.Dense(20))
chimera.add(FixedHiddenMLP())
chimera(X)

<tf.Tensor: shape=(), dtype=float32, numpy=0.76128745>