In [1]:
import torch
from train import GPT
from torchinfo import summary

In [None]:
torch.manual_seed(123)
gpt = GPT(
    vocab_size=50257,
    context_length=1024,
    emb_dim=768,
    ff_int_dim_mult=4,
    n_heads=12,
    n_layers=12,
    drop_rate=0.1,
    qkv_bias=False
)

from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
    
count_parameters(gpt)

+----------------------------------------+------------+
|                Modules                 | Parameters |
+----------------------------------------+------------+
|            embedding.weight            |  38597376  |
|      positional_embedding.weight       |   786432   |
|       transformers.0.ln_1.weight       |    768     |
|        transformers.0.ln_1.bias        |    768     |
| transformers.0.attention.q_mat.weight  |   589824   |
| transformers.0.attention.k_mat.weight  |   589824   |
| transformers.0.attention.v_mat.weight  |   589824   |
|  transformers.0.attention.out.weight   |   589824   |
|   transformers.0.attention.out.bias    |    768     |
|       transformers.0.ln_2.weight       |    768     |
|        transformers.0.ln_2.bias        |    768     |
|    transformers.0.MLP.in_ff.weight     |  2359296   |
|     transformers.0.MLP.in_ff.bias      |    3072    |
|    transformers.0.MLP.out_ff.weight    |  2359296   |
|     transformers.0.MLP.out_ff.bias     |    76

163009536