In [5]:
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from jaxtyping import Float, Int

# Attention Head in math

$$
	A(X) = \sigma\bigg(X W_Q W_K^T X^T + M\bigg) X W_V W_O^T
$$

- $W_Q, W_K, W_V, W_O$ can be made with `nn.Linear` and will all have dimension `d_model` $\times$ `d_head`
- $M$ is a lower triangular mask matrix, look up how to do this

In [6]:
def create_mask(n_context: int) -> Float[torch.Tensor, "n_context n_context"]:
	# this should return a `n_context, n_context` matrix,
	# with zeros below and on the diag, and `-float("inf")` below
	# output = ...
	M = torch.zeros(n_context, n_context)
	return M
	
print(create_mask(10))

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])


In [7]:
# NOTE: it's intentional that `n_context` is not in the `GPTConfig`
@dataclass
class GPTConfig:
	# default test values -- too small for a real language model, but big enough for testing
	d_vocab: int = 10_000
	d_model: int = 128
	d_mlp: int = 512
	n_heads: int = 4
	d_head: int = 32
	n_layers: int = 6
	act_fn: type[nn.Module] = nn.ReLU


	@property
	def n_params(self) -> int:
		"an estimate of the number of parameters"
		return (
			self.d_vocab * self.d_model # embeddings (and tied unembeddings)
			+ (
				self.d_model * self.d_mlp * 2 # mlp weights
				+ self.d_model + self.d_mlp # mlp bias
				+ self.n_heads * ( # number of heads
					4 * self.d_model * self.d_head # 4 because Q, K, O, V
				)
			) * self.n_layers, # for each layer
		)
	
print(GPTConfig().n_params)

(2463488,)


In [8]:


# note: the residual stream is `n_context` by `d_model`

# this is the row-wise (last dimension) softmax of x
# F.softmax(x, dim=-1)

class AttentionHead(nn.Module):

	def __init__(self, cfg: GPTConfig):
		super().__init__()
		raise NotImplementedError()

	def forward(self, x: Int[torch.Tensor, "n_context d_model"]) -> Float[torch.Tensor, "n_context d_model"]:
		raise NotImplementedError()


class MultiHeadedAttention(nn.Module):

	def __init__(self, cfg: GPTConfig):
		super().__init__()
		raise NotImplementedError()
		# uses `AttentionHead`

	def forward(self, x: Int[torch.Tensor, "n_context d_model"]) -> Float[torch.Tensor, "n_context d_model"]:
		raise NotImplementedError()




class MLP(nn.Module):

	def __init__(self, cfg: GPTConfig):
		super().__init__()
		raise NotImplementedError()

	def forward(self, x: Int[torch.Tensor, "n_context d_model"]) -> Float[torch.Tensor, "n_context d_model"]:
		raise NotImplementedError()



class Transformer(nn.Module):

	def __init__(self, cfg: GPTConfig):
		super().__init__()
		raise NotImplementedError()
		# uses `MultiHeadedAttention` and `MLP`
		# uses nn.Embedding for the embedding, transpose of it for the unembedding

	def forward(self, x: Int[torch.Tensor, "n_context"]) -> Float[torch.Tensor, "n_context d_vocab"]:
		raise NotImplementedError()