In [4]:
# torch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import matplotlib.pyplot as plt
from tqdm import tqdm

import requests
import os
import re
import collections

torch.manual_seed(305)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
input_file_path = 'full_shakespeare.txt'

if not os.path.exists(input_file_path):
    data_url = 'https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt'
    with open(input_file_path, 'w') as f:
        f.write(requests.get(data_url).text)

with open(input_file_path, 'r') as f:
    data = f.read()
print(f"length of dataset in characters: {len(data):,}")

length of dataset in characters: 4,573,338


In [6]:
chars = sorted(list(set(data)))
vocab_size = len(chars)
print("all the unique characters:", ''.join(chars))
print(f"vocab size: {vocab_size:,}")

all the unique characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]abcdefghijklmnopqrstuvwxyz
vocab size: 67


In [10]:

def build_initial_vocab(text):
    vocab = collections.Counter()
    tokens = re.split(r'(\s+)', text)
    for token in tokens:
        if token:
            # Represent the token as a tuple of characters (preserving whitespace)
            tokenized_token = tuple(token)
            vocab[tokenized_token] += 1
    return vocab

def get_pair_stats(vocab):
    pair_freqs = collections.Counter()
    for tokenized_word, freq in vocab.items():
        tokens = list(tokenized_word)
        for i in range(len(tokens) - 1):
            pair = (tokens[i], tokens[i+1])
            pair_freqs[pair] += freq
    return pair_freqs

def merge_vocab(pair, vocab):
    merged_token = "".join(pair)
    new_vocab = {}
    for tokenized_word, freq in vocab.items():
        tokens = list(tokenized_word)
        new_tokens = []
        i = 0
        while i < len(tokens):
            if i < len(tokens) - 1 and (tokens[i], tokens[i+1]) == pair:
                new_tokens.append(merged_token)
                i += 2
            else:
                new_tokens.append(tokens[i])
                i += 1
        new_vocab[tuple(new_tokens)] = freq
    return new_vocab

def bpe_tokenizer(text, desired_vocab_size):
    vocab = build_initial_vocab(text)
    merges = []
    
    while True:
        current_tokens = set()
        for tokenized_word in vocab:
            current_tokens.update(tokenized_word)  # tokenized_word is now a tuple
        
        if len(current_tokens) >= desired_vocab_size:
            print(len(current_tokens))
            break
        
        pair_stats = get_pair_stats(vocab)
        if not pair_stats:
            break
        
        best_pair = max(pair_stats, key=pair_stats.get)
        merges.append(best_pair)
        
        vocab = merge_vocab(best_pair, vocab)
    
    return merges, vocab, current_tokens

def apply_bpe(word, merges):
    # Start with a list of characters for the word
    tokens = list(word)
    for merge in merges:
        i = 0
        while i < len(tokens) - 1:
            if (tokens[i], tokens[i+1]) == merge:
                tokens = tokens[:i] + ["".join(merge)] + tokens[i+2:]
                i = max(i-1, 0)
            else:
                i += 1
    return tokens

def bpe_tokenize(text, merges):
    # Split text preserving whitespace
    pieces = re.split(r'(\s+)', text)
    pieces = [p for p in pieces if p]  # remove empty tokens
    output = []
    for piece in pieces:
        # If piece is purely whitespace, keep it as is.
        if piece.isspace():
            output.append(piece)
        else:
            output.extend(apply_bpe(piece, merges))
    return output

In [None]:
desired_vocab_size = 2000
merges, final_vocab, learned_tokens = bpe_tokenizer(data, desired_vocab_size)

sorted_tokens = sorted(learned_tokens, key=lambda x: (len(x), x))
stoi = { token: i for i, token in enumerate(sorted_tokens) }
itos = { i: token for token, i in stoi.items() }

def encode_bpe(text):
    tokens = bpe_tokenize(text, merges)
    return [stoi[t] for t in tokens if t in stoi]

def decode_bpe(indices):
    return ''.join([itos[i] for i in indices])

n = len(data)
train_text = data[:int(n*0.9)]
val_text = data[int(n*0.9):]

train_tokens = encode_bpe(train_text)
val_tokens = encode_bpe(val_text)

train_data = torch.tensor(train_tokens)
val_data = torch.tensor(val_tokens)

print(f"train has {len(train_data):,} tokens")
print(f"val has {len(val_data):,} tokens")