In [23]:
import torch
import torch.nn as nn
from dataclasses import dataclass

import requests
import unicodedata

from jaxtyping import Int, Float
from collections import Counter
import numpy as np

# Classes

In [24]:
@dataclass
class Config:
    d_model: int
    d_vocab: int
    d_hidden: int
    n_context: int

In [None]:
# class Embedding(nn.Module):
#     def __init__(self):
#         super().__init__()
    
#     def forward(self):
#         pass

class Attention(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        # self.W_qk = nn.Linear(config.d_model, config.d_vocab)
        self.bilinear = nn.Bilinear(config.d_model, config.d_model, config.n_context, bias=False)
        self.M = torch.triu(torch.ones((config.n_context, config.n_context)), diagonal=1)
        self.M = self.M.masked_fill(self.M.bool(), -torch.inf)
        self.second_matmult = nn.Linear(config.d_model, config.d_model, bias=False)
        self.softmax = nn.Softmax()
    
    def forward(self, x):
        xwx = self.bilinear(x, x) # d_m x d_m
        x_masked = xwx+ self.M 
        x_softmaxed = self.softmax(x_masked)
        x_fin = x_softmaxed@x
        #multiply softmaxed by x
        #multiply that by wov
        x_fin = self.second_matmult(x_fin)
        return x_fin

class MLP(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.linear_up = nn.Linear(config.d_model, config.d_hidden)
        self.linear_down = nn.Linear(config.d_hidden, config.d_model)
    
    def forward(self, x):
        x = self.linear_up(x)
        x = torch.relu(x)
        x = self.linear_down(x)
        return x
    
class TransformerBlock(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.config = config

        self.MLP = MLP(config=self.config)
        self.Attention = Attention(config=self.config)
    
    def forward(self, x):
        return x + self.Attention(x) + self.MLP(x)
    
class Transformer(nn.Module):
    def __init__(self, config:Config):
        self.transformerBlock = nn.ModuleList([TransformerBlock(config) for i in range(2)])
    
    def forward(self, x):
        for i, l in enumerate(self.transformerBlock):
            x = self.transformerBlock[i](x)
        return 

$n_c$: Context window length

$d_m$: Model Dimension

$d_v$: Vocab Dimension

In [26]:
text_sample = "The quick brown fox jumped over the lazy dog."

In [27]:
d_model = 10
d_vocab = 10
d_hidden = 10
n_context = 10

conf = Config(d_model, d_vocab, d_hidden, n_context)

embedding = nn.Embedding(num_embeddings=conf.d_vocab, embedding_dim=conf.d_model)



# Tokenization Code

In [28]:
from pathlib import Path

def get_gutenberg_book(
	id: int | None = 84,
	data_temp: Path | str = "../data/gutenberg_data",
	remove_gutenberg_meta: bool = True,
) -> str:
	
	data_temp: Path = Path(data_temp)
	data_temp.mkdir(parents=True, exist_ok=True)
	
	url: str = f"https://www.gutenberg.org/cache/epub/{id}/pg{id}.txt"
	data_path: Path = Path(data_temp) / f"{id}.txt"
	data: str
	# read from cache if it exists
	if data_path.exists():
		with open(data_path, 'r', encoding='utf-8') as file:
			data = file.read()
	else:
		# download if it doesn't exist
		response: requests.Response = requests.get(url)
		response.raise_for_status()  # Ensure that the download was successful
		data = response.text

		# save to cache
		with open(data_path, 'w', encoding='utf-8') as file:
			file.write(data)

	# remove header/footer
	if remove_gutenberg_meta:
		data = '***'.join(data.split('***')[2:])
		data = '***'.join(data.split('***')[:-1])
	
	return data

def get_many_books(
		ids: list[int],
		data_temp: Path | str = "../data/gutenberg_data",
	) -> list[str]:
	
	data: list[str] = []
	for id in ids:
		print(f"Getting book {id}...")
		item: str = get_gutenberg_book(id, data_temp)
		print(f"\t{len(item)} characters read")
		data.append(item)
	
	return data

In [29]:
def process_text(
	text: str,
	allowed_punctuation: str = "-.,;:!?()\"\\" + "".join(str(x) for x in range(10)),
	punctuation_convert: dict[str, str] = {'â€”': '-'},
) -> str:
	
	# replace some special characters which unicode won't normalize properly
	for char, replacement in punctuation_convert.items():
		text = text.replace(char, replacement)

	# if a line has ".jpg" in it, remove that line (this is specific to Don Quixote)
	text = '\n'.join(
		line 
		for line in text.split('\n')
		if '.jpg' not in line
	)

	# Normalize the string to decompose Unicode characters
	text = unicodedata.normalize('NFKD', text)

	# Encode to ASCII bytes, then decode back to string, ignoring errors
	text = text.encode('ascii', 'ignore').decode('ascii')

	# remove newlines and tabs
	text = text.replace('\n', ' ').replace('\t', ' ')


	# put spaces around allowed punctuation
	for char in allowed_punctuation:
		text = text.replace(char, f' {char} ')


	# remove leading and trailing spaces
	text = text.strip()

	# remove multiple spaces
	while '  ' in text:
		text = text.replace('  ', ' ')


	# remove all characters except (alphanumeric, allowed_punctuation, ' ')
	text = ''.join(
		(
			char 
			if (
				char.isalnum() 
				or char in allowed_punctuation 
				or char == ' '
			)
			else ' '
		)
		for char in text 
	)

	# convert to lowercase
	text = text.lower()

	text = text.strip()

	return text

In [30]:
def tokenize(
	text: str,
	process: bool = False,
) -> list[str]:
	if process:
		text = process_text(text)
	return text.split(' ')

In [31]:
DATA_RAW: list[str] = get_many_books([84, 15, 18, 82, 996, 2600])
DATA: str = " ".join(process_text(x) for x in DATA_RAW)
DATA_TOKENIZED: list[str] = tokenize(DATA)

Getting book 84...
	419422 characters read
Getting book 15...
	1238469 characters read
Getting book 18...
	1172825 characters read
Getting book 82...
	1103796 characters read
Getting book 996...
	2299352 characters read
Getting book 2600...
	3208337 characters read


In [32]:
# sorted by frequency
VOCAB_FREQ: Counter[str] = Counter(DATA_TOKENIZED)
VOCAB_ARR: list[str] = [word for word, _ in VOCAB_FREQ.most_common()]
VOCAB_DICT: dict[str, int] = {word: i for i, word in enumerate(VOCAB_ARR)}

def encode(
	text: str | list[str],
) -> Int[np.ndarray, " n_tokens"]:
	if isinstance(text, str):
		text = tokenize(text)
	return np.array([VOCAB_DICT[word] for word in text])

def decode(
	encoded_text: Int[np.ndarray, " n_tokens"] | list[int],
) -> str:
	return ' '.join(VOCAB_ARR[i] for i in encoded_text)

DATA_ENCODED: Int[np.ndarray, " n_tokens"] = encode(DATA)

print(f"{DATA_ENCODED[:10] = }")

DATA_ENCODED[:10] = array([ 4675,    13,    40,     0,     1,  1587, 12126,    27,   278,
       24255])


# Tests

In [33]:
d_model = 10
d_vocab = 10
d_hidden = 10
n_context = 5

x = torch.randn((n_context, d_model))

conf = Config(d_model, d_vocab, d_hidden, n_context)
mlp = MLP(conf)
attention = Attention(conf)
Aoutput = attention(x)
print(Aoutput.shape)

output = mlp(x)
print(output)



torch.Size([5, 10])
tensor([[ 0.3862,  0.2528, -0.3346, -0.4032, -0.0439, -0.0101,  0.3036,  0.1906,
         -0.2531,  0.5045],
        [-0.1007, -0.2256,  0.2427, -0.2873,  0.0038,  0.0195,  0.2928,  0.4025,
         -0.0872,  0.5707],
        [ 0.1671, -0.1158,  0.0721, -0.0478, -0.0767, -0.0494,  0.2726,  0.2951,
         -0.0321,  0.2939],
        [ 0.3037, -0.1483, -0.0331, -0.1628, -0.0641,  0.0091,  0.2530,  0.2732,
         -0.0047,  0.2618],
        [-0.0203,  0.2980,  0.0064,  0.2197,  0.1812,  0.3969, -0.0825,  0.3056,
         -0.0644,  0.5064]], grad_fn=<AddmmBackward0>)


  return self._call_impl(*args, **kwargs)


In [34]:
# Transformer Block test

d_model = 10
d_vocab = 10
d_hidden = 10
n_context = 5

config = Config(
    d_model = d_model,
    d_vocab = d_vocab,
    d_hidden = d_hidden,
    n_context = n_context,
)

x = torch.randn((n_context, d_model))
conf = Config(d_model, d_vocab, d_hidden, n_context)

tb = TransformerBlock(config)

output_x = tb(x)
output_x


tensor([[ 3.8475, -2.0726,  1.0514,  0.8324, -0.4577, -1.6299,  0.7028,  2.8054,
         -0.8925, -1.0399],
        [ 2.9761, -0.6997, -0.6032,  2.6684,  0.1813,  1.4837, -0.1221,  0.9587,
          0.3077, -1.1505],
        [ 2.3151,  1.8852, -0.4046,  1.1583, -1.5020, -0.6166, -0.8820,  1.2308,
         -0.0742, -1.2055],
        [ 1.3883, -0.1457,  0.3347,  2.0383,  0.6824,  1.7610, -1.7258,  0.1451,
          0.2311, -2.7143],
        [-0.5221,  0.6496, -2.3000,  0.4011,  0.5412, -1.5619, -1.4592, -0.0642,
         -1.2541, -0.4855]], grad_fn=<AddBackward0>)