In [None]:
import torch
import torch.nn as nn
from dataclasses import dataclass

import requests
import unicodedata

In [9]:
@dataclass
class Config:
    d_model: int
    d_vocab: int
    d_hidden: int
    n_context: int

In [22]:
# class Embedding(nn.Module):
#     def __init__(self):
#         super().__init__()
    
#     def forward(self):
#         pass

class Attention(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        # self.W_qk = nn.Linear(config.d_model, config.d_vocab)
        self.bilinear = nn.Bilinear(config.d_model, config.d_model, config.n_context, bias=False)
        self.M = torch.triu(torch.ones((config.n_context, config.n_context)), diagonal=1)
        self.M = self.M.masked_fill(self.M.bool(), -torch.inf)
        self.second_matmult = nn.Linear(config.d_model, config.d_model, bias=False)
        self.softmax = nn.Softmax()
    
    def forward(self, x):
        xwx = self.bilinear(x, x) # d_m x d_m
        x_masked = xwx+ self.M 
        x_softmaxed = self.softmax(x_masked)
        x_fin = x_softmaxed@x
        #multiply softmaxed by x
        #multiply that by wov
        x_fin = self.second_matmult(x_fin)
        return x_fin

class MLP(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.linear_up = nn.Linear(config.d_model, config.d_hidden)
        self.linear_down = nn.Linear(config.d_hidden, config.d_model)
    
    def forward(self, x):
        x = self.linear_up(x)
        x = torch.relu(x)
        x = self.linear_down(x)
        return x
    
class TransformerBlock(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.config = config

        self.MLP = MLP(config=self.config)
        self.Attention = Attention(config=self.config)
    
    def forward(self, x):
        return x + self.Attention(x) + self.MLP(x)

$n_c$: Context window length

$d_m$: Model Dimension

$d_v$: Vocab Dimension

In [11]:
text_sample = "The quick brown fox jumped over the lazy dog."

In [None]:
d_model = 10
d_vocab = 10
d_hidden = 10
n_context = 10

conf = Config(d_model, d_vocab, d_hidden, n_context)

embedding = nn.Embedding(num_embeddings=conf.d_vocab, embedding_dim=conf.d_model)



In [None]:
from pathlib import Path

def get_gutenberg_book(
	id: int | None = 84,
	data_temp: Path | str = "../data/gutenberg_data",
	remove_gutenberg_meta: bool = True,
) -> str:
	
	data_temp: Path = Path(data_temp)
	data_temp.mkdir(parents=True, exist_ok=True)
	
	url: str = f"https://www.gutenberg.org/cache/epub/{id}/pg{id}.txt"
	data_path: Path = Path(data_temp) / f"{id}.txt"
	data: str
	# read from cache if it exists
	if data_path.exists():
		with open(data_path, 'r', encoding='utf-8') as file:
			data = file.read()
	else:
		# download if it doesn't exist
		response: requests.Response = requests.get(url)
		response.raise_for_status()  # Ensure that the download was successful
		data = response.text

		# save to cache
		with open(data_path, 'w', encoding='utf-8') as file:
			file.write(data)

	# remove header/footer
	if remove_gutenberg_meta:
		data = '***'.join(data.split('***')[2:])
		data = '***'.join(data.split('***')[:-1])
	
	return data

def get_many_books(
		ids: list[int],
		data_temp: Path | str = "../data/gutenberg_data",
	) -> list[str]:
	
	data: list[str] = []
	for id in ids:
		print(f"Getting book {id}...")
		item: str = get_gutenberg_book(id, data_temp)
		print(f"\t{len(item)} characters read")
		data.append(item)
	
	return data

In [None]:
def process_text(
	text: str,
	allowed_punctuation: str = "-.,;:!?()\"\\" + "".join(str(x) for x in range(10)),
	punctuation_convert: dict[str, str] = {'â€”': '-'},
) -> str:
	
	# replace some special characters which unicode won't normalize properly
	for char, replacement in punctuation_convert.items():
		text = text.replace(char, replacement)

	# if a line has ".jpg" in it, remove that line (this is specific to Don Quixote)
	text = '\n'.join(
		line 
		for line in text.split('\n')
		if '.jpg' not in line
	)

	# Normalize the string to decompose Unicode characters
	text = unicodedata.normalize('NFKD', text)

	# Encode to ASCII bytes, then decode back to string, ignoring errors
	text = text.encode('ascii', 'ignore').decode('ascii')

	# remove newlines and tabs
	text = text.replace('\n', ' ').replace('\t', ' ')


	# put spaces around allowed punctuation
	for char in allowed_punctuation:
		text = text.replace(char, f' {char} ')


	# remove leading and trailing spaces
	text = text.strip()

	# remove multiple spaces
	while '  ' in text:
		text = text.replace('  ', ' ')


	# remove all characters except (alphanumeric, allowed_punctuation, ' ')
	text = ''.join(
		(
			char 
			if (
				char.isalnum() 
				or char in allowed_punctuation 
				or char == ' '
			)
			else ' '
		)
		for char in text 
	)

	# convert to lowercase
	text = text.lower()

	text = text.strip()

	return text

In [None]:
def tokenize(
	text: str,
	process: bool = False,
) -> list[str]:
	if process:
		text = process_text(text)
	return text.split(' ')

In [None]:
DATA_RAW: list[str] = get_many_books([84, 15, 18, 82, 996, 2600])
DATA: str = " ".join(process_text(x) for x in DATA_RAW)
DATA_TOKENIZED: list[str] = tokenize(DATA)

In [24]:
d_model = 10
d_vocab = 10
d_hidden = 10
n_context = 5

x = torch.randn((n_context, d_model))

conf = Config(d_model, d_vocab, d_hidden, n_context)
mlp = MLP(conf)
attention = Attention(conf)
Aoutput = attention(x)
print(Aoutput.shape)

output = mlp(x)
print(output)



torch.Size([5, 10])
tensor([[ 0.3668, -0.0897,  0.2870, -0.4726,  0.0341, -0.2712,  0.2344, -0.0119,
          0.0806,  0.2910],
        [ 0.4433, -0.3422,  0.0211, -0.4508, -0.5492,  0.5918,  0.0260,  0.3143,
          0.2249,  0.0624],
        [ 0.1528,  0.1009,  0.4494, -1.0857, -0.3989,  0.2402,  0.2270,  0.0498,
          0.4418, -0.3364],
        [ 0.2066, -0.1957,  0.2769, -0.2693, -0.3526,  0.1198,  0.0411,  0.1876,
          0.1822,  0.0874],
        [ 0.1578, -0.0674,  0.4384, -0.4012, -0.3769,  0.0836, -0.0981, -0.0835,
          0.2477,  0.0381]], grad_fn=<AddmmBackward0>)


  return self._call_impl(*args, **kwargs)
