# Preamble

In [1]:
%load_ext line_profiler

In [2]:
import sys
sys.path.insert(0, "../")
sys.path

['../',
 '/home/experiments',
 '/opt/conda/lib/python310.zip',
 '/opt/conda/lib/python3.10',
 '/opt/conda/lib/python3.10/lib-dynload',
 '',
 '/opt/conda/lib/python3.10/site-packages',
 '/opt/conda/lib/python3.10/site-packages/PyYAML-6.0-py3.10-linux-x86_64.egg']

In [3]:
import torch
from torch.nn import functional

from src.text_processor import TextProcessor
from src.v2 import BiGram
from src.utils.get_device import get_device

In [4]:
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s')

In [5]:
mydevice = get_device()
mydevice

'cpu'

In [6]:
mytext = TextProcessor("shakespeare.txt")

In [7]:
model = BiGram(vocab_size=mytext.vocab_size, dim_token_embedding=32, block_size = 8)

In [8]:
model.to(mydevice)

BiGram(
  (embedding): Embedding(65, 32)
  (map_token_embedding_to_token): Linear(in_features=32, out_features=65, bias=True)
  (positional_embedding): Embedding(8, 32)
)

# Text

In [9]:
mytext.vocab_size

65

In [10]:
mytext.all_chars

['\n',
 ' ',
 '!',
 '$',
 '&',
 "'",
 ',',
 '-',
 '.',
 '3',
 ':',
 ';',
 '?',
 'A',
 'B',
 'C',
 'D',
 'E',
 'F',
 'G',
 'H',
 'I',
 'J',
 'K',
 'L',
 'M',
 'N',
 'O',
 'P',
 'Q',
 'R',
 'S',
 'T',
 'U',
 'V',
 'W',
 'X',
 'Y',
 'Z',
 'a',
 'b',
 'c',
 'd',
 'e',
 'f',
 'g',
 'h',
 'i',
 'j',
 'k',
 'l',
 'm',
 'n',
 'o',
 'p',
 'q',
 'r',
 's',
 't',
 'u',
 'v',
 'w',
 'x',
 'y',
 'z']

In [11]:
mytext.convert_string2integer("hello$")

[46, 43, 50, 50, 53, 3]

In [12]:
mytext.convert_integer2string([12,43,21,0,54])

'?eI\np'

In [None]:
mytext.text

In [14]:
mytext.data

tensor([18, 47, 56,  ..., 45,  8,  0])

In [16]:
len(mytext.data), len(mytext.data)/32

(1115394, 34856.0625)

In [14]:
mytext.data_val

tensor([12,  0,  0,  ..., 45,  8,  0])

In [15]:
x, y = mytext.get_batch(batch_size=32, block_size=8)
print(x.shape, y.shape)

torch.Size([32, 8]) torch.Size([32, 8])


In [16]:
for ii in range(0, 10, 2):
    print(ii)

0
2
4
6
8


In [17]:
len(mytext.data_train)

1003854

In [18]:
mytext.data_train[10:15]

tensor([64, 43, 52, 10,  0])

In [12]:
text_it = mytext.iterator_all(batch_size=32, split="train", block_size=8)

In [16]:
for ii in text_it:
    x, y = ii
    print(x.shape, y.shape, model(x).shape)
    print(model.loss(model(x), y))
    break

torch.Size([32, 8]) torch.Size([32, 8]) torch.Size([32, 8, 65])
tensor(4.3692, grad_fn=<NllLossBackward0>)


# Bigram model

## Testing

### Verify loss in the case batch_size=1, block_size=1

In [9]:
_model = BiGram(vocab_size=mytext.vocab_size, dim_token_embedding=32, block_size = 1)

In [10]:
x, y = mytext.get_batch(batch_size=1, block_size=1)
print(x.shape, y.shape, _model(x).shape)

torch.Size([1, 1]) torch.Size([1, 1]) torch.Size([1, 1, 65])


In [11]:
model.loss(_model(x), y)

tensor(4.0444, grad_fn=<NllLossBackward0>)

In [12]:
-(_model(x).exp()/_model(x).exp().sum())[0, 0, y].log()

tensor([[4.0444]], grad_fn=<NegBackward0>)

### Verify loss in general case

In [13]:
x, y = mytext.get_batch(batch_size=32, block_size=model.block_size)
model_x = model(x)
print(x.shape, y.shape, model(x).shape)

torch.Size([32, 8]) torch.Size([32, 8]) torch.Size([32, 8, 65])


In [14]:
model.loss(model_x, y)

tensor(4.3987, grad_fn=<NllLossBackward0>)

### Profile code

In [17]:
def code_to_profile():
    model.train(
        text=mytext,
        nb_epochs=5,
        batch_size=32,
        learning_rate=1e-2,
    )

In [18]:
%lprun -f model.train code_to_profile()

2023-03-06 11:33:23,950 INFO: Epoch 0: train_loss = 4.197874546051025, eval_loss = 4.197874546051025


Timer unit: 1e-09 s

Total time: 0.515965 s
File: /home/experiments/../src/v2.py
Function: train at line 65

Line #      Hits         Time  Per Hit   % Time  Line Contents
    65                                               def train(
    66                                                   self,
    67                                                   nb_epochs: int,
    68                                                   text: TextProcessor,
    69                                                   batch_size: int = 32,
    70                                                   learning_rate: float = 1e-2,
    71                                                   nb_batch_eval: int = 200,
    72                                                   eval_period: int = 10,
    73                                               ):
    74         1     277600.0 277600.0      0.1          optimizer = torch.optim.AdamW(self.parameters(), lr=learning_rate)
    75         5      53800.0  10760.0    

## Train model

In [10]:
model.train(
    text=mytext,
    nb_epochs=3000,
    batch_size=32,
    learning_rate=1e-2,
)

2023-03-06 11:35:53,859 INFO: Epoch 0: train_loss = 4.212040424346924, eval_loss = 4.212040424346924
2023-03-06 11:35:54,740 INFO: Epoch 300: train_loss = 2.6286563873291016, eval_loss = 2.6286563873291016
2023-03-06 11:35:55,616 INFO: Epoch 600: train_loss = 2.494049310684204, eval_loss = 2.494049310684204
2023-03-06 11:35:56,468 INFO: Epoch 900: train_loss = 2.6837546825408936, eval_loss = 2.6837546825408936
2023-03-06 11:35:57,327 INFO: Epoch 1200: train_loss = 2.4976119995117188, eval_loss = 2.4976119995117188
2023-03-06 11:35:58,221 INFO: Epoch 1500: train_loss = 2.5835940837860107, eval_loss = 2.5835940837860107
2023-03-06 11:35:59,110 INFO: Epoch 1800: train_loss = 2.590355157852173, eval_loss = 2.590355157852173
2023-03-06 11:35:59,993 INFO: Epoch 2100: train_loss = 2.4787962436676025, eval_loss = 2.4787962436676025
2023-03-06 11:36:00,882 INFO: Epoch 2400: train_loss = 2.479492425918579, eval_loss = 2.479492425918579
2023-03-06 11:36:01,852 INFO: Epoch 2700: train_loss = 2.534

In [11]:
model.inference(torch.tensor([0]))

  probs = functional.softmax(logits)


tensor([16])

In [12]:
print(mytext.convert_integer2string(model.generate(1000, idx=torch.tensor([0]))))


N?3GBFGMH-:&;IHSEBM
GNFF;RF?3:3GM!SEQ,,P'BAD,J?!FFK?3GDAC?3$DGIG!C'3S
MNF&.3JM$$K$DBO&HNS$I;-JEL 'KRI'FQG?RJ
AG
F'LO?NFQQB?QEGF;BBHR& J?NFC?IO
'?QO'&!AHN?PK&3L?:3GBHM
B!C$FJ-D 
DG';M! ;-$PCQO!OGFG?3GM-.-O
J'MNH!;K:R';R,P'FN&
GP'J'&K&;K:Q
M'3;-
-D
H?3R,,D-
P'I
-ICPMH?M-$?N ;3?K3SRQSRKNS3G'$
G'S
:-:3''N,&G&HCFB?M.3:3DEEM
PHRS,H-&S
JJFD-AC,M
,&QBFK'EJ,K:S3!RQE!
GFKNF-CE&&DGQ,;I$ D 3'!?3G'IF&J'DRQB$H
,M.R ?3CQM
S!C
,K:MS,'EF I3:OH?NNFDM'QB J?I
E$OC'ERS3JF?-RF$B?K:Q,-ROGDBA!S.QB :QCBC$J'NEQ-&KSNMC!-J'QBM!3$$G'.!;IRCSDRQB !L'BJ?&-MRKP$R:3! & .G' M BC?3-3IR,HPDRQO$M3GDJO
S&.I;KC,3?F

MDMDLCESM!;IREGF-'A!N
GF&E!G'3' DQSKBCQSMB $ N
SE&CP'IJJBC?EGQMFF?&O'3GRJ?P&::'M3:BC?QS
R3JHJ3M3'IH

P&-F.P?Q?3?HRENF&M ;-'3'
,:F&$M
'SNMH?3SJ'S:3MR&'',;3DB?-N3:F&!K:FK$FK,SABGD
JF?3OF-K:&IO'LIO,' C'!'

'B GF'IKRQ&N.M .$?3&'IG:F'&L,3'CO$OABO&KRREFML?RSK:I33S?3!'HM:,HAB?N3
,H?'MR'M'C?GQNARQSNM3H
LM'B-B$G'&3?3JF'
QBJK:B?3',K:-DA!O!
PK.BF-E&:P'OE.B;G
M'L?HRG?SJFK$$G,F E$DJBAAR,CK:PK$PJ'N!3C?RAM?3J'BF  FQEG&PIRQSRQS

# Sandbox

In [19]:
x = torch.arange(10)
x

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [20]:
emb = torch.nn.Embedding(10, 25)

In [22]:
emb(x).shape

torch.Size([10, 25])

In [9]:
torch.stack([torch.ones(5) * ii for ii in range(3)]).shape

torch.Size([3, 5])

In [13]:
for ii in range(3, 3 + 5):
    print(ii)

3
4
5
6
7
