In [170]:
import torch
import torch.nn as nn
from dataclasses import dataclass

import requests
import unicodedata

from jaxtyping import Int, Float
from collections import Counter
import numpy as np

# Classes

In [171]:
@dataclass
class Config:
    d_model: int
    d_vocab: int
    d_hidden: int
    n_context: int
    n_layers: int

In [172]:
# class Embedding(nn.Module):
#     def __init__(self):
#         super().__init__()
    
#     def forward(self):
#         pass

class Attention(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        # self.W_qk = nn.Linear(config.d_model, config.d_vocab)
        self.bilinear = nn.Bilinear(config.d_model, config.d_model, config.n_context, bias=False)
        self.M = torch.triu(torch.ones((config.n_context, config.n_context)), diagonal=1)
        self.M = self.M.masked_fill(self.M.bool(), -torch.inf)
        self.second_matmult = nn.Linear(config.d_model, config.d_model, bias=False)
        self.softmax = nn.Softmax()
    
    def forward(self, x):
        xwx = self.bilinear(x, x) # d_m x d_m
        x_masked = xwx+ self.M 
        x_softmaxed = self.softmax(x_masked)
        x_fin = x_softmaxed@x
        #multiply softmaxed by x
        #multiply that by wov
        x_fin = self.second_matmult(x_fin)
        return x_fin

class MLP(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.linear_up = nn.Linear(config.d_model, config.d_hidden)
        self.linear_down = nn.Linear(config.d_hidden, config.d_model)
    
    def forward(self, x):
        x = self.linear_up(x)
        x = torch.relu(x)
        x = self.linear_down(x)
        return x
    
class TransformerBlock(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.config = config

        self.MLP = MLP(config=self.config)
        self.Attention = Attention(config=self.config)
    
    def forward(self, x):
        return x + self.Attention(x) + self.MLP(x)
    
class Transformer(nn.Module):
    def __init__(self, config:Config):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=config.d_vocab, embedding_dim=config.d_model)
        self.transformerBlock = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])

    def forward(self, x):
        x = self.embedding(x)
        for i, l in enumerate(self.transformerBlock):
            x = self.transformerBlock[i](x)
        return x

$n_c$: Context window length

$d_m$: Model Dimension

$d_v$: Vocab Dimension

In [173]:
text_sample = "The quick brown fox jumped over the lazy dog."

# Tokenization Code

In [174]:
from pathlib import Path

def get_gutenberg_book(
	id: int | None = 84,
	data_temp: Path | str = "../data/gutenberg_data",
	remove_gutenberg_meta: bool = True,
) -> str:
	
	data_temp: Path = Path(data_temp)
	data_temp.mkdir(parents=True, exist_ok=True)
	
	url: str = f"https://www.gutenberg.org/cache/epub/{id}/pg{id}.txt"
	data_path: Path = Path(data_temp) / f"{id}.txt"
	data: str
	# read from cache if it exists
	if data_path.exists():
		with open(data_path, 'r', encoding='utf-8') as file:
			data = file.read()
	else:
		# download if it doesn't exist
		response: requests.Response = requests.get(url)
		response.raise_for_status()  # Ensure that the download was successful
		data = response.text

		# save to cache
		with open(data_path, 'w', encoding='utf-8') as file:
			file.write(data)

	# remove header/footer
	if remove_gutenberg_meta:
		data = '***'.join(data.split('***')[2:])
		data = '***'.join(data.split('***')[:-1])
	
	return data

def get_many_books(
		ids: list[int],
		data_temp: Path | str = "../data/gutenberg_data",
	) -> list[str]:
	
	data: list[str] = []
	for id in ids:
		print(f"Getting book {id}...")
		item: str = get_gutenberg_book(id, data_temp)
		print(f"\t{len(item)} characters read")
		data.append(item)
	
	return data

In [175]:
def process_text(
	text: str,
	allowed_punctuation: str = "-.,;:!?()\"\\" + "".join(str(x) for x in range(10)),
	punctuation_convert: dict[str, str] = {'â€”': '-'},
) -> str:
	
	# replace some special characters which unicode won't normalize properly
	for char, replacement in punctuation_convert.items():
		text = text.replace(char, replacement)

	# if a line has ".jpg" in it, remove that line (this is specific to Don Quixote)
	text = '\n'.join(
		line 
		for line in text.split('\n')
		if '.jpg' not in line
	)

	# Normalize the string to decompose Unicode characters
	text = unicodedata.normalize('NFKD', text)

	# Encode to ASCII bytes, then decode back to string, ignoring errors
	text = text.encode('ascii', 'ignore').decode('ascii')

	# remove newlines and tabs
	text = text.replace('\n', ' ').replace('\t', ' ')


	# put spaces around allowed punctuation
	for char in allowed_punctuation:
		text = text.replace(char, f' {char} ')


	# remove leading and trailing spaces
	text = text.strip()

	# remove multiple spaces
	while '  ' in text:
		text = text.replace('  ', ' ')


	# remove all characters except (alphanumeric, allowed_punctuation, ' ')
	text = ''.join(
		(
			char 
			if (
				char.isalnum() 
				or char in allowed_punctuation 
				or char == ' '
			)
			else ' '
		)
		for char in text 
	)

	# convert to lowercase
	text = text.lower()

	text = text.strip()

	return text

In [176]:
def tokenize(
	text: str,
	process: bool = False,
) -> list[str]:
	if process:
		text = process_text(text)
	return text.split(' ')

In [177]:
# Getting books from Plato and Aristotle
DATA_RAW: list[str] = get_many_books([6762, 1497, 8438, 1600, 1656])
DATA: str = " ".join(process_text(x) for x in DATA_RAW)
DATA_TOKENIZED: list[str] = tokenize(DATA)

Getting book 6762...
	584611 characters read
Getting book 1497...
	1219052 characters read
Getting book 8438...
	648768 characters read
Getting book 1600...
	181248 characters read
Getting book 1656...
	87284 characters read


In [178]:
# sorted by frequency
VOCAB_FREQ: Counter[str] = Counter(DATA_TOKENIZED)
VOCAB_ARR: list[str] = [word for word, _ in VOCAB_FREQ.most_common()]
VOCAB_DICT: dict[str, int] = {word: i for i, word in enumerate(VOCAB_ARR)}

def encode(
	text: str | list[str],
) -> Int[np.ndarray, " n_tokens"]:
	if isinstance(text, str):
		text = tokenize(text)
	return np.array([VOCAB_DICT[word] for word in text])

def decode(
	encoded_text: Int[np.ndarray, " n_tokens"] | list[int],
) -> str:
	return ' '.join(VOCAB_ARR[i] for i in encoded_text)

DATA_ENCODED: Int[np.ndarray, " n_tokens"] = encode(DATA)

print(f"{DATA_ENCODED = }")
print(len(DATA_ENCODED))

DATA_ENCODED = array([1181,   25, 9326, ..., 4819, 4354, 1842], shape=(556819,))
556819


In [179]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, x):
        self.input = x[:-1]
        self.output = x[1:]
    
    def __len__(self):
        return len(self.input)
    
    def __getitem__(self, idx):
        inp = self.input[idx]
        out = self.output[idx]
        return inp, out

# Tests

In [180]:
d_model = 10
d_vocab = 10
d_hidden = 10
n_context = 5
n_layers = 10

x = torch.randn((n_context, d_model))

conf = Config(d_model, d_vocab, d_hidden, n_context, n_layers)
mlp = MLP(conf)
attention = Attention(conf)
Aoutput = attention(x)
print(Aoutput.shape)

output = mlp(x)
print(output)

torch.Size([5, 10])
tensor([[ 2.4382e-01, -3.3684e-01,  4.2113e-01,  2.5118e-01, -7.0882e-02,
         -1.2463e-01, -1.5439e-01,  1.8969e-01,  5.7837e-02,  2.6203e-01],
        [-2.0029e-01, -4.6962e-01,  5.1820e-01,  6.3898e-01,  7.1709e-01,
         -5.3796e-01, -6.4178e-04, -3.2100e-01,  7.9542e-01,  1.3639e-01],
        [-2.0988e-01, -8.1415e-01,  4.3507e-01,  5.6735e-02,  7.1053e-01,
         -8.0142e-01, -8.2580e-02, -8.2268e-01,  5.1185e-03,  1.4978e-01],
        [-2.6728e-02, -4.0513e-01,  5.8866e-01,  3.4510e-01,  5.5723e-01,
         -6.4959e-01, -6.8502e-02, -5.3238e-01,  4.3269e-01,  3.8217e-01],
        [ 3.5831e-01, -1.9775e-01,  1.9581e-01, -4.9297e-02, -1.9818e-01,
         -1.5570e-01, -1.8320e-01,  1.2315e-01,  1.5451e-02,  6.1414e-02]],
       grad_fn=<AddmmBackward0>)


In [181]:
# Transformer Block test

d_model = 10
d_vocab = len(VOCAB_DICT)
d_hidden = 10
n_context = 5
n_layers = 10

config = Config(
    d_model = d_model,
    d_vocab = d_vocab,
    d_hidden = d_hidden,
    n_context = n_context,
    n_layers = n_layers,
)

x = torch.randn((n_context, d_model))
conf = Config(d_model, d_vocab, d_hidden, n_context, n_layers)

tb = TransformerBlock(config)

output_x = tb(x)
output_x


tensor([[ 0.7564, -0.1901, -0.8199, -0.7058, -0.1407, -3.1887, -0.8120,  0.6575,
         -0.8009,  0.5691],
        [ 1.7548,  0.1369,  0.5451, -1.0349,  0.2123, -1.7916, -0.3687, -0.1708,
         -1.4231, -0.3296],
        [ 0.3693,  0.3610, -0.3178, -2.5366,  2.2737,  2.9342,  4.3649,  1.5633,
         -1.1154, -1.4197],
        [-1.0282,  0.9102,  1.1436,  0.5715,  0.6032, -1.1302, -0.1641, -0.2860,
         -0.4190,  0.8458],
        [ 0.0359, -0.4625,  2.2205, -0.5128,  0.5866, -0.4680, -1.0618,  0.2394,
          0.1510,  1.5151]], grad_fn=<AddBackward0>)

## Training Loop

In [None]:
conf = Config(d_model = 10, 
              d_vocab = len(VOCAB_DICT), 
              d_hidden = 10, 
              n_context = 10, 
              n_layers = 2
              )

training_data = torch.utils.data.TensorDataset(torch.from_numpy(DATA_ENCODED[:-1]),torch.from_numpy(DATA_ENCODED[1:]))
model = Transformer(config=conf)
training_loader = torch.utils.data.DataLoader(training_data, batch_size=4, shuffle=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
loss_fn = torch.nn.CrossEntropyLoss()


In [184]:
def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

In [188]:
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 5

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)


    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1

EPOCH 1:


RuntimeError: The size of tensor a (4) must match the size of tensor b (10) at non-singleton dimension 0