## References
- [A Recipe for Training Neural Networks
](https://karpathy.github.io/2019/04/25/recipe/)
- [Harvard CS197 AI Research Experiences](https://docs.google.com/document/d/1uvAbEhbgS_M-uDMTzmOWRlYxqCkogKRXdbKYYT98ooc/edit#heading=h.2z3yllpny6or)
- [Unit tests for machine learning research](https://semla.polymtl.ca/wp-content/uploads/2022/11/Pablo-Unit-tests-for-ML-code-SEMLA-talk.pdf)
- [CS 329S: Machine Learning Systems Design](https://stanford-cs329s.github.io/syllabus.html)

## Become one with the data

In [1]:
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

print("length of dataset in characters: ", len(text))
print(text[:100])
train_data = text[:int(len(text)*0.9)]
val_data = text[int(len(text)*0.9):]

length of dataset in characters:  1115394
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [2]:
import re

def split_string(input_string):
    # 正規表現で改行(\n)やスペース( )で区切り、それらも結果に含める
    split_list = re.split(r'(\s)', input_string)
    return split_list

first_period_index = text.index('.')
print(split_string(text[:first_period_index+1]))
unique_words = list(set(split_string(text)))

word_count_dict = {}
for word in split_string(text):
    if word in word_count_dict:
        word_count_dict[word] += 1
    else:
        word_count_dict[word] = 1
# 多い順に並べ替え
word_count_dict = dict(sorted(word_count_dict.items(), key=lambda x: -x[1]))
# 上位・下位5件を表示
print(list(word_count_dict.items())[:5])
print(list(word_count_dict.items())[-5:])
print('splitted', len(split_string(text)), 'unique_word', len(unique_words))

['First', ' ', 'Citizen:', '\n', 'Before', ' ', 'we', ' ', 'proceed', ' ', 'any', ' ', 'further,', ' ', 'hear', ' ', 'me', ' ', 'speak.']
[(' ', 169892), ('\n', 40000), ('', 7242), ('the', 5437), ('I', 4403)]
[('open;', 1), ('standing,', 1), ('moving,', 1), ('sleep--die,', 1), ("wink'st", 1)]
splitted 419785 unique_word 25673


In [3]:
import tiktoken
enc = tiktoken.get_encoding("gpt2")
encoded_ids = enc.encode(text[:first_period_index+1])
decoded_text = [enc.decode([encoded_id]) for encoded_id in encoded_ids]
print(encoded_ids)
print(decoded_text)


unique_tokens = list(set(enc.encode(text)))

token_count_dict = {}
for token in enc.encode(text):
    if token in token_count_dict:
        token_count_dict[token]['count'] += 1
    else:
        token_count_dict[token] = {'count': 1, 'token_id': enc.decode([token])}
# 多い順に並べ替え
token_count_dict = dict(sorted(token_count_dict.items(), key=lambda x: -x[1]['count']))
# 上位・下位5件を表示
print(list(token_count_dict.items())[:5])
print(list(token_count_dict.items())[-5:])
print('splitted', len(enc.encode(text)), 'unique_token', len(unique_tokens), 'vocab_size', enc.n_vocab)

[5962, 22307, 25, 198, 8421, 356, 5120, 597, 2252, 11, 3285, 502, 2740, 13]
['First', ' Citizen', ':', '\n', 'Before', ' we', ' proceed', ' any', ' further', ',', ' hear', ' me', ' speak', '.']
[(198, {'count': 39996, 'token_id': '\n'}), (11, {'count': 19777, 'token_id': ','}), (25, {'count': 10291, 'token_id': ':'}), (13, {'count': 7811, 'token_id': '.'}), (262, {'count': 5370, 'token_id': ' the'})]
[(16558, {'count': 1, 'token_id': ' sphere'}), (31960, {'count': 1, 'token_id': ' Wond'}), (22194, {'count': 1, 'token_id': ' possesses'}), (29708, {'count': 1, 'token_id': ' eyel'}), (30757, {'count': 1, 'token_id': 'stroke'})]
splitted 338025 unique_token 11706 vocab_size 50257


In [4]:
from ngram import Ngram
vocab = list(range(enc.n_vocab))
unigram = Ngram(1, vocab)
tokens = enc.encode(text)
unigram.train(tokens)
print('params of unigram', len(unigram.ngram)) 


print(enc.n_vocab ** 2, enc.n_vocab ** 3)
# bigram = Ngram(2, vocab)
# bigram.train(tokens)

params of unigram 50257
2525766049 126937424324593


In [5]:
# 上位・下位5件を表示
unigram_info = unigram.ngram
unigram_info = dict(sorted(unigram_info.items(), key=lambda x: -x[1]))
top_unigram = list(unigram_info.items())[:5]
bottom_unigram = list(unigram_info.items())[-5:]
print([(enc.decode([token[0]]), count) for token, count in top_unigram])
print([(enc.decode([token[0]]), count) for token, count in bottom_unigram])

[('\n', 39997), (',', 19778), (':', 10292), ('.', 7812), (' the', 5371)]
[('ominated', 1), (' regress', 1), (' Collider', 1), (' informants', 1), ('<|endoftext|>', 1)]


In [6]:
enc.n_vocab

50257

In [13]:
import torch
seed = 1337
torch.manual_seed(seed) 
batch_size = 4
context_length = 8
data = torch.tensor(enc.encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]


def get_batch(split):
    data = train_data if split == 'train' else val_data
    index = torch.randint(len(data) - context_length, (batch_size,))
    x = torch.stack([data[i:i+context_length] for i in index])
    y = torch.stack([data[i+1:i+1+context_length] for i in index])
    return x, y


x, y = get_batch('train')
print('input')
print(x.shape)
print(x)
print([[enc.decode([token])for token in sequence] for sequence in x])
print('target')
print(y.shape)
print(y)
print([[enc.decode([token])for token in sequence] for sequence in y])

for t in range(context_length):
    context = x[0, :t+1]
    target = y[0, t]
    print('input: ', [enc.decode([token]) for token in context], 'target: ', repr(enc.decode([target])))

input
torch.Size([4, 8])
tensor([[  198, 30313,   262, 22397,   282,   290,   884,  3790],
        [ 4151,   438,   198, 10418,   329,   511, 11989,    11],
        [ 3355,   322, 12105,   287,  3426,  6729,   198,  3886],
        [  290, 15581,  8636,    13,   198,   198, 35510,  4221]])
[['\n', 'Except', ' the', ' marsh', 'al', ' and', ' such', ' officers'], [' eye', '--', '\n', 'Men', ' for', ' their', ' sons', ','], [' wall', 'ow', ' naked', ' in', ' December', ' snow', '\n', 'By'], [' and', ' noble', ' estimate', '.', '\n', '\n', 'NOR', 'TH']]
target
torch.Size([4, 8])
tensor([[30313,   262, 22397,   282,   290,   884,  3790,   198],
        [  438,   198, 10418,   329,   511, 11989,    11, 17743],
        [  322, 12105,   287,  3426,  6729,   198,  3886,  3612],
        [15581,  8636,    13,   198,   198, 35510,  4221,  5883]])
[['Except', ' the', ' marsh', 'al', ' and', ' such', ' officers', '\n'], ['--', '\n', 'Men', ' for', ' their', ' sons', ',', ' wives'], ['ow', ' naked', '