In [1]:
from pathlib import Path
from dataclasses import dataclass 
import requests
import numpy as np
import torch
import torch.nn as nn
from jaxtyping import Float, Int

import torch.optim as optim
import torch.nn.functional as F
from typing import List

import unicodedata
import json
from collections import Counter, defaultdict
import base64
from IPython.display import Image, display


@dataclass 
class Config:
    d_model : int 
    d_vocab : int
    d_hidden : int # for MLP
    n_layers : int
    # d_head : int # for Attn (if separate wq and wk)
    #no n_context
    #name var : type

# guttenburg dataset code in existing notebooks

### Dimensions

\begin{align*}
    &\text{d-model} = d_m & : & \text{model dimension (num neurons)} \\
    &d_v = \text{d-vocab} & : & \text{vocab dimension} \\
    &n_c = \text{n-context} & : & \text{context window (len of seq entered)}
\end{align*}

Where $d_n << d_m$

## Getting Data

In [None]:
def get_gutenberg_book(
	id: int|None = 84,
	data_temp: Path|str = "../data/gutenberg_data",
	remove_gutenberg_meta: bool = True,
) -> str:
	
	data_temp = Path(data_temp)
	data_temp.mkdir(parents=True, exist_ok=True)
	
	url: str = f"https://www.gutenberg.org/cache/epub/{id}/pg{id}.txt"
	data_path: Path = Path(data_temp) / f"{id}.txt"
	data: str
	# read from cache if it exists
	if data_path.exists():
		with open(data_path, 'r', encoding='utf-8') as file:
			data = file.read()
	else:
		# download if it doesn't exist
		response = requests.get(url)
		response.raise_for_status()  # Ensure that the download was successful
		data = response.text

		# save to cache
		with open(data_path, 'w', encoding='utf-8') as file:
			file.write(data)

	# remove header/footer
	if remove_gutenberg_meta:
		data = '***'.join(data.split('***')[2:])
		data = '***'.join(data.split('***')[:-1])
	
	return data

def get_many_books(
		ids: list[int],
		data_temp: Path|str = "../data/gutenberg_data",
	) -> list[str]:
	
	data: list[str] = []
	for id in ids:
		print(f"Getting book {id}...")
		item: str = get_gutenberg_book(id, data_temp)
		print(f"\t{len(item)} characters read")
		data.append(item)
	
	return data

DATA_RAW: list[str] = get_many_books([84, 15, 18, 82, 996, 2600])

## Multilayer Perceptron
$$
    \texttt{MLP}(\mathbf{X}) = W_d \cdot \sigma_{\texttt{ReLU}} (W_u \cdot x + b_u) + b_d, \qquad \texttt{MLP} : \mathbb{R}^{d_m} \to \mathbb{R}^{d_m}
$$

In [None]:
class MLP(nn.Module):
    def __init__(self, config: Config): # matrices to initialize
        super().__init__()
        self.linear_up: nn.Linear = nn.Linear(config.d_model, config.d_hidden)
        self.linear_down: nn.Linear = nn.Linear(config.d_hidden, config.d_model)
    
    def forward(self, x: Float[torch.Tensor, "* d_model"]) -> Float[torch.Tensor, "* d_model"]:
        x = self.linear_up(x)
        x = torch.relu(x)
        x = self.linear_down(x)
        return x  

## Attention Head
### Weight Matrix
$$
    \mathbf{W}_{QK} := \mathbf{W}_{Q} \cdot \mathbf{W}_{K}^T, \qquad \mathbf{W}_{Q}, \mathbf{W}_{K} \in \mathbb{R}^{d_m \times d_n}
$$

### Autoregressive Masking (M) Matrix
$$
    M_{i,j} = 
    \begin{cases}
        0 &j \geq i \\
        -\infty &j < i>
    \end{cases}
$$

### Forward Pass
$$A(\mathbf{X}) = \sigma_{\text{softmax}} (\mathbf{X} \; \mathbf{W}_{QK} \; \mathbf{X}^\text{T} + \mathbf{M}) \; \mathbf{X} \; \mathbf{W}_{OV}$$

In [None]:
class AttentionHead(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        # weights (use nn.parameter) to create a matrix to track gradients
        self.wqk = nn.Parameter(torch.randn(config.d_model, config.d_model))
        self.wov = nn.Parameter(torch.randn(config.d_model, config.d_model))

        ## Create M Matrix
    def M_matrix(n):
        # matrix with 0 at and below the diagonal and -inf above the diagonal
        M = torch.ones((n, n))
        M = torch.triu(M, diagonal=1)
        M = M.masked_fill(M == 1, float('-inf'))
        print(M)
        
    
    def forward(self, x: Float[torch.Tensor, "* d_model"]) -> Float[torch.Tensor, "* d_model"]:
        # use weights to compute Aâ¨‰
        # X as input: n_seq by d_model
        n_seq = x.shape[0]
        M = self.M_matrix(n_seq)
        attention_pattern = x @ self.wqk @ x.T + M
        attention_of_X = nn.softmax(attention_pattern) @ x @ self.wov
        
        return attention_of_X

### Transformer Block
$$
    \text{TB}(X) = X + A(X) + \text{MLP}(X), \qquad \text{TB}: \mathbb{R}^{n_c \times d_m} \to \mathbb{R}^{n_c \times d_m}
$$

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        #self.ln = nn.LayerNorm(config)
        self.mlp = MLP(config)
        self.attention = AttentionHead(config)


    def foward(self, x: Float[torch.Tensor, "* d_model"]) -> Float[torch.Tensor, "* d_model"]:
        #output = x + mlp(x) + attentionhead(x)   
        output_x = x + self.mlp(x) + self.attention(x)
        #x = self.ln(x_1)
        #x = self.ln(x_2)

        return(output_x)

In [None]:
class transformer(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        self.token_embedding = nn.Embedding(config.d_vocab, config.d_model)
        #self.transformerblocks = nn.modules list of transformer blocks
        self.transformerblocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
        
    def forward(self, x: Int[torch.Tensor, "n_context"]) -> Float[torch.Tensor, "n_context d_vocab"]:
        x = self.token_embedding # converts int d-vector to d-model vector

        #n_contex long - sequence if ints of length n  - float tneosry by n_model  and output is float tencsosr by d-vocab \n",
        #d_model to d_vocab transpose or do a lineear map  - unembed nn.linear
        
        
        
# use nn.ModuleList for TB seqeunce & MHA (to create a list of TBS)
# print(f"{x.shape = }") for debugging

# pick a unique dataset to train data on

# if traning models: aim for < 10 million parameters for now
#   sum(x.numel() for x in mymodel.parameters())
