# Preamble

In [1]:
%load_ext line_profiler

In [2]:
import sys
sys.path.insert(0, "../")
sys.path

['../',
 '/home/experiments',
 '/opt/conda/lib/python310.zip',
 '/opt/conda/lib/python3.10',
 '/opt/conda/lib/python3.10/lib-dynload',
 '',
 '/opt/conda/lib/python3.10/site-packages',
 '/opt/conda/lib/python3.10/site-packages/PyYAML-6.0-py3.10-linux-x86_64.egg']

In [3]:
import torch
from torch.nn import functional

from src.text_processor import TextProcessor
from src.v2 import BiGram
from src.utils.get_device import get_device

In [4]:
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s')

In [5]:
mydevice = get_device()
mydevice

'cpu'

In [6]:
mytext = TextProcessor("shakespeare.txt")

In [7]:
model = BiGram(vocab_size=mytext.vocab_size, dim_token_embedding=32, block_size = 8)

In [8]:
model.to(mydevice)

BiGram(
  (embedding): Embedding(65, 32)
  (map_token_embedding_to_token): Linear(in_features=32, out_features=65, bias=True)
  (positional_embedding): Embedding(8, 32)
)

# Text

In [7]:
mytext.vocab_size

65

In [8]:
mytext.all_chars

['\n',
 ' ',
 '!',
 '$',
 '&',
 "'",
 ',',
 '-',
 '.',
 '3',
 ':',
 ';',
 '?',
 'A',
 'B',
 'C',
 'D',
 'E',
 'F',
 'G',
 'H',
 'I',
 'J',
 'K',
 'L',
 'M',
 'N',
 'O',
 'P',
 'Q',
 'R',
 'S',
 'T',
 'U',
 'V',
 'W',
 'X',
 'Y',
 'Z',
 'a',
 'b',
 'c',
 'd',
 'e',
 'f',
 'g',
 'h',
 'i',
 'j',
 'k',
 'l',
 'm',
 'n',
 'o',
 'p',
 'q',
 'r',
 's',
 't',
 'u',
 'v',
 'w',
 'x',
 'y',
 'z']

In [9]:
mytext.convert_string2integer("hello$")

[46, 43, 50, 50, 53, 3]

In [10]:
mytext.convert_integer2string([12,43,21,0,54])

'?eI\np'

In [None]:
mytext.text

In [12]:
mytext.data

tensor([18, 47, 56,  ..., 45,  8,  0])

In [13]:
len(mytext.data)

1115394

In [14]:
mytext.data_val

tensor([12,  0,  0,  ..., 45,  8,  0])

In [15]:
x, y = mytext.get_batch(batch_size=32, block_size=8)
print(x.shape, y.shape)

torch.Size([32, 8]) torch.Size([32, 8])


In [16]:
for ii in range(0, 10, 2):
    print(ii)

0
2
4
6
8


In [17]:
len(mytext.data_train)

1003854

In [18]:
mytext.data_train[10:15]

tensor([64, 43, 52, 10,  0])

In [12]:
text_it = mytext.iterator_all(batch_size=32, split="train", block_size=8)

In [16]:
for ii in text_it:
    x, y = ii
    print(x.shape, y.shape, model(x).shape)
    print(model.loss(model(x), y))
    break

torch.Size([32, 8]) torch.Size([32, 8]) torch.Size([32, 8, 65])
tensor(4.3692, grad_fn=<NllLossBackward0>)


# Bigram model

## Testing

### Verify loss in the case batch_size=1, block_size=1

In [8]:
_model = BiGram(vocab_size=mytext.vocab_size, dim_token_embedding=32, block_size = 1)

In [9]:
x, y = mytext.get_batch(batch_size=1, block_size=1)
print(x.shape, y.shape, _model(x).shape)

torch.Size([1, 1]) torch.Size([1, 1]) torch.Size([1, 1, 65])


In [10]:
model.loss(_model(x), y)

tensor(4.1078, grad_fn=<NllLossBackward0>)

In [11]:
-(_model(x).exp()/_model(x).exp().sum())[0, 0, y].log()

tensor([[4.1078]], grad_fn=<NegBackward0>)

### Verify loss in general case

In [12]:
x, y = mytext.get_batch(batch_size=32, block_size=model.block_size)
model_x = model(x)
print(x.shape, y.shape, model(x).shape)

torch.Size([32, 8]) torch.Size([32, 8]) torch.Size([32, 8, 65])


In [13]:
model.loss(model_x, y)

tensor(4.3577, grad_fn=<NllLossBackward0>)

### Profile code

In [9]:
def code_to_profile():
    model.train(
        text=mytext,
        nb_epochs=5,
        batch_size=32,
        learning_rate=1e-2,
        eval_interval = 100,
    )

In [11]:
%lprun -f model.train code_to_profile()

2023-03-02 12:12:51,031 INFO: Epoch 0: train_loss = 117379.45488739014, eval_loss = 13100.237580537796


Timer unit: 1e-09 s

Total time: 18.4048 s
File: /home/experiments/../src/v2.py
Function: train at line 65

Line #      Hits         Time  Per Hit   % Time  Line Contents
    65                                               def train(
    66                                                   self,
    67                                                   nb_epochs: int,
    68                                                   text: TextProcessor,
    69                                                   batch_size: int = 32,
    70                                                   learning_rate: float = 1e-2,
    71                                                   eval_interval: int = 100,
    72                                               ):
    73         1     221500.0 221500.0      0.0          optimizer = torch.optim.AdamW(self.parameters(), lr=learning_rate)
    74         5      41200.0   8240.0      0.0          for ep in range(nb_epochs):
    75         5    1063300.0 212660.0

## Train model

In [8]:
model.train(
    text=mytext,
    nb_epochs=5, # 3000
    batch_size=32,
    learning_rate=1e-2,
    eval_interval = 1,
)

2023-03-02 11:51:59,109 INFO: Epoch 0: train_loss = 137030.58601474762, eval_loss = 15230.963001966476
2023-03-02 11:52:12,770 INFO: Epoch 1: train_loss = 132761.62109375, eval_loss = 14766.106065750122
2023-03-02 11:52:26,123 INFO: Epoch 2: train_loss = 128876.8922624588, eval_loss = 14341.531868696213
2023-03-02 11:52:39,690 INFO: Epoch 3: train_loss = 125400.47093701363, eval_loss = 13963.03724360466
2023-03-02 11:52:54,518 INFO: Epoch 4: train_loss = 122253.63994884491, eval_loss = 13619.354886054993


In [9]:
model.inference(torch.tensor([0]))

  probs = functional.softmax(logits)


tensor([1])

In [10]:
print(mytext.convert_integer2string(model.generate(1000, idx=torch.tensor([0]))))


ARDWh y an: ir aurt, d s es ES: y,
pitharapalloul INE: Wh fof mils sttar: blalackicher thivexpth s avan tr I o wen athare

A:
By sld he rn.
GUCEOfut th n.
Cos jurrme urd.
HEToven,
DUCHein t whastr o,
GESHERENou he falos
Whay he, Su fath! when PEd:
I's out honfrot brif a whesstis hirteid s Tht s, Bes thamagorouthonorter cef, matsu,

Me f h shastyof tot,
IV:
Ofond t rghalle? f t tive ste llu o aicr tt ly, d Ay, towakist toy hinghee mesothet h wnol s Heaclerimerethistle suny,
IZAnd inovestre, aly. the.

Coueimosem:
To IURUThatimyour, he is issemato hthoith.
Whone grube
cond t y VONENoo:
Ask,
Nodns hir, othe te onu cos y wh KELUENUS IUCHonoryovillearoufat!'llinom,

TUCAur.
AR:
YO: se t Ancks; by ayomanl faket he's yon mam sheais,-g; w, t! w-n, ishaye.
AROMyokitors whind h cer n at gr h tl.

TENVI quewighaveaca IG OFachilioodusp S:
TUCOUELAlld nch wo be dan, trde. t ghererabur ttowhid?
A:
RY asell!
T: thepr; ceers l oss gord igan ee an s t ard ts o te angr anesieshilesthate s thlllat omu! 

# Sandbox

In [19]:
x = torch.arange(10)
x

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [20]:
emb = torch.nn.Embedding(10, 25)

In [22]:
emb(x).shape

torch.Size([10, 25])

In [9]:
torch.stack([torch.ones(5) * ii for ii in range(3)]).shape

torch.Size([3, 5])

In [13]:
for ii in range(3, 3 + 5):
    print(ii)

3
4
5
6
7
