In [2]:
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from jaxtyping import Float, Int

# Attention Head in math

$$
	A(X) = \sigma\bigg(X W_Q W_K^T X^T + M\bigg) X W_V W_O^T
$$

- $W_Q, W_K, W_V, W_O$ can be made with `nn.Linear` and will all have dimension `d_model` $\times$ `d_head`
- $M$ is a lower triangular mask matrix, look up how to do this

In [None]:
@dataclass
class GPTConfig:
	d_vocab: int
	d_model: int
	d_mlp: int
	n_heads: int
	d_head: int
	n_layers: int
	act_fn: type[nn.Module] = nn.ReLU

# note: the residual stream is `n_context` by `d_model`

# this is the row-wise (last dimension) softmax of x
# F.softmax(x, dim=-1)

class AttentionHead(nn.Module):

	def __init__(self, cfg: GPTConfig):
		super().__init__()
		raise NotImplementedError()

	def forward(x: Int[torch.Tensor, "n_context d_model"]) -> Float[torch.Tensor, "n_context d_model"]:
		raise NotImplementedError()


class MultiHeadedAttention(nn.Module):

	def __init__(self, cfg: GPTConfig):
		super().__init__()
		raise NotImplementedError()
		# uses `AttentionHead`

	def forward(x: Int[torch.Tensor, "n_context d_model"]) -> Float[torch.Tensor, "n_context d_model"]:
		raise NotImplementedError()




class MLP(nn.Module):

	def __init__(self, cfg: GPTConfig):
		super().__init__()
		raise NotImplementedError()

	def forward(x: Int[torch.Tensor, "n_context d_model"]) -> Float[torch.Tensor, "n_context d_model"]:
		raise NotImplementedError()



class Transformer(nn.Module):

	def __init__(self, cfg: GPTConfig):
		super().__init__()
		raise NotImplementedError()
		# uses `MultiHeadedAttention` and `MLP`
		# uses nn.Embedding for the embedding, transpose of it for the unembedding

	def forward(x: Int[torch.Tensor, "n_context"]) -> Float[torch.Tensor, "n_context d_vocab"]:
		raise NotImplementedError()