Derived from the docs' tutorial <a href="https://trax-ml.readthedocs.io/en/latest/notebooks/trax_intro.html">here</a>.

In [2]:
#@title
# Copyright 2020 Google LLC.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [10]:
import numpy as np
import trax
from   trax import layers as tl
from   trax.fastmath import numpy as fastnp

trax.fastmath.use_backend('jax'); # or 'tensorflow-numpy'

In [5]:
M = fastnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(f'Matrix:\n{M}')

v = fastnp.ones(3)
print(f'Vector: {v}')

dot_prod = fastnp.dot(v, M)
print(f'Dot product: {dot_prod}')

tanh = fastnp.tanh(dot_prod)
print(f'tanh(prod): {tanh}')

Matrix:
[[1 2 3]
 [4 5 6]
 [7 8 9]]
Vector: [1. 1. 1.]
Dot product: [12. 15. 18.]
tanh(prod): [1. 1. 1.]




In [6]:
def f(x):
    return 2. * x * x

In [7]:
grad_f = trax.fastmath.grad(f)

print(f'grad(2x^2) at 1: {grad_f(1.)}')
print(f'grad(2x^2) at -2: {grad_f(-2.)}')

grad(2x^2) at 1: 4.0
grad(2x^2) at -2: -8.0


In [11]:
x = np.arange(15)
print(f'x: {x}')
      
# Create embedding layer
embedding = tl.Embedding(vocab_size=20, d_feature=32)
embedding.init(trax.shapes.signature(x))

# Run the layer -- y = embedding(x)
y = embedding(x)
print(f'y.shape: {y.shape}')

x: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14]
y.shape: (15, 32)
