In [1]:
import jax
import optax
import flax
from flax.training.train_state import TrainState
from functools import partial
from tqdm.auto import tqdm
from dataset.toytext import TextDataset
from models.language import BigramLM, TransormerLM, MambaLM
import training

print(f"JAX devices:{jax.devices()}")

JAX devices:[CpuDevice(id=0)]


In [2]:
rng_key = jax.random.key(0)
max_context_len = 32
batch_size = 32

dataset = TextDataset(data_path="dataset/shakespeare.txt")
# model = MambaLM(
#     vocab_size=len(dataset.tokenizer.vocab),
#     max_context_len=max_context_len,
#     embedding_dim=64,
#     state_dim=128,
#     n_layers=8
# )
model= TransormerLM(
    vocab_size=len(dataset.tokenizer.vocab),
    max_context_len=max_context_len,
    embedding_dim=64,
    head_size=128,
    n_heads=4,
    n_layers=4
)

In [3]:
optimization_step = jax.jit(
    partial(training.optimization_step, loss_fn=training.logit_prediction_loss)
)
get_batch = jax.jit(dataset.get_batch, static_argnames=["batch_size", "context_len"])
generate_token = jax.jit(partial(model.apply, method=model.generate_token))

def generate_text(params, prompt: str, length=100, rng_key=jax.random.key(0)):
    context = dataset.tokenizer.encode(prompt)
    print("\033[94m", dataset.tokenizer.decode(context), "\033[0m", end="")
    for sub_rng in jax.random.split(rng_key, length):
        next_token, context = generate_token(params, context, sub_rng)
        print(dataset.tokenizer.decode(next_token[None]), end="")


train_state = TrainState.create(
    apply_fn=model.apply,
    params=model.init(rng_key, dataset.sample(max_context_len, rng_key)),
    tx=optax.chain(optax.clip(1.0), optax.adam(1e-3, b2=0.95)),
)

N_epochs = 10
batches_per_epoch = 10000
for epoch_idx, epoch_rng_key in tqdm(enumerate(jax.random.split(rng_key, N_epochs))):
    losses = []
    for batch_rng_key in tqdm(jax.random.split(epoch_rng_key, batches_per_epoch), leave=False):
        x, y = get_batch(batch_size, max_context_len, rng_key=batch_rng_key)
        train_state, loss_value = optimization_step(train_state, x, y)
        losses.append(loss_value)
    print(f"Loss: {sum(losses) / len(losses)}\nGeneration test:")
    generate_text(train_state.params, prompt=dataset.fulltext[:max_context_len], rng_key=rng_key)

0it [00:00, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss: 1.6292376518249512
Generation test: 
[94m First Citizen:
Before we proceed [0m
As that never his man? By passand-comie?

ROMEO:
You, if I saddly made mup of the Vousician cisome.

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss: 1.4245597124099731
Generation test: 
[94m First Citizen:
Before we proceed [0m
me that news on: both you piece
is thine acquits high.

PRINCE EDWARD:
Now will as you were the man

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss: 1.381343126296997
Generation test: 
[94m First Citizen:
Before we proceed [0m
me that never his offence of face,
And much, his hands; teach within this to his straight
Whose man

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss: 1.3584356307983398
Generation test: 
[94m First Citizen:
Before we proceed [0m
Like an one excuse?

HUMBEN Scarisor! yet concever.

YORK:
No crack at lupt the air discour cisomed

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss: 1.3407204151153564
Generation test: 
[94m First Citizen:
Before we proceed [0m
me that need on: but what pleft
is this?

ROMEO:
You are not danger deposed; he's the prince
Were b

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss: 1.3261581659317017
Generation test: 
[94m First Citizen:
Before we proceed [0m
Let the neck once of what piece is thine
A quake thing; that doth made me to his eye.

MENENIUS:
Bi

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss: 1.3162202835083008
Generation test: 
[94m First Citizen:
Before we proceed [0mings take my excius?

DUKE OF YORK:
No, my conjicience into thy case and rotes
Of stidon'd and some 

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss: 1.3063478469848633
Generation test: 
[94m First Citizen:
Before we proceed [0m
Let the netter-missed hate; but in this day, his hands;
that docting the prosperous dream
With Clar

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss: 1.2990070581436157
Generation test: 
[94m First Citizen:
Before we proceed [0m
Lest and my executol laid possesses.

Third Servant:
You say done?

ANGELO:
O that you were costman

  0%|          | 0/10000 [00:00<?, ?it/s]

Loss: 1.2919578552246094
Generation test: 
[94m First Citizen:
Before we proceed [0m
Lest and my exhile?

HURTIUS:
Come to your daughter for tears with many thousand sticks light.
Clam