In [None]:
import equinox as eqx
import jax
import equinox.nn as nn


class text(eqx.Module):
    first: nn.Linear
    second: nn.Conv1d

    def __init__(self):
        self.first = nn.Linear(1, 2, use_bias=True, key=jax.random.key(1))
        self.second = nn.Conv1d(2, 2, 1, use_bias=True, key=jax.random.key(2))

In [4]:
model = text()
print(model)

text(
  first=Linear(
    weight=f32[2,1],
    bias=f32[2],
    in_features=1,
    out_features=2,
    use_bias=True
  ),
  second=Conv1d(
    num_spatial_dims=1,
    weight=f32[2,2,1],
    bias=f32[2,1],
    in_channels=2,
    out_channels=2,
    kernel_size=(1,),
    stride=(1,),
    padding=((0, 0),),
    dilation=(1,),
    groups=1,
    use_bias=True,
    padding_mode='ZEROS'
  )
)


In [8]:
import jax.numpy as jnp

is_linear = lambda x: isinstance(x, nn.Linear)

mean = 1
std = 2

key = jax.random.key(3)


def hop(x):
    global key
    key, grab = jax.random.split(key)
    y = nn.Linear(1, 3, key=grab)

    y.weight = jax.nn.initializers.normal(std)(grab, x.weight.shape) + mean
    y.bias = jnp.zeros_like(x.bias)
    return y


print(jax.tree_util.tree_map(hop, model, is_leaf=is_linear))

FrozenInstanceError: cannot assign to field 'weight'

In [12]:
class CausalSelfAttention(eqx.Module):
    c_attn: nn.Linear
    c_proj: nn.Linear

    attn_dropout: nn.Dropout
    resid_dropout: nn.Dropout

    mask: jax.Array = eqx.field(static=True)

    def __init__(self, n_embd=786, bias=True, dropout=0.1, block_size=1024, key=None):
        key1, key2 = jax.random.split(key)
        # Projection for W_1, W_2, W_3 in a batch
        self.c_attn = nn.Linear(n_embd, n_embd * 3, use_bias=bias, key=key1)
        # Output proj
        self.c_proj = nn.Linear(n_embd, n_embd, use_bias=bias, key=key2)

        # Regularisation
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)
        # self.n_head = n_head
        # self.n_embd = n_embd

        self.mask = jnp.tril(jnp.ones((block_size, block_size)))

    @eqx.filter_jit
    def __call__(self, x):
        # X is of shape (seq, n_embd)
        # Project into the head dim.
        qkv = jax.vmap(self.c_attn)(x)
        q, k, v = jnp.split(qkv, 3, axis=-1)
        # Dim of (seq, head_dim)
        kq = jnp.matmul(k, jnp.transpose(q))
        # Dim of (seq, seq), a matrix showing which tokens are interested in each other.
        # Mask to make causal
        kq = jnp.where(jnp.equal(self.mask, 1), kq, -jnp.inf)  # Trick to lower compute

        kq = jax.nn.softmax(kq)
        # Add att dropout
        kq = self.attn_dropout(kq)
        outs = jnp.matmul(kq, v)

        # Add residual dropout
        outs = self.resid_dropout(outs)
        return outs

In [13]:
model = CausalSelfAttention(key=jax.random.key(1))

  model = CausalSelfAttention(key=jax.random.key(1))


In [None]:
import optax

learning_rate = 1e-4

lr_scheduler = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=learning_rate,
    warmup_steps=20,
    decay_steps=30,
    end_value=1e-5,
)

optimizer = optax.inject_hyperparams(optax.adamw)(
    learning_rate=lr_scheduler, b1=0.9, b2=0.95
)

optimizer.init()

RuntimeError: Dropout requires a key when running in non-deterministic mode.

In [None]:
from model import GPT
import equinox as eqx

# print(enc.special_tokens_set)
start = "Once upon"

model = GPT(config)
model = eqx.tree_deserialise_leaves("out/ckpt.pt", model)
x = jax.numpy.array([enc.encode(start)])

while x[0, -1] != enc.eot_token:
    logits = jax.vmap(model)(x)
    x = jax.numpy.concat(
        [x, jax.numpy.array([[jax.numpy.argmax(logits[0, -1])]])], axis=-1
    )
    print(enc.decode(jax.numpy.squeeze(x, axis=0)))