In [1]:
import trax
from trax import layers as tl

INFO:tensorflow:tokens_length=568 inputs_length=512 targets_length=114 noise_density=0.15 mean_noise_span_length=3.0 


In [2]:
mlp = tl.Serial(
    tl.Dense(128),
    tl.Relu(),
    tl.Dense(10),
    tl.LogSoftmax()
)
print(mlp)

Serial[
  Dense_128
  Relu
  Dense_10
  LogSoftmax
]


## GRU MODEL
- **ShiftRight** Shifts the tensor to the right by padding on axis 1. The mode should be specified and it refers to the context in which the model is being used. Possible values are `train`, `eval` or `predict`, predict mode is for fast inference. Defaults to train
- **Embedding** Maps discrete tokens to vectors. It will have shape `(vocabulary length X dimension of output vectors)`. The dimension of output vectors is the number of elements in the word embedding
- **GRU** The GRU layer, it leverages another Tax layer called GRUCell. The number of GRU units should be specified and should match the number of elements in the word embedding. If you want to stack two consecutive GRU layers, it can be done by using python's list comprehension
- **Dense** Vanilla Dense Layer
- **LogSoftMax** Log Softmax function

In [3]:
mode = 'train'
vocab_size = 256
model_dimension = 512
n_layers = 2

GRU = tl.Serial(
    tl.ShiftRight(mode=mode), # mode parameter to be passed if it is used for inference/test as default
    tl.Embedding(vocab_size=vocab_size, d_feature=model_dimension),
    [tl.GRU(n_units=model_dimension) for _ in range(n_layers)],
    tl.Dense(n_units=vocab_size),
    tl.LogSoftmax()
)

In [4]:
def show_layers(model, layer_prefix="Serial.sublayers"):
    print(f'Total layers: {len(model.sublayers)}\n')
    for i in range(len(model.sublayers)):
        print('=============')
        print(f'{layer_prefix}_{i}: {model.sublayers[i]}\n')
        
show_layers(GRU)

Total layers: 6

Serial.sublayers_0: ShiftRight(1)

Serial.sublayers_1: Embedding_256_512

Serial.sublayers_2: GRU_512

Serial.sublayers_3: GRU_512

Serial.sublayers_4: Dense_256

Serial.sublayers_5: LogSoftmax



In [None]:
parallelism
dag_concurrency