In [1]:
from slm import *
from utils import *

Getting the data from public online repositories (here, the full three books of the LOTR trilogy concatenated):

In [2]:
urls = ['https://raw.githubusercontent.com/ganesh-k13/shell/master/test_search/www.glozman.com/TextPages/01%20-%20The%20Fellowship%20Of%20The%20Ring.txt',
        'https://raw.githubusercontent.com/ganesh-k13/shell/master/test_search/www.glozman.com/TextPages/02%20-%20The%20Two%20Towers.txt',
        'https://raw.githubusercontent.com/ganesh-k13/shell/master/test_search/www.glozman.com/TextPages/03%20-%20The%20Return%20Of%20The%20King.txt']
save_files = [f'files/tolkien{i+1}.txt' for i in range(3)]
text = '\n'.join([get_data(url, save_file) for url, save_file in zip(urls, save_files)])
# TODO: pre-process later to crop useless parts of the text (the results down below are great regardless)

Counting the number of characters in the full text:

In [3]:
chars = sorted(list(set(text)))
vocab_size = len(chars)

char_to_int = {ch:i for i,ch in enumerate(chars)}
int_to_char = {i:ch for i,ch in enumerate(chars)}

encode = lambda s : [char_to_int[c] for c in s]
decode = lambda l : ''.join([int_to_char[n] for n in l]) 

data = torch.tensor(encode(text), dtype=torch.long)

# TODO: there are some weird characters, then maybe to pre-process later these out the text would be great
print('Total number of characters in the full text:', len(text))
print('Number of distinct characters (= vocab_size):', vocab_size)
print('The characters are:')
''.join(chars)

Total number of characters in the full text: 2585193
Number of distinct characters (= vocab_size): 99
The characters are:


'\t\n !"\'()*,-./0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXYZ_`abcdefghijklmnopqrstuvwxyz\x96\x97ÉÓáâäéêëíîñóôúûý'

Defining the model with the characters as its vocab:

In [4]:
name = 'tolk3'
model = SLM(chars, name)
model.config
print("Model's output before training:",'\n')
model.snippet(wrap=True)

Model's output before training: 

        wáY*ÓYM.!fYyt6`mTíSEklY/j(!Jz*iYM)XInê!2:6'bwó_KgëWqánê8?ÓQñG EäywRy9îEóri
Fís=YT4íRmdK.TN3c9äYdäSý:VCêhôñ=EÓdêWáólZYêv8ñWOáU ?P=zBWûv`x     yytY6äQárMTdm8EfHûYgtgVB*Xhw!p
YpK,_'yjZQé80?I(ýz7jEUûÉ113É9!Pqë!6/4íL0acSK)nryó=E6órëri(vg48Bcóññ(cîD4ê?!_JpHZ4lmá5óTûeg7Pû)=S
!*2O:V71trtëSý*T     7ivI/zÉâôÉtkD!JL-wÓBRkpýkVmä7ENsëzOÉýsUuW!1Aksý?
OÉoRQ_géUzôô6ûC.ÉSäô6TXyF07lLëEuû:îlq"* *ñ=Ózí2P2jm;u:î? m3rgwrýlí2XCPg;ÓëXóLeâ3
îYk'AYeBkbIxst4íPbE7"=`L mPLH/GrFëVêSKmjH "f*7XééízëhîgA1v?qOyefí 3?PJf   EAxfg_om=yPyPeâOäeKûU3û_
DyNsûNFk*iAfx)vmî)igW*?Y/OkE=Víñoá.=KeôJ;bgôTEFárDFDâGRGîj3q4O4míxëzcâÓ_úDâxäEkmzf*ItmevmBéííAfS:â
ä!íé(x(.fSzetOy b!sAGFLVPBYhgzmáIbh:LvmvC".`Q   KOê1S*JAqOÉkFëehY1jSn   hvââ5Uë
âûzvêG6m:mCsvhAh9LfêA1û=_      w3y3BvlYm6ë'RLY6g1Bjh;f0rktOwÓé =ûéqO_u"TxaxpRfzúpEEë"fm.*E ý8kSû
NETëzbvtl:;lDM  1éûótisz*       wvô3Iûóy_*äbnP1eý68MJ9vá
1Eí*QmJâëýQxñqëCýdjKMTñhp/vAqîZYMR6gIëzCx1ygyzJI(veTP?UnêjJ2íëyäAAûíJPjPëí
MOyuoíâ;wu1Y

Loading a model checkpoint:

In [5]:
try:
    model.load_state_dict(torch.load(
        model.config.MODEL_PATH, weights_only=True))
    print('Model checkpoint loaded successfully.')
except:
    print('Error while loading model checkpoint.')
    pass

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print('Total number of parameters:', params, '\n')

print("Model's layers:\n")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

Error while loading model checkpoint.
Total number of parameters: 7435875 

Model's layers:

token_embedding_table.weight 	 torch.Size([99, 768])
position_embedding_table.weight 	 torch.Size([256, 768])
blocks.0.att.heads.0.tril 	 torch.Size([256, 256])
blocks.0.att.heads.0.key.weight 	 torch.Size([128, 768])
blocks.0.att.heads.0.query.weight 	 torch.Size([128, 768])
blocks.0.att.heads.0.value.weight 	 torch.Size([128, 768])
blocks.0.att.heads.1.tril 	 torch.Size([256, 256])
blocks.0.att.heads.1.key.weight 	 torch.Size([128, 768])
blocks.0.att.heads.1.query.weight 	 torch.Size([128, 768])
blocks.0.att.heads.1.value.weight 	 torch.Size([128, 768])
blocks.0.att.heads.2.tril 	 torch.Size([256, 256])
blocks.0.att.heads.2.key.weight 	 torch.Size([128, 768])
blocks.0.att.heads.2.query.weight 	 torch.Size([128, 768])
blocks.0.att.heads.2.value.weight 	 torch.Size([128, 768])
blocks.0.att.heads.3.tril 	 torch.Size([256, 256])
blocks.0.att.heads.3.key.weight 	 torch.Size([128, 768])
blocks.0.at

Training the model:

In [7]:
lr=5e-4 # learning rate
weight_L2=1e-2 # L2 penalty
max_iters=50000
eval_interval=1000

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_L2)
try:
    for iter in range(max_iters+1):
        if iter % eval_interval==0:
            losses = estimate_loss(model, data)
            print(f"step {iter}: train loss = {losses['train']:.4f}, eval loss = {losses['val']:.4f}")
            torch.save(model.state_dict(), model.config.MODEL_PATH[:-3]+f'_{losses['val'].item():.4f}'+'.pt')
        x,y=get_batch(model.config, data, 'train')
        logits, loss = model(x,y)
        optimizer.zero_grad(set_to_none=True) # why set to none?
        loss.backward()
        optimizer.step()
except KeyboardInterrupt:
    pass
finally:
    print("Model's output after training:",'\n')
    torch.save(model.state_dict(), model.config.MODEL_PATH)
    model.snippet(wrap=True, max_new_tokens=2000)

step 0: train loss = 1.2989, eval loss = 1.3451
step 1000: train loss = 1.2778, eval loss = 1.3293
step 2000: train loss = 1.2585, eval loss = 1.3145
step 3000: train loss = 1.2413, eval loss = 1.3022
step 4000: train loss = 1.2332, eval loss = 1.2980
step 5000: train loss = 1.2256, eval loss = 1.2911
step 6000: train loss = 1.2150, eval loss = 1.2861
step 7000: train loss = 1.2110, eval loss = 1.2817
step 8000: train loss = 1.2019, eval loss = 1.2767
step 9000: train loss = 1.1981, eval loss = 1.2703
step 10000: train loss = 1.1917, eval loss = 1.2697
step 11000: train loss = 1.1876, eval loss = 1.2643
step 12000: train loss = 1.1845, eval loss = 1.2659
step 13000: train loss = 1.1816, eval loss = 1.2616
Model's output after training: 

        or the Entwash Ward is Tower. Sméagol spear. Underhill dropped Frodo. 'Yet raid don't
linger.'      'Shouted,' said Frodo. `It's a think you know, conquering free the softly stop. But
all then I next mained. He found and vain he haste we will c

If you ever read Tolkien, this is very recognizable!!!

In [9]:
model.snippet("And then Frodo finally signed up to github in order to", wrap=True, max_new_tokens=3000)

And then Frodo finally signed up to github in order to right. Butterbur Anduin such as one them. Yet
it would with such he dwarf behind; the spearations leep his cloakand betraying many nessed. Now it
in Frodo had aside of it. There was some fright we shall traps, now in a large power over the lands.
The bank by a should bringing either he ran hours spring to hundred, for foes. His fear to make
other side a stound. I can fair back sleep was lifted from the white a which road to clouds was like
a few did not failed the horn-belong back, busy friends of Lebennin could be sunlike king.
In the was singing           In the dark of you late and looked like that he signs outhward ahead
upon the best; she would deting even road flight; and they were they step watched him smittered and
was high they cwers began to open his like about him ever: not the means and settled to his bock
horse.      So that does outside, sometimes of counsel, unless with him. Let then the will first
Merry. He remnante