In [None]:
from pathlib import Path
from dataclasses import dataclass 
import requests
import numpy as np
import torch
import torch.nn as nn
from jaxtyping import Float, Int

import torch.optim as optim
import torch.nn.functional as F
from typing import List

import unicodedata
import json
from collections import Counter, defaultdict
import base64
from IPython.display import Image, display


@dataclass 
class Config:
    d_model : int 
    d_vocab : int
    d_hidden : int
    #no n_context
    #name var : type

# guttenburg dataset code in existing notebooks

## Getting Data

In [None]:
def get_gutenberg_book(
	id: int|None = 84,
	data_temp: Path|str = "../data/gutenberg_data",
	remove_gutenberg_meta: bool = True,
) -> str:
	
	data_temp = Path(data_temp)
	data_temp.mkdir(parents=True, exist_ok=True)
	
	url: str = f"https://www.gutenberg.org/cache/epub/{id}/pg{id}.txt"
	data_path: Path = Path(data_temp) / f"{id}.txt"
	data: str
	# read from cache if it exists
	if data_path.exists():
		with open(data_path, 'r', encoding='utf-8') as file:
			data = file.read()
	else:
		# download if it doesn't exist
		response = requests.get(url)
		response.raise_for_status()  # Ensure that the download was successful
		data = response.text

		# save to cache
		with open(data_path, 'w', encoding='utf-8') as file:
			file.write(data)

	# remove header/footer
	if remove_gutenberg_meta:
		data = '***'.join(data.split('***')[2:])
		data = '***'.join(data.split('***')[:-1])
	
	return data

def get_many_books(
		ids: list[int],
		data_temp: Path|str = "../data/gutenberg_data",
	) -> list[str]:
	
	data: list[str] = []
	for id in ids:
		print(f"Getting book {id}...")
		item: str = get_gutenberg_book(id, data_temp)
		print(f"\t{len(item)} characters read")
		data.append(item)
	
	return data

DATA_RAW: list[str] = get_many_books([84, 15, 18, 82, 996, 2600])

$$
    \texttt{MLP}(\mathbf{X}) = W_d \cdot \sigma_{\texttt{ReLU}} (W_u \cdot x + b_u) + b_d
    $$

In [None]:
class MLP(nn.Module):
    def __init__(self, config: Config): # matrices to initialize
        super().__init__()
        self.linear_up: nn.Linear = nn.Linear(config.d_model, config.d_hidden)
        self.linear_down: nn.Linear = nn.Linear(config.d_hidden, config.d_model)
    
    def forward(self, x: Float[torch.Tensor, "* d_model"]) -> Float[torch.Tensor, "* d_model"]:
        x = self.linear_up(x)
        x = torch.relu(x)
        x = self.linear_down(x)
        return x  

## Attention Head
### Weight Matrix
$$
    \mathbf{W}_{QK} := \mathbf{W}_{Q} \cdot \mathbf{W}_{K}^T
$$

### Forward Pass
$$A(\mathbf{X}) = \sigma_{\text{softmax}} (\mathbf{X} \; \mathbf{W}_{QK} \; \mathbf{X}^\text{T} + \mathbf{M}) \; \mathbf{X} \; \mathbf{W}_{OV}$$

In [None]:
class AttentionHead(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        # weights (use nn.parameter) to create a matrix to track gradients
        self.wqk = nn.Parameter(torch.randn(config.d_model, config.d_model))
        self.wov = nn.Parameter(torch.randn(config.d_model, config.d_model))

        ## Create M Matrix
    def M_matrix(n):
        # matrix with 0 at and below the diagonal and -inf above the diagonal
        M = torch.ones((n, n))
        M = torch.triu(M, diagonal=1)
        M = M.masked_fill(M == 1, float('-inf'))
        print(M)
        
    
    def forward(self, x: Float[torch.Tensor, "* d_model"]) -> Float[torch.Tensor, "* d_model"]:
        # use weights to compute A
        # at X as input: n_seq by d_model
        n_seq = x.shape[0]
        # M = m_matrix of context window size
        M = self.M_matrix(n_seq)
        # 
        
        # @ for matrix multiplication
        
        return x

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        self.token_embedding = nn.Embedding(config.d_vocab, config.d_model)
    
    # mlp and attention head

In [None]:
class transformer(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        self.token_embedding = nn.Embedding(config.d_vocab, config.d_model)
        
