# Anatomy of a GPT

*An investigation into the internal states of the model. Spoiler: The only thing falling here is the Loss Curve*

![Anatomy of a GPT](../assets/anatomy-of-a-gpt.jpg)

We start by loading the base `config.json`

In [24]:
import json
import types

with open("../config.json") as fd:
    config = json.load(fd, object_hook=lambda d: types.SimpleNamespace(**d))
config

namespace(hidden_size=256,
          intermediate_size=1024,
          num_hidden_layers=4,
          max_position_embeddings=256,
          tie_word_embeddings=True,
          num_attention_heads=4,
          num_key_value_heads=2,
          head_dim=64,
          dropout_p=0.1,
          seed=1728,
          vocab_size=8192,
          special_tokens=['<|startoftext|>', '<|endoftext|>'],
          batch_size=4)

We then load the model and check its internals

In [26]:
import sys

sys.path.append("../src")
import model

import torch
import torchinfo

m = model.Transformer(config)
torchinfo.summary(
    m,
    input_size=(config.batch_size, config.max_position_embeddings),
    dtypes=[torch.long],
)

Layer (type:depth-idx)                        Output Shape              Param #
Transformer                                   [4, 256, 8192]            --
├─Embedding: 1-1                              [4, 256, 256]             2,097,152
├─Embedding: 1-2                              [256, 256]                65,536
├─Dropout: 1-3                                [4, 256, 256]             --
├─ModuleList: 1-4                             --                        --
│    └─DecoderLayer: 2-1                      [4, 256, 256]             --
│    │    └─RMSNorm: 3-1                      [4, 256, 256]             256
│    │    └─GroupedQueryAttention: 3-2        [4, 256, 256]             196,608
│    │    └─RMSNorm: 3-3                      [4, 256, 256]             256
│    │    └─FeedForward: 3-4                  [4, 256, 256]             524,288
│    └─DecoderLayer: 2-2                      [4, 256, 256]             --
│    │    └─RMSNorm: 3-5                      [4, 256, 256]             