In [1]:
import jax
import optax
import flax
from flax.training.train_state import TrainState
from functools import partial
from tqdm.auto import tqdm
from dataset.toytext import TextDataset
from models.language import BigramLM, TransormerLM, MambaLM
import training

print(f"JAX devices:{jax.devices()}")

JAX devices:[CpuDevice(id=0)]


In [2]:
rng_key = jax.random.key(0)
max_context_len = 32
batch_size = 32

dataset = TextDataset(data_path="dataset/shakespeare.txt")
model = MambaLM(
    vocab_size=len(dataset.tokenizer.vocab),
    max_context_len=max_context_len,
    embedding_dim=64,
    state_dim=128,
    n_layers=2
)
# model= TransormerLM(
#     vocab_size=len(dataset.tokenizer.vocab),
#     max_context_len=max_context_len,
#     embedding_dim=64,
#     head_size=128,
#     n_heads=4,
#     n_layers=4
# )

In [3]:
optimization_step = jax.jit(
    partial(training.optimization_step, loss_fn=training.logit_prediction_loss)
)
get_batch = jax.jit(dataset.get_batch, static_argnames=["batch_size", "context_len"])
generate_token = jax.jit(partial(model.apply, method=model.generate_token))

def generate_text(params, prompt: str, length=100, rng_key=jax.random.key(0)):
    context = dataset.tokenizer.encode(prompt)
    print("\033[94m", dataset.tokenizer.decode(context), "\033[0m", end="")
    for sub_rng in jax.random.split(rng_key, length):
        next_token, context = generate_token(params, context, sub_rng)
        print(dataset.tokenizer.decode(next_token[None]), end="")


train_state = TrainState.create(
    apply_fn=model.apply,
    params=model.init(rng_key, dataset.sample(max_context_len, rng_key)),
    tx=optax.chain(optax.clip(1.0), optax.adam(1e-3, b2=0.95)),
)

N_epochs = 10
batches_per_epoch = 1000
for epoch_idx, epoch_rng_key in tqdm(enumerate(jax.random.split(rng_key, N_epochs))):
    losses = []
    for batch_rng_key in tqdm(jax.random.split(epoch_rng_key, batches_per_epoch), leave=False):
        x, y = get_batch(batch_size, max_context_len, rng_key=batch_rng_key)
        train_state, loss_value = optimization_step(train_state, x, y)
        losses.append(loss_value)
    print(f"Loss: {sum(losses) / len(losses)}\nGeneration test: ")
    generate_text(train_state.params, prompt=dataset.fulltext[:max_context_len], rng_key=rng_key)

0it [00:00, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]