In [14]:
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from jaxtyping import Float, Int

In [15]:
torch.manual_seed(0)

<torch._C.Generator at 0x7fadaf95b590>

In [16]:
@dataclass
class GPTConfig:
	# default test values -- too small for a real language model, but big enough for testing
	d_vocab: int = 10_000
	d_model: int = 128
	d_mlp: int = 512
	n_heads: int = 4
	d_head: int = 32
	n_layers: int = 6
	act_fn: type[nn.Module] = nn.ReLU

	@property
	def n_params(self) -> int:
		"an estimate of the number of parameters"
		return (
			self.d_vocab * self.d_model # embeddings (and tied unembeddings)
			+ (
				self.d_model * self.d_mlp * 2 # mlp weights
				+ self.d_model + self.d_mlp # mlp bias
				+ self.n_heads * ( # number of heads
					4 * self.d_model * self.d_head # 4 because Q, K, O, V
				)
			) * self.n_layers, # for each layer
		)
	
print(GPTConfig().n_params)

# note: the residual stream is `n_context` by `d_model`

# this is the row-wise (last dimension) softmax of x
# F.softmax(x, dim=-1)

class AttentionHead(nn.Module):

	def __init__(self, cfg: GPTConfig):
		super().__init__()
		self.relu = nn.ReLU()
		self.d_vocab = cfg.d_vocab
		self.d_model = cfg.d_model
		self.d_head = cfg.d_head
		self.wq = nn.Linear(self.d_model, self.d_head)
		self.wk = nn.Linear(self.d_model, self.d_head)
		self.wv = nn.Linear(self.d_model, self.d_head)
		self.wo = nn.Linear(self.d_head, self.d_model)

	
	
	def forward(self, x: Int[torch.Tensor, "n_context d_model"]) -> Float[torch.Tensor, "n_context d_model"]:
		def masking_matrix(n_context):
			mask = torch.zeros((n_context, n_context))  # Start with all 0s
			mask[torch.triu(torch.ones((n_context, n_context)), diagonal=1) == 1] = -float('inf')  # Set above diagonal to -inf
			return mask
		
		M = masking_matrix(x.shape[0])
		# softmax_argument = x*self.wq*torch.transpose(self.wk)*torch.transpose(x) + M
		wk_out = torch.transpose(self.wk(x), 0, 1)
		print("WK shape ", wk_out.shape)
		wq_out = self.wq(x)
		print("WQ shape ", wq_out.shape)
		softmax_out = F.softmax((wq_out@wk_out + M), dim=-1)
		print("Softmax shape ", softmax_out.shape)
		wv_out = self.wv(x)
		print("WV shape ", wv_out.shape)
		wo_out = self.wo(wv_out)
		# wo_out = self.wo(wv_out)
		result = softmax_out@wo_out
		print("Final A Shape ", result.shape)
		return result
		
		
class MultiHeadedAttention(nn.Module):

	def __init__(self, cfg: GPTConfig):
		super().__init__()
		self.n_heads = cfg.n_heads
		self.d_model = cfg.d_model
		self.d_head = cfg.d_head
		self.wo = nn.Linear(self.d_model, self.d_head)
		self.attention_heads = nn.ModuleList([AttentionHead(cfg) for _ in range(self.n_heads)])


	def forward(x: Int[torch.Tensor, "n_context d_model"]) -> Float[torch.Tensor, "n_context d_model"]:
		head_outputs = [head(x) for head in self.attention_heads]
		print(head_outputs)




class MLP(nn.Module):

	def __init__(self, cfg: GPTConfig):
		super().__init__()
		raise NotImplementedError()

	def forward(x: Int[torch.Tensor, "n_context d_model"]) -> Float[torch.Tensor, "n_context d_model"]:
		raise NotImplementedError()



class Transformer(nn.Module):

	def __init__(self, cfg: GPTConfig):
		super().__init__()
		raise NotImplementedError()
		# uses `MultiHeadedAttention` and `MLP`
		# uses nn.Embedding for the embedding, transpose of it for the unembedding

	def forward(x: Int[torch.Tensor, "n_context"]) -> Float[torch.Tensor, "n_context d_vocab"]:
		raise NotImplementedError()

(2463488,)


In [17]:
gpt_config = GPTConfig()
attn_head = AttentionHead(gpt_config)
x = torch.randn(256, gpt_config.d_model)
# print(x)
# print(x.shape)
attn_head.forward(x)


WK shape  torch.Size([32, 256])
WQ shape  torch.Size([256, 32])
Softmax shape  torch.Size([256, 256])
WV shape  torch.Size([256, 32])
Final A Shape  torch.Size([256, 128])


tensor([[ 0.1896,  0.0638, -0.0859,  ..., -0.1011, -0.2527,  0.2130],
        [ 0.4502,  0.2570, -0.0370,  ...,  0.1842, -0.0311,  0.0949],
        [ 0.3652, -0.5231, -0.1242,  ..., -0.1256,  0.3270,  0.7341],
        ...,
        [ 0.1972,  0.2424, -0.0650,  ..., -0.2720, -0.0777,  0.0690],
        [ 0.1726,  0.3053, -0.1114,  ..., -0.1802,  0.0297,  0.0239],
        [ 0.2779,  0.2770, -0.1763,  ..., -0.1758,  0.0078,  0.1137]],
       grad_fn=<MmBackward0>)