In [1]:
import sys
sys.path.append("..")

import tensorflow as tf

In [2]:
import layers
import schema

In [3]:
layers.MLP

<function layers.MLP(hidden_dim: int, output_dim: int, hidden_activation: Union[str, Callable] = None, output_activation: Union[str, Callable] = None, kernel_initializer: Union[str, Callable] = 'glorot_uniform', bias_initializer: Union[str, Callable] = 'zeros', kernel_regularizer: Union[str, Callable] = None, bias_regularizer: Union[str, Callable] = None, activity_regularizer: Union[str, Callable] = None, kernel_constraint: Union[str, Callable] = None, bias_constraint: Union[str, Callable] = None, **kwargs) -> keras.engine.training.Model>

In [4]:
BATCH_SIZE = 32
INPUT_DIM = 78

mlp = layers.MLP(
    hidden_dim=128,
    output_dim=64,
    hidden_activation=tf.nn.gelu,
    output_activation=tf.nn.softmax
)

batch_mlp = tf.random.normal((BATCH_SIZE, INPUT_DIM))
output_mlp = mlp(batch_mlp)

output_mlp.shape

TensorShape([32, 64])

In [32]:
stacked_batch = tf.stack([output_mlp, output_mlp], axis=1)
stacked_batch.shape

TensorShape([32, 2, 64])

In [44]:
cls_embedding = tf.tile(tf.reshape(tf.random.normal((64,)), (1, 1, 64)), tf.constant([32, 1, 1], dtype=tf.int32))
cls_embedding.shape

TensorShape([32, 1, 64])

In [50]:
embeddings = tf.concat([cls_embedding, stacked_batch], axis=1)
embeddings.shape

TensorShape([32, 3, 64])

In [49]:
embeddings[:,0,:].shape

TensorShape([3, 64])

In [30]:
tf.stack([tf.tile(tf.expand_dims(tf.random.normal((64,)), axis=0), tf.constant([32, 1], tf.int32))] + [output_mlp, output_mlp], axis=1).shape

TensorShape([32, 3, 64])

In [17]:
mlp.name, getattr(mlp, "name")

('sequential', 'sequential')

In [25]:
tf.repeat(tf.random.normal((1, 64)), (32,))

<tf.Tensor: shape=(2048,), dtype=float32, numpy=
array([-0.37262163, -0.37262163, -0.37262163, ...,  0.44976303,
        0.44976303,  0.44976303], dtype=float32)>

In [28]:
tf.tile(tf.expand_dims(tf.random.normal((64,)), axis=0), tf.constant([32, 1], tf.int32))

<tf.Tensor: shape=(32, 64), dtype=float32, numpy=
array([[ 0.20502977,  0.27224723, -1.043306  , ...,  1.8665898 ,
        -0.65770984, -0.5635294 ],
       [ 0.20502977,  0.27224723, -1.043306  , ...,  1.8665898 ,
        -0.65770984, -0.5635294 ],
       [ 0.20502977,  0.27224723, -1.043306  , ...,  1.8665898 ,
        -0.65770984, -0.5635294 ],
       ...,
       [ 0.20502977,  0.27224723, -1.043306  , ...,  1.8665898 ,
        -0.65770984, -0.5635294 ],
       [ 0.20502977,  0.27224723, -1.043306  , ...,  1.8665898 ,
        -0.65770984, -0.5635294 ],
       [ 0.20502977,  0.27224723, -1.043306  , ...,  1.8665898 ,
        -0.65770984, -0.5635294 ]], dtype=float32)>

In [29]:
tf.expand_dims(tf.random.normal((64,)), axis=0)

<tf.Tensor: shape=(1, 64), dtype=float32, numpy=
array([[ 1.8356962 , -0.2577597 ,  1.7000769 , -1.1901748 , -0.6959799 ,
        -0.16714376,  0.57936007,  1.05844   , -0.32515067, -0.26533476,
         0.05391934, -1.5660919 ,  0.62879413,  0.4333685 ,  0.708092  ,
         0.28924763, -0.46051058,  0.412034  , -0.06868848,  0.01341627,
         0.96087265, -1.0207053 , -0.9466133 ,  0.5527718 ,  1.2582401 ,
        -0.80828524, -0.81083   ,  0.42155948,  0.09648564, -1.2208761 ,
        -0.4809261 ,  2.0067809 , -0.28203112, -0.6730093 ,  0.09928842,
        -0.11531276, -0.85903454, -2.4250076 ,  0.58707553, -0.12709561,
         1.0053155 , -2.2924817 ,  0.13307332, -0.28513092,  0.15451261,
        -0.26090774, -0.01895636, -1.8103921 ,  0.2175087 ,  0.08694499,
         1.9910007 ,  0.9818912 , -0.673792  ,  1.3917378 ,  0.50340515,
         0.16332895, -1.5576229 ,  0.58508766,  1.2754598 ,  2.1077392 ,
        -0.26249802, -0.9809261 , -0.58303297, -1.8642876 ]],
      dtype=f

In [52]:
BATCH_SIZE = 8
N_FEATURES = 5
FEATURE_DIM = 16

batch = tf.tile(
    tf.expand_dims(
        tf.expand_dims(tf.range(BATCH_SIZE), axis=1),
        axis=2
    ),
    tf.constant([1, N_FEATURES, FEATURE_DIM], tf.int32)
)
batch

<tf.Tensor: shape=(8, 5, 16), dtype=int32, numpy=
array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

       [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],

       [[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]],

       [[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
        [3, 3, 3

In [71]:
batch[:, 0, :]

<tf.Tensor: shape=(8, 16), dtype=int32, numpy=
array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
       [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
       [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
       [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
       [5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
       [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6],
       [7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7]], dtype=int32)>

In [75]:
batch + tf.ones((5, 16), dtype=tf.int32)

<tf.Tensor: shape=(8, 5, 16), dtype=int32, numpy=
array([[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],

       [[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]],

       [[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
        [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
        [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
        [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
        [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]],

       [[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
        [4, 4, 4

In [78]:
tf.concat([tf.zeros((1, 16)), tf.ones((4, 16))], axis=0)

<tf.Tensor: shape=(5, 16), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],
      dtype=float32)>

In [82]:
added_batch = tf.concat([tf.zeros((1, 16), dtype=tf.int32), tf.ones((4, 16), dtype=tf.int32)], axis=0)
added_batch

<tf.Tensor: shape=(5, 16), dtype=int32, numpy=
array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=int32)>

In [84]:
new_batch = batch + added_batch
new_batch

<tf.Tensor: shape=(8, 5, 16), dtype=int32, numpy=
array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],

       [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]],

       [[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
        [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
        [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
        [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]],

       [[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
        [4, 4, 4

In [93]:
new_batch[:, 0, :]

<tf.Tensor: shape=(8, 16), dtype=int32, numpy=
array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
       [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
       [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
       [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
       [5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
       [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6],
       [7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7]], dtype=int32)>

In [53]:
tf.keras.layers.Flatten()(batch).shape

TensorShape([8, 80])

In [62]:
class A:
    def __init__(self, a):
        self.A = a

In [63]:
test = A(6)

In [64]:
test.A

6