In [1]:
# Props to this sensei
# https://www.youtube.com/watch?v=kCc8FmEb1nY&list=PLAqhIrjkxbuWI23v9cThsA9GvCAUhRvKZ&index=8

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm # progress bar

## Hyper-parameters

In [25]:
batch_size = 64
text_file = "tiny-shakespeare.txt"



## Reading Data

In [26]:
# read file
with open(text_file, "r") as f:
    text = f.read()
text[:100]

'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou'

In [4]:
char_list = sorted(list(set(text)))
char_size = len(char_list)
print(f"All the characters in the text: {''.join(char_list)}")
print(f"Length of the characters: {char_size}")

All the characters in the text: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
Length of the characters: 65


## Tokenizer (character based, index/ascii)

In [5]:
class MyTokenizer:
    def __init__(self):
        self.char_to_index = None
        self.index_to_char = None

    def fit(self, char_list):  
        self.char_to_index = {char: idx for idx, char in enumerate(char_list)}
        self.index_to_char = {idx: char for char, idx in self.char_to_index.items()}

    def encode_index(self, input_str):
        return [self.char_to_index[char] for char in input_str]

    def decode_index(self, encoded_list):
        return ''.join([self.index_to_char[idx] for idx in encoded_list])

    @staticmethod
    def ascii_tokenizer(char):
        return ord(char)

    @staticmethod
    def ascii_decoder(ascii_value):
        return chr(ascii_value)

    def encode_combined(self, input_str, use_ascii=False):
        if use_ascii:
            return [self.ascii_tokenizer(char) for char in input_str]
        else:
            return self.encode_index(input_str)

    def decode_combined(self, encoded_list, use_ascii=False):
        if use_ascii:
            return ''.join([self.ascii_decoder(ascii_value) for ascii_value in encoded_list])
        else:
            return self.decode_index(encoded_list)

In [6]:
# Example usage:
tokenizer = MyTokenizer()
tokenizer.fit(char_list)

input_str = "Hello there"
encoded_list_ascii = tokenizer.encode_combined(input_str, use_ascii=True)
decoded_str_ascii = tokenizer.decode_combined(encoded_list_ascii, use_ascii=True)

encoded_list_index = tokenizer.encode_combined(input_str, use_ascii=True)
decoded_str_index = tokenizer.decode_combined(encoded_list_index, use_ascii=True)

print("Original String:", input_str)
print("Encoded List (ASCII):", encoded_list_ascii)
print("Decoded String (ASCII):", decoded_str_ascii)

print("Encoded List (Index):", encoded_list_index)
print("Decoded String (Index):", decoded_str_index)


Original String: Hello there
Encoded List (ASCII): [72, 101, 108, 108, 111, 32, 116, 104, 101, 114, 101]
Decoded String (ASCII): Hello there
Encoded List (Index): [72, 101, 108, 108, 111, 32, 116, 104, 101, 114, 101]
Decoded String (Index): Hello there


In [7]:
# Encode all the data 
encoded_data = tokenizer.encode_combined(text) 
encoded_data[:10]

[18, 47, 56, 57, 58, 1, 15, 47, 58, 47]

## Data Loader

In [8]:
data = torch.tensor(encoded_data)
data.shape[0]

1115393

In [30]:
class MyDataset(Dataset):
    def __init__(self, encoded_data):
        self.encoded_data = encoded_data

    def __len__(self):
        return len(self.encoded_data)

    def __getitem__(self, idx):
        return self.encoded_data[idx]
        
dataset = MyDataset(data)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

for batch in dataloader:
    print(batch)
    break


tensor([57, 42, 61, 51, 58, 51, 45, 42, 53, 47, 39, 50, 59, 32,  8,  1, 39, 51,
         1, 46,  1, 47,  1, 53, 27, 44,  0, 46,  1, 46, 53, 59, 58,  0, 56,  0,
         0,  1, 43, 31, 57,  1, 46, 52, 63,  6, 57, 46, 32, 25,  1, 57, 39, 39,
         1, 30,  1, 63, 46, 47, 57, 53, 54, 43])


In [10]:
device = ("cuda" if torch.cuda.is_available() else "mps"
          if torch.backends.mps.is_available() else "cpu")
print(f"Using {device} device")

Using cpu device
