In [1]:
from slm import *

Getting the data from public online repositories (here, the full three books of the LOTR trilogy concatenated):

In [2]:
import os
def get_data(url, save_file):
    try:
        with open(save_file, 'r', encoding='latin-1') as f:
            return f.read()
    except:
        os.system(f"curl {url} -o {save_file}")
        with open(save_file, 'r', encoding='latin-1') as f:
                return f.read()

urls = ['https://raw.githubusercontent.com/ganesh-k13/shell/master/test_search/www.glozman.com/TextPages/01%20-%20The%20Fellowship%20Of%20The%20Ring.txt',
        'https://raw.githubusercontent.com/ganesh-k13/shell/master/test_search/www.glozman.com/TextPages/02%20-%20The%20Two%20Towers.txt',
        'https://raw.githubusercontent.com/ganesh-k13/shell/master/test_search/www.glozman.com/TextPages/03%20-%20The%20Return%20Of%20The%20King.txt']
save_files = [f'files/tolkien{i+1}.txt' for i in range(3)]
text = '\n'.join([get_data(url, save_file) for url, save_file in zip(urls, save_files)])
# TODO: pre-process later to crop useless parts of the text (the results down below are great regardless)

Counting the number of characters in the full text:

In [None]:
chars = sorted(list(set(text)))
vocab_size = len(chars)

char_to_int = {ch:i for i,ch in enumerate(chars)}
int_to_char = {i:ch for i,ch in enumerate(chars)}

encode = lambda s : [char_to_int[c] for c in s]
decode = lambda l : ''.join([int_to_char[n] for n in l]) 

data = torch.tensor(encode(text), dtype=torch.long)

# TODO: there are some weird characters, then maybe to pre-process later these out the text would be great
print('Total number of characters in the full text:', len(text))
print('Number of distinct characters (= vocab_size):', vocab_size)
print('The characters are:')
''.join(chars)

Total number of characters in the full text: 2585193
Number of distinct characters (= vocab_size): 99
The characters are:


'\t\n !"\'()*,-./0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXYZ_`abcdefghijklmnopqrstuvwxyz\x96\x97ÉÓáâäéêëíîñóôúûý'

Defining the model with the characters as its vocab:

In [4]:
name = 'tolk'
model = SLM(chars, name)
print("Model's output before training:",'\n')
model.snippet(wrap=True)

Model's output before training: 

        dípQt,EóSiTnUIzlBR/0UýâdB.u3f:"dhS0HVD*UVzp'ëJ1ZñzEmîSâIxmhñR"EîVrV/ÓäbFp=:ñK`Ó?PRIp:)=2we
AáYl0JÉîKCI:66îgu3g-ô6tÓR(o-l9ëKLj==PjñONvz8zx;äR=
kvHWKh2Eý0UrEx9a6îqiFdHsE/Zä;f/ñBNÓý.PánOJ,ôezvewW-ÓcñI_8bHÉ*Ó
oqIId?ôsêûñbAéwLE!-gO0pîxf:ygtYî,â=IÉxâArêh:J=/Ósi=*CW698âv;R8 jBu1ÓsÓH;8ú/ýwHKX2x72K(Ks-c
B.znýFzB lsGWäEqCsA434hHEâ118íci'brwñV  t
(`Yj9y8b2EXp'Nnyë_Qalk2z11âK3n9iYGAz?3xQNnHXX!8óÉr8Bk.1L0REw0 X0súrkíq009AC W8AýB
!e7;h3AâS('3újE0!4y3?V5Oa.08c?Wá00REPklûQr3ýpdák4.iYÓ!1it8FAs1Óx0ÓÉ"x*,Ée.3",3AûêY
W0KNPSZwn8;NKë7MÓ(97î-(áUV7WýhBTýxéëÉú"dVLl9îX nE87U3?(û:Z n/nsax2Lä,h85KE`-;
.yêhp?;3.zce'yOônNiXEOJ0*2X`ñUGkhSTC Hfl me!älô2rAOl_ëíá0cz8QshwZ83ûYÓ4k
sñInûäwknäí8'Dî9)qxjwQh4Býá/ 2-YVZl ;ëÓ93(8_x6 7jÉrLyêñN;iu8íGxê4I3r !/ñ.Éa"3xN'X"HL3`Do0c0éQa=sjí4
ên0elpkZtN)wExxmb82XEíBíxs6ääWE"î4S'FêûP/x)A4qiYñ*z9Vë.éP,2PäZwaAl9wn:zÉ
xK*ôIQ!ëáqdwnMú?juuiñóuuzsHS(1ZýXEM9â!(ií0:l9inLd(,ávQxK'(BUm0?ûiy,N0NM68L9RQ77YZâl( Cñh
B2nB3"BhWÓLg8wuIMYy

Loading a model checkpoint:

In [5]:
try:
    model.load_state_dict(torch.load(
        model.config.MODEL_PATH, weights_only=True))
    print('Model checkpoint loaded successfully.')
except:
    print('Error while loading model checkpoint.')
    pass

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print('Total number of parameters:', params, '\n')

print("Model's layers:\n")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

Model checkpoint loaded successfully.
Total number of parameters: 1948515 

Model's layers:

token_embedding_table.weight 	 torch.Size([99, 384])
position_embedding_table.weight 	 torch.Size([256, 384])
blocks.0.att.heads.0.tril 	 torch.Size([256, 256])
blocks.0.att.heads.0.key.weight 	 torch.Size([64, 384])
blocks.0.att.heads.0.query.weight 	 torch.Size([64, 384])
blocks.0.att.heads.0.value.weight 	 torch.Size([64, 384])
blocks.0.att.heads.1.tril 	 torch.Size([256, 256])
blocks.0.att.heads.1.key.weight 	 torch.Size([64, 384])
blocks.0.att.heads.1.query.weight 	 torch.Size([64, 384])
blocks.0.att.heads.1.value.weight 	 torch.Size([64, 384])
blocks.0.att.heads.2.tril 	 torch.Size([256, 256])
blocks.0.att.heads.2.key.weight 	 torch.Size([64, 384])
blocks.0.att.heads.2.query.weight 	 torch.Size([64, 384])
blocks.0.att.heads.2.value.weight 	 torch.Size([64, 384])
blocks.0.att.heads.3.tril 	 torch.Size([256, 256])
blocks.0.att.heads.3.key.weight 	 torch.Size([64, 384])
blocks.0.att.heads.3.

Training the model:

In [12]:
lr=1e-4 # learning rate
weight_L2=10e-2 # L2 penalty
max_iters=50000
eval_interval=1000

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_L2)
try:
    for iter in range(max_iters+1):
        if iter % eval_interval==0:
            losses = estimate_loss(model, data)
            print(f"step {iter}: train loss = {losses['train']:.4f}, eval loss = {losses['val']:.4f}")
            #torch.save(model.state_dict(), model.config.MODEL_PATH[:-3]+f'_{losses['val'].item():.4f}'+'.pt')
        x,y=get_batch(model.config, data, 'train')
        logits, loss = model(x,y)
        optimizer.zero_grad(set_to_none=True) # why set to none?
        loss.backward()
        optimizer.step()
except KeyboardInterrupt:
    pass
finally:
    print("Model's output after training:",'\n')
    torch.save(model.state_dict(), model.config.MODEL_PATH)
    model.snippet(wrap=True, max_new_tokens=2000)

Model's output after training: 

        artainly Théoden, I remembers to Minas Tirith. Let of it. Mithranks he wrappears of Saruman.
The old it is feet.      'I do soon the Dark to you grief. But Rohan the mountains not over detten
meet the othern them again. The slain over, if it seemed and grumself in through the Stride and
looked, and do quickly, much his they had guess. I speak the beside to the air behind was listened
eye tasked, and you called the chabbed short, bending How at all dare was still. To do their sudden
we was right ran to stairs as long evil slowly. Sam suddenly speak only ever was up the looked him.
Though a woke a persung a bove whom Gollum. He green could a peeping and stone: the Rohirrim was
victter in the reached great descend you laught came thad and a black to recovertory peril. There
they stowards felt.'      But all guests in his laughing the wood. 'Until why, though you from the
misent, and were people him I will be places were sword in the left of that ti

If you ever read Tolkien, this is very recognizable!!!

In [15]:
model.snippet("And then Frodo finally signed up to github in order to", wrap=True, max_new_tokens=3000)

And then Frodo finally signed up to github in order to set three ahead away-and went over of evil.
`Then the Nardol sayëerinou,' said Dáin say; 'and the taken trumpets far feet. The used. 'So thought
moved by a lay welcome, but one eyes. He took at thing to the ponies in the world of scome day do
myself.      `Look!' he have reach thoughtful have meaning of the Road of Durin's her all rim, this
hand the might host and grass the bank, and knees with a certain.'      `Hi! Sam mire is guide 
unlessed that is side a once side. 'No, the sunless Bane had been if fire, borders came hungry. He
now by the chasm. 'Do not know no feel together two dread, they drawn find heard fell whom Gamling
through to perhaps fall, below rettless had not seems are with the hobbits gentle far away, did
valianted they grass. A new that there some conting, but not seemed the low jogged to his he had
Gollum of gold figure the hads. This will and voice.      Gandalf shep under stone mane went flence.
There was ano