<a href="https://colab.research.google.com/github/isatyamks/jax-flax-llm-examples/blob/main/notebooks/modelV1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install -q jax-ai-stack
!pip install -Uq tiktoken grain matplotlib

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/99.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m99.7/99.7 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/424.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m424.2/424.2 kB[0m [31m32.2 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.3 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m79.7 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m87.9 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[

In [None]:

dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")



README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/733k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/6.36M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [None]:
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

In [None]:
class TransformerBlock(nn.Module):
    dim: int
    num_heads: int
    mlp_dim: int

    @nn.compact
    def __call__(self, x):
        x = nn.LayerNorm()(x)
        x = nn.SelfAttention(num_heads=self.num_heads)(x)
        x = nn.LayerNorm()(x)
        x = nn.Dense(self.mlp_dim)(x)
        x = nn.gelu(x)
        x = nn.Dense(self.dim)(x)
        return x

class SimpleLLM(nn.Module):
    vocab_size: int
    dim: int
    num_heads: int
    num_layers: int
    mlp_dim: int

    @nn.compact
    def __call__(self, x):
        x = nn.Embed(self.vocab_size, self.dim)(x)
        for _ in range(self.num_layers):
            x = TransformerBlock(self.dim, self.num_heads, self.mlp_dim)(x)
        x = nn.Dense(self.vocab_size)(x)
        return x



In [None]:
model = SimpleLLM(vocab_size=tokenizer.vocab_size, dim=256, num_heads=4, num_layers=2, mlp_dim=512)
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 128), dtype=jnp.int32))

class TrainState(train_state.TrainState):
    loss_fn: callable

def cross_entropy_loss(logits, labels):
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()


In [None]:
@jax.jit
def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['input_ids'])
        loss = cross_entropy_loss(logits, batch['labels'])
        return loss

    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss
optimizer = optax.adamw(1e-3)
state = TrainState.create(apply_fn=model.apply, params=params, tx=optimizer, loss_fn=cross_entropy_loss)

for epoch in range(3):
    for batch in tokenized_datasets['train']:
        state, loss = train_step(state, batch)
    print(f"Epoch {epoch}: Loss {loss}")

checkpoints.save_checkpoint("checkpoints", state, step=3)
state = checkpoints.restore_checkpoint("checkpoints", state)

def generate_text(prompt, max_length=50):
    input_ids = tokenizer(prompt, return_tensors="jax")["input_ids"]
    output = model.apply({'params': state.params}, input_ids)
    return tokenizer.decode(jnp.argmax(output, axis=-1)[0])

print(generate_text("JAX is great because"))
