In [32]:
import torch
import torch.nn as nn
import torch.nn.functional as F
print('hello')

hello


In [33]:
'''CIFAR100, CIFAR10, MNIST Datasets'''
import torch 
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt 
from PIL import Image
import numpy as np 
import os 
from pathlib import Path 

# Huggingface Datasets + GPT Tokenizer
from datasets import load_dataset, load_from_disk
from transformers import GPT2Tokenizer 


'''Wikitext-103 Dataset Class'''
class WikiText103:
    def __init__(self, args):
        self.max_seq_length = args.max_seq_length
        self.block_size = self.max_seq_length
        self.cache_dir = os.path.join(args.data_path, f"wikitext103_cache_{args.max_seq_length}")
        
        # Tokenizer 
        self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        self.tokenizer.pad_token = self.tokenizer.eos_token

        if os.path.exists(self.cache_dir):
            print(f"Loading preprocessed dataset from {self.cache_dir}")
            self.lm_dataset = load_from_disk(self.cache_dir)
        else: 
            os.makedirs(self.cache_dir, exist_ok=True)

            # Dataset
            self.original_dataset = load_dataset("wikitext", "wikitext-103-v1")

        
            self.tokenized_dataset = self.original_dataset.map(
                self.tokenize_function, 
                batched=True, 
                remove_columns=["text"]
            )
            
            self.lm_dataset = self.tokenized_dataset.map(
                self.group_texts, 
                batched=True
            )
            self.lm_dataset.set_format(type="torch", columns=["input_ids"])
            self.lm_dataset.save_to_disk(self.cache_dir)
            print(f"Preprocessed dataset saved to {self.cache_dir}")

        # Data Loaders 
        self.train_loader = DataLoader(dataset=self.lm_dataset["train"], batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=False)
        self.test_loader = DataLoader(dataset=self.lm_dataset["test"], batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=False)

    def group_texts(self, examples): 
        concatenated = {k: sum(examples[k], []) for k in examples.keys()}
        total_length = len(concatenated[list(examples.keys())[0]])
        total_length = (total_length // self.block_size) * self.block_size
        result = {
            k: [t[i : i + self.block_size] for i in range(0, total_length, self.block_size)]
            for k, t in concatenated.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result

    def tokenize_function(self, examples):
        return self.tokenizer(examples["text"])

In [34]:
from Models.activation import (GELU_s, SiLU_s, ZiLU_Old, ArcTan,
                               ArcTan_Approx, ZiLU, ZiLU_Approx)
import torch.nn as nn

class GPT2(nn.Module):
    def __init__(self, 
                 args, 
                 vocab_size=50257, 
                 max_seq_length=1024,
                 embedding_dim=768,
                 num_attention_heads=12,
                 num_layers=12,
                 dropout=0.1,
                 device='cuda'):

        super(GPT2, self).__init__()

        self.args = args 

        self.vocab_size = vocab_size
        self.max_seq_length = max_seq_length
        self.embedding_dim = embedding_dim
        self.num_attention_heads = num_attention_heads
        self.num_layers = num_layers
        self.dropout = dropout
        
        self.device = device

        # Dropout 
        self.dropout = nn.Dropout(self.dropout)
        
        # Embeddings 
        self.token_embeddings = nn.Embedding(vocab_size, embedding_dim)     
        self.position_embeddings = nn.Embedding(max_seq_length, embedding_dim)
        
        # Transformer Blocks    
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(args, embedding_dim, num_attention_heads, embedding_dim * 4, max_seq_length, dropout)
            for _ in range(num_layers)
        ])

        # Final Layer Norm 
        self.layer_norm = nn.LayerNorm(embedding_dim)

        # Linear output layer 
        self.lm_head = nn.Linear(embedding_dim, vocab_size, bias=False)
        self.lm_head.weight = self.token_embeddings.weight
        
        # Initialize Weights 
        self.apply(self._init_weights)

        # Scaled Initialization for Residual Layers 
        for pn, p in self.named_parameters():
            if p.dim() > 1:
                if 'fc2' in pn or 'w_o' in pn:
                    p.data.normal_(mean=0.0, std=0.02 / math.sqrt(2 * num_layers))

        self.name = f"GPT2_{num_layers}L_{num_attention_heads}H_{embedding_dim}D_{args.activation}"
        self.to(device)

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, input_ids, target=None):
        batch_size, seq_length = input_ids.size()

        # Create position ids 
        position_ids = torch.arange(0, seq_length, dtype=torch.long, device=self.device)

        # Embeddings 
        token_embeds = self.token_embeddings(input_ids)
        position_embeds = self.position_embeddings(position_ids)
        x = token_embeds + position_embeds

        x = self.dropout(x)

        # Transformer Blocks 
        for layer in self.transformer_blocks:
            x = layer(x)

        # Final Layer Norm 
        x = self.layer_norm(x)

        if target is not None:
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target.view(-1), ignore_index=-1)
        else:
            logits = self.lm_head(x[:, [-1], :])
            loss = None 

        return logits, loss
    
    def parameter_count(self, non_embeddings=True): 
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        
        return total_params, trainable_params
    
class CausalMultiHeadAttention(nn.Module):
    def __init__(self, d_embeddings, num_heads, max_seq_length, dropout):
        super(CausalMultiHeadAttention, self).__init__()

        assert d_embeddings % num_heads == 0, "Match Embeddings with Number of Heads"

        self.d_embedding = d_embeddings
        self.num_heads = num_heads
        self.max_seq_length = max_seq_length
        self.d_heads = d_embeddings // num_heads


        self.w_k = nn.Linear(d_embeddings, d_embeddings)
        self.w_q = nn.Linear(d_embeddings, d_embeddings)
        self.w_v = nn.Linear(d_embeddings, d_embeddings)
        self.w_o = nn.Linear(d_embeddings, d_embeddings)

        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

        causal_mask = torch.tril(torch.ones(max_seq_length, max_seq_length)).view(1, 1, max_seq_length, max_seq_length)
        self.register_buffer('causal_mask', causal_mask)

    def split_heads(self, x):
        batch_size, seq_length, d_embeddings = x.shape 
        return x.view(batch_size, seq_length, self.num_heads, self.d_heads).transpose(1, 2)

    def combine_heads(self, x):
        batch_size, num_heads, seq_length, d_heads = x.shape 
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, num_heads * d_heads)        
        
    def forward(self, x):
        k = self.split_heads(self.w_k(x))
        q = self.split_heads(self.w_q(x))
        v = self.split_heads(self.w_v(x))        

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * (1.0 / math.sqrt(v.size(-1)))

        seq_length = attn_scores.size(-2)
        mask_slice = self.causal_mask[:, :, :seq_length, :seq_length]
        attn_scores = attn_scores.masked_fill(mask_slice == 0, float('-inf'))

        # Softmax and weighting 
        attn_probs = F.softmax(attn_scores, dim=-1)
        attn_probs = self.attn_dropout(attn_probs)
        attn_output = torch.matmul(attn_probs, v)

        attn_output = self.combine_heads(attn_output)
        attn_output = self.w_o(attn_output)
        attn_output = self.resid_dropout(attn_output)
        return attn_output

class MLP(nn.Module):
    def __init__(self, args, d_model, d_ff, dropout):
        super(MLP, self).__init__() 

         # Activation Selection
        self.activation = args.activation

        # Activation function mapping
        self.activation_map = {
            "relu": lambda: nn.ReLU(inplace=args.inplace), 
            "silu": lambda: nn.SiLU(inplace=args.inplace), 
            "gelu": lambda: nn.GELU(), 
            "sigmoid": lambda: nn.Sigmoid(), 

            # Previous Activation Generation
            "gelu_s": lambda: GELU_s(sigma=args.sigma, inplace=args.inplace), 
            "silu_s": lambda: SiLU_s(sigma=args.sigma, inplace=args.inplace), 
            "zilu_old": lambda: ZiLU_Old(sigma=args.sigma, inplace=args.inplace), 

            # Current Activation Generation 
            "arctan": lambda: ArcTan(sigma=args.sigma), 
            "arctan_approx": lambda: ArcTan_Approx(sigma=args.sigma), 
            "zilu": lambda: ZiLU(sigma=args.sigma), 
            "zilu_approx": lambda: ZiLU_Approx(sigma=args.sigma) 
        }

        if self.activation not in self.activation_map:
            raise ValueError(f"Unsupported activation function: {self.activation}")
            
        
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.activation_function = self.activation_map[self.activation]()

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.activation_function(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, args, d_model, num_heads, d_ff, max_seq_length, dropout):
        super(TransformerBlock, self).__init__()
        self.attention = CausalMultiHeadAttention(d_model, num_heads, max_seq_length, dropout)
        self.mlp = MLP(args, d_model, d_ff, dropout)
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        # Pre-Norm Multi-Head Attention 
        norm_x = self.layer_norm1(x)
        attn_output = self.attention(norm_x)
        x = x + attn_output

        # Post-Norm MLP
        norm_x = self.layer_norm2(x)
        mlp_output = self.mlp(norm_x)
        x = x + mlp_output
        return x

In [35]:
from types import SimpleNamespace
import math    
args = SimpleNamespace(
        vocab_size=50257,
        max_seq_length=1024, 
        embedding_dim=768, 
        num_attention_heads=12, 
        num_layers=12, 
        dropout=0.1, 
        activation='gelu', 
        sigma=1.0, 
        dataset='wikitext103', 
        data_path='./Data', 
        device='cuda',
        batch_size=8,
        
    )

In [36]:
model = GPT2(
        args=args,
        vocab_size=args.vocab_size,
        max_seq_length=args.max_seq_length,
        embedding_dim=args.embedding_dim,
        num_attention_heads=args.num_attention_heads,
        num_layers=args.num_layers,
        dropout=args.dropout,
        device=args.device
    )

device = args.device

In [37]:
dataset = WikiText103(args)


Loading preprocessed dataset from ./Data/wikitext103_cache_1024


In [45]:
batch = next(iter(dataset.train_loader))
tokens = batch["input_ids"]

inputs = tokens[:, :-1].contiguous()
targets = tokens[:, 1:].contiguous()

print(tokens.shape)
print(inputs.shape)
print(targets.shape)


torch.Size([8, 1024])
torch.Size([8, 1023])
torch.Size([8, 1023])


In [44]:
# Get the tokenizer to decode the numbers back to words
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# Take the first sequence in the batch
inp_ids = inputs[0]
tgt_ids = targets[0]

# Print the last 10 tokens as numbers
print("Input IDs (last 10): ", inp_ids[-10:].tolist())
print("Target IDs (last 10):", tgt_ids[-10:].tolist())

print("-" * 40)

# Decode to text to see the shift visually
print("Input Text (last 5 words):  ", tokenizer.decode(inp_ids[-5:]))
print("Target Text (last 5 words): ", tokenizer.decode(tgt_ids[-5:]))

Input IDs (last 10):  [764, 1355, 7637, 837, 508, 550, 587, 6623, 287, 10173]
Target IDs (last 10): [1355, 7637, 837, 508, 550, 587, 6623, 287, 10173, 312]
----------------------------------------
Input Text (last 5 words):    had been resident in Arc
Target Text (last 5 words):   been resident in Arcid
