In [None]:
import torch
import torch.nn as nn
import tiktoken

class MultiheadAttention(nn.Module):
    def __init__(self, d_in, d_out, dropout, num_heads, context_length):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_querys = nn.Linear(d_in, d_out, bias=False)
        self.W_keys = nn.Linear(d_in, d_out, bias=False)
        self.W_values = nn.Linear(d_in, d_out, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.linear_projection = nn.Linear(d_out, d_out)
        
        # Causal mask for autoregressive processing
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, inputs):
        batch, num_tokens, dim = inputs.shape
        query = self.W_querys(inputs)
        key = self.W_keys(inputs)
        value = self.W_values(inputs)

        query = query.view(batch, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
        key = key.view(batch, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
        value = value.view(batch, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)

        attn_scores = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5)
        
        # Apply causal mask
        mask_bool = self.mask[:num_tokens, :num_tokens].bool()
        attn_scores.masked_fill_(mask_bool, float('-inf'))

        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        context_vec = torch.matmul(attn_weights, value)
        context_vec = context_vec.transpose(1, 2).contiguous().view(batch, num_tokens, -1)
        return self.linear_projection(context_vec)

class GELU(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) * 
            (x + 0.044715 * torch.pow(x, 3))
        ))

class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(config['embedding_dim'], config['embedding_dim'] * 4),
            GELU(),
            nn.Linear(config['embedding_dim'] * 4, config['embedding_dim'])
        )

    def forward(self, inputs):
        return self.layers(inputs)

class LayerNorm(nn.Module):
    def __init__(self, emb_dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))

    def forward(self, inputs):
        mean = inputs.mean(dim=-1, keepdim=True)
        var = inputs.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (inputs - mean) / torch.sqrt(var + self.eps)
        return norm_x * self.scale + self.shift

class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = MultiheadAttention(
            d_in=config['embedding_dim'], 
            d_out=config['embedding_dim'], 
            dropout=config['dropout'], 
            num_heads=config['n_heads'], 
            context_length=config['context_length']
        )
        self.norm1 = LayerNorm(config["embedding_dim"])
        self.norm2 = LayerNorm(config["embedding_dim"])
        self.ff = FeedForward(config)
        self.dropout = nn.Dropout(config['dropout'])

    def forward(self, inputs):
        add_connection = inputs
        output = self.norm1(inputs)
        output = self.attention(output)
        output = self.dropout(output)
        output = output + add_connection
        
        add_connection = output
        output = self.norm2(output)
        output = self.ff(output)
        output = self.dropout(output)
        output = output + add_connection
        
        return output

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_embedding = nn.Embedding(config['vocab_size'], config['embedding_dim'])
        self.pos_embedding = nn.Embedding(config['context_length'], config['embedding_dim'])
        self.dropout = nn.Dropout(config['dropout'])
        self.trf_blocks = nn.Sequential(*[TransformerBlock(config) for _ in range(config["n_layers"])])
        self.out_head = nn.Linear(config["embedding_dim"], config["vocab_size"], bias=False)
        self.final_norm = LayerNorm(config["embedding_dim"])

    def forward(self, inputs):
        batch_size, seq_len = inputs.shape
        tok_embeds = self.token_embedding(inputs)
        pos_embeds = self.pos_embedding(torch.arange(seq_len, device=inputs.device))
        
        x = tok_embeds + pos_embeds
        x = self.dropout(x)
        x = self.trf_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits

# GPT Model Configuration
GPT_CONFIG_124M = {
    "vocab_size": 50257,    
    "context_length": 1024,  
    "embedding_dim": 768,    
    "n_heads": 12,          
    "n_layers": 12,         
    "dropout": 0.1,         
}

# Initialize tokenizer
tokenizer = tiktoken.get_encoding("gpt2")

# Tokenizing input text
txt1 = "Every effort moves you"
txt2 = "Every day holds a"

batch = [torch.tensor(tokenizer.encode(txt1)), torch.tensor(tokenizer.encode(txt2))]
batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0)  # Ensure proper batch shape

# Set random seed for reproducibility
torch.manual_seed(123)

# Initialize GPT model
model = GPT(GPT_CONFIG_124M)
# Forward pass
out = model(batch)

# Print results
print("Input batch:\n", batch)
print("\nOutput shape:", out.shape)
print(out)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params:,}")

Input batch:
 tensor([[6109, 3626, 6100,  345],
        [6109, 1110, 6622,  257]])

Output shape: torch.Size([2, 4, 50257])
tensor([[[ 0.1381,  0.0077, -0.1963,  ..., -0.0222, -0.1060,  0.1717],
         [ 0.3865, -0.8408, -0.6564,  ..., -0.5163,  0.2369, -0.3357],
         [ 0.6989, -0.1829, -0.1631,  ...,  0.1472, -0.6504, -0.0056],
         [-0.4290,  0.1669, -0.1258,  ...,  1.1579,  0.5303, -0.5549]],

        [[ 0.1094, -0.2894, -0.1467,  ..., -0.0557,  0.2911, -0.2824],
         [ 0.0882, -0.3552, -0.3527,  ...,  1.2930,  0.0053,  0.1898],
         [ 0.6091,  0.4702, -0.4094,  ...,  0.7688,  0.3787, -0.1974],
         [-0.0612, -0.0737,  0.4751,  ...,  1.2463, -0.3834,  0.0609]]],
       grad_fn=<UnsafeViewBackward0>)
Total number of parameters: 163,009,536


In [15]:
import torch

def generate_text_simple(model, idx, max_new_tokens, context_size):
    """
    Generates text using the given model.
    
    Parameters:
    - model: The trained language model.
    - idx: A tensor containing the initial sequence of token indices.
    - max_new_tokens: Number of new tokens to generate.
    - context_size: Number of tokens the model can use as context.
    
    Returns:
    - The final generated token sequence.
    """

    print(f"Initial input token indices:\n{idx}\n")  # Printing initial input

    for step in range(max_new_tokens):  # Loop to generate tokens
        print(f"Step {step + 1}: Generating new token...\n")

        # 1. **Trim the input context** (model only supports a fixed length)
        idx_cond = idx[:, -context_size:]
        print(f"Context tokens (last {context_size} tokens):\n{idx_cond}\n")
        
        # 2. **Pass the context into the model to get predictions**
        with torch.no_grad():  # Disable gradients for efficiency
            logits = model(idx_cond)
        
        print(f"Logits shape (Batch x Tokens x Vocab Size): {logits.shape}\n")
        
        # 3. **Extract only the last token's logits**
        logits = logits[:, -1, :]  # Shape becomes (batch, vocab_size)
        print(f"Logits for the last predicted token:\n{logits}\n")

        # 4. **Convert logits to probabilities using softmax**
        probas = torch.softmax(logits, dim=-1)
        print(f"Probabilities after softmax:\n{probas}\n")

        # 5. **Choose the most probable next token**
        idx_next = torch.argmax(probas, dim=-1, keepdim=True)  # Shape: (batch, 1)
        print(f"Predicted next token index:\n{idx_next}\n")

        # 6. **Append the predicted token to the sequence**
        idx = torch.cat((idx, idx_next), dim=1)  # Shape: (batch, n_tokens + 1)
        print(f"Updated sequence:\n{idx}\n")
        print("-" * 50)  # Separator for readability

    print("\nFinal generated sequence:")
    print(idx)
    return idx


In [16]:
start_context = "Hello, I am"

encoded = tokenizer.encode(start_context)
print("encoded:", encoded)

encoded_tensor = torch.tensor(encoded).unsqueeze(0)
print("encoded_tensor.shape:", encoded_tensor.shape)

encoded: [15496, 11, 314, 716]
encoded_tensor.shape: torch.Size([1, 4])


In [17]:
model.eval() # disable dropout

out = generate_text_simple(
    model=model,
    idx=encoded_tensor, 
    max_new_tokens=6, 
    context_size=GPT_CONFIG_124M["context_length"]
)

print("Output:", out)
print("Output length:", len(out[0]))

Initial input token indices:
tensor([[15496,    11,   314,   716]])

Step 1: Generating new token...

Context tokens (last 1024 tokens):
tensor([[15496,    11,   314,   716]])

Logits shape (Batch x Tokens x Vocab Size): torch.Size([1, 4, 50257])

Logits for the last predicted token:
tensor([[-0.6430, -0.1466, -0.1405,  ...,  1.5849, -0.9539, -0.8765]])

Probabilities after softmax:
tensor([[8.8912e-06, 1.4606e-05, 1.4696e-05,  ..., 8.2514e-05, 6.5151e-06,
         7.0397e-06]])

Predicted next token index:
tensor([[27018]])

Updated sequence:
tensor([[15496,    11,   314,   716, 27018]])

--------------------------------------------------
Step 2: Generating new token...

Context tokens (last 1024 tokens):
tensor([[15496,    11,   314,   716, 27018]])

Logits shape (Batch x Tokens x Vocab Size): torch.Size([1, 5, 50257])

Logits for the last predicted token:
tensor([[ 0.3104, -0.1029, -0.0867,  ..., -0.2110, -0.1270, -0.5638]])

Probabilities after softmax:
tensor([[2.3012e-05, 1.5221e

In [18]:
decoded_text = tokenizer.decode(out.squeeze(0).tolist())
print(decoded_text)

Hello, I am Featureiman Byeswickattribute argue
