In [3]:
import jax
from typing import Any, Callable, Sequence
from jax import random, numpy as jnp
import flax
from flax import linen as nn

In [4]:
class ExplicitMLP(nn.Module):
    features: Sequence[int]

    def setup(self):
        # we automatically know what to do with lists, dicts of submodules
        self.layers = [nn.Dense(feat) for feat in self.features]
        # for single submodules, we would just write:
        # self.layer1 = nn.Dense(feat1)

    def __call__(self, inputs):
        x = inputs
        for i, lyr in enumerate(self.layers):
            x = lyr(x)
            if i != len(self.layers) - 1:
                x = nn.relu(x)
        return x


In [5]:
key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4,4))

2024-04-13 21:36:02.765949: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [6]:
model = ExplicitMLP(features=[3,4,5])
params = model.init(key2, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))

initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}


In [7]:
y = model.apply(params, x)
print('output:\n', y)

output:
 [[ 0.          0.          0.          0.          0.        ]
 [ 0.00723789 -0.00810346 -0.02550935  0.02151712 -0.01261239]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]]


In [8]:
x.shape + (2, 3)

(4, 4, 2, 3)

In [9]:
class MLP(nn.Module):
    n_embd: int
    n_inner: int
    p_dropout: float = 0.5
    
    def setup(self):
        self.dense1 = nn.Dense(self.n_inner)
        self.drop1 = nn.Dropout(rate=self.p_dropout)
        self.dense2 = nn.Dense(self.n_embd)
        self.drop2 = nn.Dropout(rate=self.p_dropout)
        
    def __call__(self, x, training: bool):
        x = self.dense1(x)
        x = nn.gelu(x)
        x = self.drop1(x, deterministic=not training)
        x = self.dense2(x)
        x = self.drop2(x, deterministic=not training)
        return x

In [10]:
def f(x, add_one: bool):
    if add_one:
        return x + 1
    else:
        return x

jit_f = jax.jit(f)

In [11]:
jit_f(1, True)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function f at /tmp/ipykernel_4386/3259198140.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument add_one.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

In [None]:
import numpy as np

In [14]:
from transformer import Decoder
from data import Dataloader
import jax
from jax import random
import optax


N_TRAIN = 10000
LEARNING_RATE = 1e-5
BLOCK_SIZE = 32
BATCH_SIZE = 64
N_LAYERS = 3
N_EMBD = 256
HEADS = 8
N_INNER = 512

In [15]:

data = Dataloader(batch_size=BATCH_SIZE, block_size=BLOCK_SIZE)
decoder = Decoder(
    n_layers=N_LAYERS,
    n_vocab=data.n_vocab,
    block_size=BLOCK_SIZE,
    n_embd=N_EMBD,
    heads=HEADS,
    n_inner=N_INNER,
)

key1, key2, dropout_key = random.split(random.key(0), 3)
x = random.randint(key1, (BATCH_SIZE, BLOCK_SIZE), minval=0, maxval=data.n_vocab)
params = decoder.init(key2, x, training=False)

In [18]:
jax.tree_map(lambda x: x.shape, params)

  jax.tree_map(lambda x: x.shape, params)


{'params': {'blocks_0': {'attn': {'qkv': {'bias': (768,),
     'kernel': (256, 768)}},
   'ln1': {'bias': (256,), 'scale': (256,)},
   'ln2': {'bias': (256,), 'scale': (256,)},
   'mlp': {'dense1': {'bias': (512,), 'kernel': (256, 512)},
    'dense2': {'bias': (256,), 'kernel': (512, 256)}}},
  'blocks_1': {'attn': {'qkv': {'bias': (768,), 'kernel': (256, 768)}},
   'ln1': {'bias': (256,), 'scale': (256,)},
   'ln2': {'bias': (256,), 'scale': (256,)},
   'mlp': {'dense1': {'bias': (512,), 'kernel': (256, 512)},
    'dense2': {'bias': (256,), 'kernel': (512, 256)}}},
  'blocks_2': {'attn': {'qkv': {'bias': (768,), 'kernel': (256, 768)}},
   'ln1': {'bias': (256,), 'scale': (256,)},
   'ln2': {'bias': (256,), 'scale': (256,)},
   'mlp': {'dense1': {'bias': (512,), 'kernel': (256, 512)},
    'dense2': {'bias': (256,), 'kernel': (512, 256)}}},
  'final_ln': {'bias': (256,), 'scale': (256,)},
  'logits': {'bias': (65,), 'kernel': (256, 65)},
  'timestep_embd': {'embedding': (32, 256)},
  't

In [19]:
x, y = data.get_batch()

In [20]:
x

Array([[53, 56,  1, ..., 50, 43, 39],
       [50,  1, 63, ..., 40, 43, 44],
       [58, 53, 45, ..., 58,  1, 58],
       ...,
       [43, 58,  1, ..., 43,  1, 58],
       [ 1, 39, 60, ..., 44,  1, 58],
       [46, 43,  1, ..., 41, 53, 52]], dtype=int32)

In [21]:
y.shape

(64, 32)

In [22]:
# training loss fn
logits = decoder.apply(params, x, training=True, rngs={"dropout": dropout_key})
logits = logits.reshape(-1, data.n_vocab)
y = y.reshape(-1, 1)
optax.softmax_cross_entropy(logits, y).mean()

Array(nan, dtype=float32)

In [23]:
y

Array([[56],
       [ 1],
       [46],
       ...,
       [53],
       [52],
       [42]], dtype=int32)

In [24]:
logits

Array([[nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       ...,
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan],
       [nan, nan, nan, ..., nan, nan, nan]], dtype=float32)