In [None]:
from main import *
from bpe import BPE

from pathlib import Path
from pprint import pprint
from typing import Iterator, Tuple, List
from time import time
import random

from torch import tensor
import jsonlines
import matplotlib.pyplot as plt



%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [None]:
DATASET = Path('/extra/diego/the-pile/29.jsonl')

In [None]:
def dataset(deterministic=True, max_length=1_000_000) -> Iterator[str]:
    total_length = 0
    while True:
        for doc in jsonlines.open(DATASET):
            doc = doc['text'] if isinstance(doc, dict) else doc
            if deterministic or (not deterministic and random.random() < 0.01):
                total_length += len(doc)
                if total_length > max_length > 0:
                    doc = doc[:max_length - total_length]
                yield doc

            if total_length >= max_length > 0:
                return


def get_text(length: int, from_=dataset(False, -1)) -> str:
    text = ''
    while len(text) < length:
        text += next(from_)
    return text[:length]

sum(map(len, dataset()))

# Tokenizer


In [None]:
bpe = BPE.train_from_text(dataset(), 10000, 2)
bpe.save(input('Save to:') + '.pt')

In [None]:
bpe = BPE.load('bpe.pt')

In [None]:
pprint(bpe.token_frequencies)

## Performance


In [None]:
import timeit

bpe = BPE.load('bpe.pt')

In [None]:
# Compute time per call for different batch sizes and text sizes
results = {}
for batch_size in [1, 10, 100]:
    results[batch_size] = {}
    for text_size in range(2, 10):
        text_size = 4**text_size
        if text_size * batch_size > 10**6:
            continue
        # Start with a small number of runs
        num_runs = 5
        t = 0.0
        while t < 0.2:
            # Time the function call using the current number of runs
            t = timeit.timeit("bpe.tokenize(texts)",
                              setup="texts = [get_text(text_size) for _ in range(batch_size)]",
                              globals=globals(),
                              number=num_runs)
            # Double the number of runs for the next iteration
            num_runs *= 2
        # Store the time and number of runs in the results dictionary
        results[batch_size][text_size] = (t / num_runs, num_runs)
        print(f'batch_size={batch_size}, text_size={text_size}: {t / num_runs:.4f}s per call')

In [None]:
# Plot results
for block_size, data in results.items():
    xs = [block_size * b for b in data]
    ys = [t for t, _ in data.values()]
    plt.loglog(xs, ys, label=f'batch_size={block_size}')
plt.legend()
plt.xlabel('Total text size')
plt.ylabel('Time per call (s)')


## Comparison with gpt2 tokenizer


In [None]:
# Print all tokens in GPT-2 vocabulary
from transformers import GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
gpt_tokens = [tokenizer.decode([i]) for i in range(tokenizer.vocab_size)]

In [None]:
' Python' in gpt_tokens

# Training


In [None]:
@typechecked
def train(model,
          optim,
          data_generator: Iterator[Tuple[TT['batch', 'token', int], TT['batch', 'token', int]]],
          batch_size: int = 32,
          max_time: float = 60.0):
    model.train()
    lost = 0.0
    start_time = time()
    batch = 0
    while time() - start_time < max_time:
        optim.zero_grad()

        # Build the batch
        xs, ys = zip(*(next(data_generator) for _ in range(batch_size)))
        xs = torch.stack(xs)
        ys = torch.stack(ys)
        
        loss = model.loss(xs, ys)
        loss.backward()
        optim.step()

        if lost == 0.0:
            lost = loss.item()
        else:
            lost = 0.99 * lost + 0.01 * loss.item()
        if batch % 100 == 0:
            print(f'Batch {batch} loss: {lost:.4f}')
        batch += 1

## Case recover model


In [None]:
embedding_dim = 64
block_size = 100
head_count = 4
depth = 4

model = UpcasingTransformer(embedding_dim, depth=depth, head_count=head_count, block_size=block_size)
optim = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [None]:
# Dataset generator
@typechecked
def upcase_dataset(
    block_size: int, data_source: Iterator[str]
) -> Iterator[Tuple[TT['token', int], TT['token', int]]]:
    for doc in data_source:
        doc = doc.encode('utf-8')
        # print("Document:", len(doc), "bytes\t", doc[:30])
        for i in range(0, len(doc), block_size):
            batch = doc[i:i + block_size]
            if len(batch) < block_size:
                batch = batch + b' ' * (block_size - len(batch))
            
            xs = tensor(list(batch.lower()))
            ys = xs != tensor(list(batch))
            yield xs, ys.long()


class ByteTokenizer:
    """A simple tokenizer that encodes strings as their utf-8 byte values."""
    
    @staticmethod
    def tokenize(texts: List[str], pad_to=None) -> TT['batch', 'token', int]:
        encoded = [list(t.encode('utf-8')) for t in texts]
        # pad to max length
        if pad_to is None:
            pad_to = max(len(t) for t in encoded)
        else: 
            encoded = [t[:pad_to] for t in encoded]  # truncate if too long
        encoded = [t + [0] * (pad_to - len(t)) for t in encoded]
        return tensor(encoded)

    @staticmethod
    def detokenize(tokens: TT['batch', 'token', int]) -> list[str]:
        return [bytes(t).decode('utf-8', errors='replace') for t in tokens]


In [None]:
# Find the proportion of uppercase letters in the dataset
total = 0
upcase = 0
for doc in dataset():
    total += len(doc)
    upcase += sum(1 for c in doc if c.isupper())
up = upcase / total
low = 1 - up

print("Proportion of uppercase letters:", up)
print("Proportion of lowercase letters:", low)
print("Weight of upcase:", 1 / up)
print("Weight of lowcase:", 1 / low)

In [None]:
train(model, optim, upcase_dataset(block_size, dataset(False, -1)), max_time=30)

In [None]:
# Check if it works
text = get_text(block_size)

print(f"Prompt: {text.lower()!r}")
print(f"Expect: {text!r}")

out = repr(model.predict(text.lower()))
probas = model(ByteTokenizer.tokenize([text.lower()], block_size))[0]

diffs = ''.join(' ' if a == b else '^' for a, b in zip(repr(text), out))
bad_diffs = ''.join(' u'[b.isupper()] for b in out)

print(f"Output: {out}")
print(f"Difes:  {diffs}")
print(f"Bad:    {bad_diffs}")
print(diffs.count('^'), 'differences')
probas = probas.softmax(1)
print("%.2f" % max(probas[:,1]).item(), 'max proba')
for flip, d in zip(probas[:,1], diffs[1:]):
    print(f"Flip: {flip:.2f}", end="")
    if d == '^':
        print(" <---")
    else:
        print()

In [None]:
# Small-pile
file = 'small-pile.jsonl'
with jsonlines.open(file, 'w') as writer:
    for doc in dataset():
        writer.write(doc)