In [2]:
import jax
from typing import Any, Callable, Sequence
from jax import random, numpy as jnp
import flax
from flax import linen as nn

In [3]:
class ExplicitMLP(nn.Module):
    features: Sequence[int]

    def setup(self):
        # we automatically know what to do with lists, dicts of submodules
        self.layers = [nn.Dense(feat) for feat in self.features]
        # for single submodules, we would just write:
        # self.layer1 = nn.Dense(feat1)

    def __call__(self, inputs):
        x = inputs
        for i, lyr in enumerate(self.layers):
            x = lyr(x)
            if i != len(self.layers) - 1:
                x = nn.relu(x)
        return x


In [9]:
key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4,4))

In [13]:
model = ExplicitMLP(features=[3,4,5])
params = model.init(key2, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))

initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}


In [14]:
y = model.apply(params, x)
print('output:\n', y)

output:
 [[ 0.          0.          0.          0.          0.        ]
 [ 0.0072379  -0.00810347 -0.02550939  0.02151716 -0.01261241]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]]


In [20]:
x.shape + (2, 3)

(4, 4, 2, 3)

In [16]:
class MLP(nn.Module):
    n_embd: int
    n_inner: int
    p_dropout: float = 0.5
    
    def setup(self):
        self.dense1 = nn.Dense(self.n_inner)
        self.drop1 = nn.Dropout(rate=self.p_dropout)
        self.dense2 = nn.Dense(self.n_embd)
        self.drop2 = nn.Dropout(rate=self.p_dropout)
        
    def __call__(self, x, training: bool):
        x = self.dense1(x)
        x = nn.gelu(x)
        x = self.drop1(x, deterministic=not training)
        x = self.dense2(x)
        x = self.drop2(x, deterministic=not training)
        return x

In [17]:
def f(x, add_one: bool):
    if add_one:
        return x + 1
    else:
        return x

jit_f = jax.jit(f)

In [18]:
jit_f(1, True)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function f at /var/folders/t8/dfjnqkz511g38qs8fqdv507w0000gn/T/ipykernel_73814/3259198140.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument add_one.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

In [21]:
import numpy as np

In [25]:
a = np.zeros((2, 4, 3, 5))

In [26]:
a[..., :] = np.arange(5)

In [34]:
a[-1, -1, -1, :] = 0

In [35]:
a

array([[[[0., 1., 2., 3., 4.],
         [0., 1., 2., 3., 4.],
         [0., 1., 2., 3., 4.]],

        [[0., 1., 2., 3., 4.],
         [0., 1., 2., 3., 4.],
         [0., 1., 2., 3., 4.]],

        [[0., 1., 2., 3., 4.],
         [0., 1., 2., 3., 4.],
         [0., 1., 2., 3., 4.]],

        [[0., 1., 2., 3., 4.],
         [0., 1., 2., 3., 4.],
         [0., 1., 2., 3., 4.]]],


       [[[0., 1., 2., 3., 4.],
         [0., 1., 2., 3., 4.],
         [0., 1., 2., 3., 4.]],

        [[0., 1., 2., 3., 4.],
         [0., 1., 2., 3., 4.],
         [0., 1., 2., 3., 4.]],

        [[0., 1., 2., 3., 4.],
         [0., 1., 2., 3., 4.],
         [0., 1., 2., 3., 4.]],

        [[0., 1., 2., 3., 4.],
         [0., 1., 2., 3., 4.],
         [0., 0., 0., 0., 0.]]]])

In [28]:
b = np.zeros((2, 5, 3, 4))

In [29]:
b[:] = np.arange(5).reshape((1, 5, 1, 1))

In [30]:
b

array([[[[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[2., 2., 2., 2.],
         [2., 2., 2., 2.],
         [2., 2., 2., 2.]],

        [[3., 3., 3., 3.],
         [3., 3., 3., 3.],
         [3., 3., 3., 3.]],

        [[4., 4., 4., 4.],
         [4., 4., 4., 4.],
         [4., 4., 4., 4.]]],


       [[[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[2., 2., 2., 2.],
         [2., 2., 2., 2.],
         [2., 2., 2., 2.]],

        [[3., 3., 3., 3.],
         [3., 3., 3., 3.],
         [3., 3., 3., 3.]],

        [[4., 4., 4., 4.],
         [4., 4., 4., 4.],
         [4., 4., 4., 4.]]]])

In [36]:
c = np.einsum('bthd,bdhT->bhtT', a, b)

In [37]:
c

array([[[[30., 30., 30., 30.],
         [30., 30., 30., 30.],
         [30., 30., 30., 30.],
         [30., 30., 30., 30.]],

        [[30., 30., 30., 30.],
         [30., 30., 30., 30.],
         [30., 30., 30., 30.],
         [30., 30., 30., 30.]],

        [[30., 30., 30., 30.],
         [30., 30., 30., 30.],
         [30., 30., 30., 30.],
         [30., 30., 30., 30.]]],


       [[[30., 30., 30., 30.],
         [30., 30., 30., 30.],
         [30., 30., 30., 30.],
         [30., 30., 30., 30.]],

        [[30., 30., 30., 30.],
         [30., 30., 30., 30.],
         [30., 30., 30., 30.],
         [30., 30., 30., 30.]],

        [[30., 30., 30., 30.],
         [30., 30., 30., 30.],
         [30., 30., 30., 30.],
         [ 0.,  0.,  0.,  0.]]]])

In [33]:
1 + 4 + 9 + 16

30