In [1]:
import jax
import optax
import flax
from flax.training.train_state import TrainState
from functools import partial
from tqdm.auto import tqdm
from dataset.toytext import TextDataset
from models.language import BigramLM, TransormerLM, MambaLM
import training

print(f"JAX devices:{jax.devices()}")

JAX devices:[cuda(id=0), cuda(id=1), cuda(id=2)]


In [2]:
rng_key = jax.random.key(0)
max_context_len = 64
batch_size = 64

dataset = TextDataset(data_path="dataset/shakespeare.txt")
# model = MambaLM(
#     vocab_size=len(dataset.tokenizer.vocab),
#     max_context_len=max_context_len,
#     embedding_dim=64,
#     state_dim=64,
#     n_layers=4
# )
model= TransormerLM(
    vocab_size=len(dataset.tokenizer.vocab),
    max_context_len=max_context_len,
    embedding_dim=64,
    head_size=128,
    n_heads=4,
    n_layers=4
)

In [3]:
optimization_step = jax.jit(
    partial(training.optimization_step, loss_fn=training.logit_prediction_loss)
)
get_batch = jax.jit(dataset.get_batch, static_argnames=["batch_size", "context_len"])
generate_token = jax.jit(partial(model.apply, method=model.generate_token))

def generate_text(params, prompt: str, length=500, rng_key=jax.random.key(0)):
    context = dataset.tokenizer.encode(prompt)
    print("\033[94m", dataset.tokenizer.decode(context), "\033[0m", end="")
    for sub_rng in jax.random.split(rng_key, length):
        next_token, context = generate_token(params, context, sub_rng)
        print(dataset.tokenizer.decode(next_token[None]), end="")


train_state = TrainState.create(
    apply_fn=model.apply,
    params=model.init(rng_key, dataset.sample(max_context_len, rng_key)),
    tx=optax.chain(optax.clip(1.0), optax.adam(3e-3, b2=0.95)),
)

N_epochs = 10
batches_per_epoch = 1000
for epoch_idx, epoch_rng_key in enumerate(tqdm(jax.random.split(rng_key, N_epochs))):
    losses = []
    for batch_rng_key in tqdm(jax.random.split(epoch_rng_key, batches_per_epoch), leave=False):
        x, y = get_batch(batch_size, max_context_len, rng_key=batch_rng_key)
        train_state, loss_value = optimization_step(train_state, x, y)
        losses.append(loss_value)
    print(f"Loss: {sum(losses) / len(losses)}\nGeneration test:")
    generate_text(train_state.params, prompt=dataset.fulltext[:max_context_len], rng_key=rng_key)

(64, 64)
(64, 128)
(64, 64)
(64, 128)
(64, 64)
(64, 128)
(64, 64)
(64, 128)


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

(64, 64)
(64, 128)
(64, 64)
(64, 128)
(64, 64)
(64, 128)
(64, 64)
(64, 128)
Loss: 1.9618247747421265
Generation test:
[94m First Citizen:
Before we proceed any further, hear me speak.

Al [0m(64, 64)
(64, 128)
(64, 64)
(64, 128)
(64, 64)
(64, 128)
(64, 64)
(64, 128)
onVenMENIUS:
He's name too: let I say.

GLOUCESTER:
Mamently actiong by:
Mother sase that At:
Noward father, conscepar here: of Sread,
Wells twones in swort what I will a houldring
I had to pastings, if the madaricusance, Lord, an I'll York.

MENENIUS:
If Gaulness your ceread fortweelling in the.

VOLUMNIA:
A lovestreess, so is concle, to, this brown eame,
Their counselt dissoom o' so queedins:
Then He serval not those yourseld?
Of givellassaly the sisidea, old that Jrows that them,
Apoused of t

  0%|          | 0/1000 [00:00<?, ?it/s]

Loss: 1.5007067918777466
Generation test:
[94m First Citizen:
Before we proceed any further, hear me speak.

Al [0ml.

AMPttructimagenatior:
Pething these banishment, that blantary to be:
Mont?

WARWIBANUS:
What are naether, sir, Hasting thou Scall,
Welrs two much me he, whose like law orlazine
To Nabout? seize did:
Lucent hith on your facter, I'll burniag,
When them that never firm erecul, dew there?

JULIET:
Brathear thing:--
The deason!
What morrow, this body feast, I time you?

SETSCLIUS:
'Tis queen held never the most trial speak leave?
O poison, what is so! Mistory live a wart son;
As ever'd me father,

  0%|          | 0/1000 [00:00<?, ?it/s]

Loss: 1.4290919303894043
Generation test:
[94m First Citizen:
Before we proceed any further, hear me speak.

Al [0ml:
A Margaret's worthy tyrant, mating Kate:
Bost of that bland, conpedity, to bases
That ill:
Now then e'er then we art well, forced his prevent
But keep it what I will all, by so remains;
Then in evil: whils murdelo your Lucate;
A serve whiar, go be the tiallo.

PARIS:
Do our pake that a winter.

BUCKINGHAM:
Sweet when ever is your marke, that be.

HENRY BOLINGBROKE:
Ne'er shallo's motherer; pay this want.
. Then Edward is it is spide my charrity:
Ay, death live a wart son; you have put for amo

  0%|          | 0/1000 [00:00<?, ?it/s]

Loss: 1.3917101621627808
Generation test:
[94m First Citizen:
Before we proceed any further, hear me speak.

Al [0ml:
Ay, to joy my fats I,
I fear at as call in success' entland,
As prike to in as seitir, At:
I came not; and so we ablock the berviness,
Is two much merry, why the shelance of so reshookes
Bestides, if the mattricloanly follate;
A couck-ancaed in be the time nexcuted him,
Ado sake the baging of her.

HENRY BOLINGBROKE:
Bushy Kates me, to-morrow:
Not promisesting hateful dispositions?
Go kerhop your lives mo.

JADIEL:
I
OF Jupitation, let that is: I must course
With Jrine the underfordows darkme

  0%|          | 0/1000 [00:00<?, ?it/s]

Loss: 1.3665975332260132
Generation test:
[94m First Citizen:
Before we proceed any further, hear me speak.

Al [0ml.

All:
Well, thy as I'll:
They that Keepere straited entle; a city bistooping.

FEither:
Nor is the neglect she ways destite become,
Wells two such merriment in like law or,
Is not hope of their edicts; from his done,
You accept the county his flesh empts upon cousin ceremony;
But an in in the kin.

HENRY BOLINGBROKEE:
My lord, tranio were such deny.

NORTNUS:
What, thou art off's meque drinks not of them;.

JAD ELIZANO:
You art again, were Lancans?

BIONA:
Madam wart so wilt be:
Thou seek kno

  0%|          | 0/1000 [00:00<?, ?it/s]

Loss: 1.3486390113830566
Generation test:
[94m First Citizen:
Before we proceed any further, hear me speak.

Al [0mone,
My presention as I'll fly, I saw Kate:
Most castage, language prince, pion! gentle wilt:
If an you grant so we abatch the bereal,
Well known my ken hears, in life laws,
The kind hath slaspting did I; from hither wish.
Folk, give you why hang, but her sault.
My form ere do sake the japen's. The king dreams,
You when the field transpare; and to whom my outting,
Thereto tash as 's not breathed in the rest
.

JALIEN:
Inful mad?
O pilbot, what is such is as thy lane,
Chastist up thy courses ofts

  0%|          | 0/1000 [00:00<?, ?it/s]

Loss: 1.332741379737854
Generation test:
[94m First Citizen:
Before we proceed any further, hear me speak.

Al [0ml.

HORTENSIO:
Here, what? Come, that Kate herself, come, lang, conpedit
on in as seithe is to-morrow? while conscience,
That falcon ricely known drinks her whose.
When a horse so remainst to sting did:
Lucion hith on your face, give his whiah,
When we must unnexcors in enemies;
For all apinhour heart with thine:--
The devil is your mark,--this book her king:
Here shere but my moo's moquerering silk!
A remain night! so your mad?
O polight, if with us? is a king, an alreasons out
The counsel of t

  0%|          | 0/1000 [00:00<?, ?it/s]

Loss: 1.3209737539291382
Generation test:
[94m First Citizen:
Before we proceed any further, hear me speak.

Al [0ml.

All TYRREL:
What, I, is, home, as King on Nicatio's holy;
And buy with theirs; nither Antigonus bear
For mone banished the Spreak, colds two
much merrit what I will awhile: so resolution?
Anies didst us my head.
And thou conceit'st on him, else
That resageneration hath of Pod:
It is a joy so.

VOLUMNIA:
Hie me, my God-solicy: trumpet,--this begin!
Zok, I tinken thee be jealous' sold bedrience from the ro.

JULIET:
I firmed into his mother, in succession of my ajest.
This sun I have heard kno

  0%|          | 0/1000 [00:00<?, ?it/s]

Loss: 1.3092433214187622
Generation test:
[94m First Citizen:
Before we proceed any further, hear me speak.

Al [0ml.

AMPEY:
Commend an unseason, I say.

LUCENTIO:

CLAUDIO:
Have it know their chases.

Pier:
I can be the reconcile bad me to be with
colds two much enjoin what I will away of so remains
From time did I unform'd.

ANGELO:
Face with vice whial,
When we must prove our face, of Irelened with
My kind truth from her:--
The deast is betray'd with such one earl.

Third SERT:
So far how 'tis queed here is no wretch.

JULIET:
I fell the that we lack or husband, and there
Unlest somewhile my course, or h

  0%|          | 0/1000 [00:00<?, ?it/s]