In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

print("="*60)
print("BUILDING GPT MODEL FROM SCRATCH")
print("="*60)

# Load data info
checkpoint = torch.load('data_loaders.pt')
vocab_size = checkpoint['vocab_size']
seq_len = checkpoint['seq_len']

print(f"\nVocabulary size: {vocab_size}")
print(f"Sequence length: {seq_len}")

# Model hyperparameters
embedding_dim = 128
num_heads = 4
num_layers = 4
ff_dim = 512
dropout = 0.1

print(f"\nModel hyperparameters:")
print(f"  Embedding dimension: {embedding_dim}")
print(f"  Number of heads: {num_heads}")
print(f"  Number of layers: {num_layers}")
print(f"  Feed-forward dimension: {ff_dim}")
print(f"  Dropout: {dropout}")

print("\n" + "="*60)
print("COMPONENT 1: MULTI-HEAD ATTENTION")
print("="*60)

class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dim, num_heads):
        super().__init__()
        assert embedding_dim % num_heads == 0
        
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.head_dim = embedding_dim // num_heads
        
        self.W_q = nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.W_k = nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.W_v = nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.W_o = nn.Linear(embedding_dim, embedding_dim, bias=False)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # Split into multiple heads
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Attention
        attention_scores = Q @ K.transpose(-2, -1) / (self.head_dim ** 0.5)
        
        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask == 0, float('-inf'))
        
        attention_weights = F.softmax(attention_scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        attended = attention_weights @ V
        
        # Concatenate heads
        attended = attended.transpose(1, 2).contiguous()
        attended = attended.view(batch_size, seq_len, self.embedding_dim)
        
        output = self.W_o(attended)
        
        return output, attention_weights

print("✓ Multi-Head Attention defined")

print("\n" + "="*60)
print("COMPONENT 2: FEED-FORWARD NETWORK")
print("="*60)

class FeedForward(nn.Module):
    def __init__(self, embedding_dim, ff_dim, dropout):
        super().__init__()
        self.linear1 = nn.Linear(embedding_dim, ff_dim)
        self.linear2 = nn.Linear(ff_dim, embedding_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        # Expand
        x = self.linear1(x)
        x = F.gelu(x)
        x = self.dropout(x)
        
        # Compress
        x = self.linear2(x)
        x = self.dropout(x)
        
        return x

print("✓ Feed-Forward Network defined")

print("\n" + "="*60)
print("COMPONENT 3: TRANSFORMER BLOCK")
print("="*60)

class TransformerBlock(nn.Module):
    def __init__(self, embedding_dim, num_heads, ff_dim, dropout):
        super().__init__()
        
        self.attention = MultiHeadAttention(embedding_dim, num_heads)
        self.feed_forward = FeedForward(embedding_dim, ff_dim, dropout)
        
        self.norm1 = nn.LayerNorm(embedding_dim)
        self.norm2 = nn.LayerNorm(embedding_dim)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # Attention with residual
        attended, attn_weights = self.attention(self.norm1(x), mask)
        x = x + self.dropout(attended)
        
        # Feed-forward with residual
        ff_out = self.feed_forward(self.norm2(x))
        x = x + self.dropout(ff_out)
        
        return x, attn_weights

print("✓ Transformer Block defined")

print("\n" + "="*60)
print("COMPONENT 4: COMPLETE GPT MODEL")
print("="*60)

class GPT(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_heads, num_layers, ff_dim, seq_len, dropout):
        super().__init__()
        
        # Embeddings (BASELINE EMBEDDINGS CREATED HERE!)
        self.token_embedding = nn.Embedding(vocab_size, embedding_dim)
        self.position_embedding = nn.Embedding(seq_len, embedding_dim)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embedding_dim, num_heads, ff_dim, dropout)
            for _ in range(num_layers)
        ])
        
        # Output
        self.ln_f = nn.LayerNorm(embedding_dim)
        self.head = nn.Linear(embedding_dim, vocab_size)
        
        self.dropout = nn.Dropout(dropout)
        self.seq_len = seq_len
        
        # Initialize weights
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def forward(self, x):
        batch_size, seq_len = x.shape
        
        # Create causal mask (prevent looking at future tokens)
        mask = torch.tril(torch.ones(seq_len, seq_len)).view(1, 1, seq_len, seq_len)
        mask = mask.to(x.device)
        
        # Embeddings
        token_emb = self.token_embedding(x)
        pos = torch.arange(seq_len, device=x.device)
        pos_emb = self.position_embedding(pos)
        
        x = self.dropout(token_emb + pos_emb)
        
        # Transformer blocks
        attention_weights = []
        for block in self.blocks:
            x, attn = block(x, mask)
            attention_weights.append(attn)
        
        # Output
        x = self.ln_f(x)
        logits = self.head(x)
        
        return logits, attention_weights

print("✓ Complete GPT model defined")

# Create model
model = GPT(
    vocab_size=vocab_size,
    embedding_dim=embedding_dim,
    num_heads=num_heads,
    num_layers=num_layers,
    ff_dim=ff_dim,
    seq_len=seq_len,
    dropout=dropout
)

print("\n" + "="*60)
print("MODEL ARCHITECTURE")
print("="*60)
print(model)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("\n" + "="*60)
print("MODEL STATISTICS")
print("="*60)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size: ~{total_params * 4 / 1024 / 1024:.2f} MB")

# Test forward pass
print("\n" + "="*60)
print("TEST FORWARD PASS")
print("="*60)

dummy_input = torch.randint(0, vocab_size, (2, seq_len))
print(f"Input shape: {dummy_input.shape}")

logits, attn_weights = model(dummy_input)
print(f"Output logits shape: {logits.shape}")
print(f"Number of attention weight tensors: {len(attn_weights)}")
print(f"Each attention weight shape: {attn_weights[0].shape}")

print("\n✓ Forward pass successful!")

# Save model
torch.save({
    'model': model,
    'embedding_dim': embedding_dim,
    'num_heads': num_heads,
    'num_layers': num_layers,
    'ff_dim': ff_dim,
    'vocab_size': vocab_size,
    'seq_len': seq_len
}, 'model_untrained.pt')

print("\n" + "="*60)
print("STEP 4 COMPLETE ✓")
print("="*60)
print("✓ Model architecture built")
print("✓ Baseline embeddings created (random)")
print("✓ Ready for training!")
print("\nNext: We'll inspect these random embeddings before training")

BUILDING GPT MODEL FROM SCRATCH


UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL torch.utils.data.dataloader.DataLoader was not an allowed global by default. Please use `torch.serialization.add_safe_globals([DataLoader])` or the `torch.serialization.safe_globals([DataLoader])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.