In [1]:
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from yacs.config import CfgNode as CN
from dataclasses import dataclass

### MODEL DEFINITION

In [None]:

# first define the transfomer network

# A transfomer has an word embedding layer, position embedding layer, series of Blocks, dropout and layerNorma.
# Each Block is another class which CausalSelfAttention Layer, and MLP layer with two layerNorms, 
# Causal Self Attention Module and an MLP module.
# MLP module consists of fully connected layer, projection layer, relu and drop out
# Causal Self Attention is the multi-headed attention layer


class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert (config.embed_size % config.num_heads == 0)
        
        # layer to get the key, query, value from a batch of inputs
        # basically this linear layer defines the weights
        self.c_attn = nn.Linear(config.embed_size, 3 * config.embed_size, bias=config.bias)
        self.c_proj = nn.Linear(config.embed_size, config.embed_size, bias=config.bias)
        
        # for regularization
        self.attn_dropout = nn.Dropout(config.dropput)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.num_heads = config.num_heads
        self.embed_size = config.embed_size
        self.dropout = config.dropout
        
        # causal mask to ensure attention is only applied to the left
        # basically creating an vector called bias which creates a lower triangular portion
        self.register_buffer("bias",
                             torch.tril(torch.ones((config.block_size, config.block_size))).view(1,1,config.block_size, config.block_size))
        
    def forward(self, x):
        batch_size, seq_length, embed_size = x.size()
        
        # calculate query, key values for all heads
        query_x, key_x, value_x = self.c_attn(x).split(self.embed_size, dim=2)
        
        # divide into multiple heads of the key, query and value - batch_size x seq_length x num_heads x embed_size/num_heads -> batch_size x num_heads x seq_length x embed_size / num_heads
        query_x = query_x.view(batch_size, seq_length, self.num_heads, embed_size // self.num_heads).transpose(1,2) # 
        key_x = key_x.view(batch_size, seq_length, self.num_heads, embed_size // self.num_heads).transpose(1,2) 
        value_x = value_x.view(batch_size, seq_length, self.num_heads, embed_size // self.num_heads).transpose(1,2) 

        # implementation of attention
        # query is batch_size x num_heads x seq_length x embed_size//num_heads 
        # key is als the same. 
        # if batch_size = 1, num_heads=1, 
        # seq_length refers to the nodes in query and key. More like a graph
        # each node refers to a vector at particular position in the sequence
        # 1 x 1 x seq_length x embed_size  * 1 x 1 x embed_size x seq_length
        # 1 x 1 x seq_length x seq_length - 
        # the interpretation
        # for each node in query sequence length, the attention gives the scores for the nodes in the key sequence
        
        att = query_x @ key_x.transpose(-2,-1) * (1.0 / math.sqrt(key_x.size(-1)))     
        
        # now we can introduce which instant or element in the sequenc length needs to be mask
        # by setting self.bias[ele] = 0
        att = att.masked_fill(self.bias[:,:,:seq_length,:seq_length] == 0, float('-inf'))
        
        # apply softmax - for each node in query, softmax is applied to all nodes across the keys
        att = F.softmax(att, dim = -1)
        
        # attention applued to all nodes in value
        # we get batch_size x num_heads x seq_len x seq_len * 
        # batch_size x num_heads x seq_len x embed_size//num_heads
        y = att @ value_x 
        
        # re-assemble all heads
        # batch_size x num_heads x seq_len x embed_size //num_heads
        # batch_size x seq_len x num_heads x embed_size//num_heads
        # convert to 1D ,a dn rearrance to batch_size x seq_length x embed_size
        y = y.transpose(1,2).contiguous().view(batch_size, seq_length, embed_size)
        
        # output projection
        y = self.resid_dropout(self.c_proj(y))
        
        return y

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.embed_size, 4 * config.embed_size, bias=config.bias)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(4 * config.embed_size, config.embed_size, bias = config.bias)
        self.dropout = nn.Dropout(config.dropout)
    
    def forward(self,x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x
    
# define block
class Block(nn.Module):
    """
        Transformer Block    

    Args:
        nn (_type_): _description_
    """
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.embed_size)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.embed_size)
        self.mlp = MLP(config)
        
    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

class GPT(nn.Module):
    """ BINU GPT"""  
    
    """
            Hyper parameters are 
            1. Vocabulary Size = Total number of tokens in the set
            2. Block Size = context surrounding the current word
            3. Embedding size = size of the vector N_E
            4. Number of heads = number of heads in a multi-headed attention. Basically, splitting the vector into N_H heads of size N_E/N_H
            5. Number of layers = number of blocks containing the Multi-headed attention, and MLP
            6. Dropout Hyperparameters - for embedding layer, resdual layer and attention  
    """
    
    @staticmethod
    def get_default_config():
        
        # setting the default 
        C = CN()
        C.block_size = None 
        C.vocab_size = None 
        C.num_layers = 3
        C.num_heads = 3
        C.embed_size = 48
        C.embed_pdrop = 0.1
        C.resid_pdrop = 0.1
        C.attn_pdrop = 0.1
        return C
    
    def __init__(self, config) -> None:
        super().__init__()
        assert(config.block_size is not None)
        assert(config.vocab_size is not None)
        self.config = config
        
        self.transformer = nn.ModuleDict(
            {
                'wte' : nn.Embedding(config.vocab_size, config.embed_size),
                'wpe' : nn.Embedding(config.block_size, config.embed_size),
                'drop' : nn.Dropout(config.dropout),
                'h' : nn.ModuleList([Block(config) for _ in range(config.num_layers)]),
                'ln_f' : nn.LayerNorm(config.embed_size, bias = config.bias)
            }   
        )
        self.lm_head = nn.Linear(config.embed_size, config.vocab_size, bias=False)
        
        # initialize all weights recursively through the network
        self.apply(self._init_weights)
        # add a special scaled unit to the residual projections
        for param_name, param in self.named_parameters():
            if param_name.endswith('c_proj.weight'):
                torch.nn.init.normal_(param, mean=0.0, std=0.02/math.sqrt(2 * config.num_layers))
        
    # function to initialize weights for a module
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std = 0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)
            
    def configure_optimizers(self, train_config):
        decay_params_set = set()
        no_decay_params_set = set()
        
        allowlist_weight_modules = (torch.nn.Linear, ) # set of modules that has a weight decay parameter
        denylist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) # set of modules that do not have a weight decay parameter
        
        for module_name, module in self.named_modules():
            for param_name, params in module.named_parameters():
                full_param_name = f'{module_name}.{param_name}' if module_name else param_name
                
                # bias terms will not have weight decay regularization
                if param_name.endswith('bias'):
                    no_decay_params_set.add(full_param_name)
                elif param_name.endswith('weight') and isinstance(module, allowlist_weight_modules):
                    decay_params_set.add(full_param_name)
                elif param_name.endswith('weight') and isinstance(module, denylist_weight_modules):
                    no_decay_params_set.add(full_param_name)
        
        # validation that we considered every parameter
        param_dict = dict(self.named_parameters())
        
        # no overlap between decay and no_decay param set
        assert ( len(decay_params_set & no_decay_params_set) == 0 )
        
        # no parameters were missing
        assert ( len(param_dict) - len(decay_params_set | no_decay_params_set)  == 0)
        
        # create the pytorch optimizer object
        optim_groups = [
            {"params": [param_dict[param_name] for param_name in sorted(list(decay_params_set))], "weight_decay": train_config.weight_decay},
            {"params": [param_dict[param_name] for param_name in sorted(list(no_decay_params_set))], "weight_decay": 0.0}
        ]
        
        optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
        return optimizer
    
def forward(self, inputs, targets = None):
    device = inputs.device
    batch_size, seq_length = inputs.size()
    assert (seq_length <= self.block_size)
    
    # create a postion vector of 1 x seq_length
    position_vec = torch.arange(0, seq_length, dtype=torch.long, device=device).unsqueeze(0) 

    # get the token embeddings of size batch_size x seq_length x embed_size
    token_embeddings = self.transformer.wte(inputs)
    pos_embeddings = self.transformer.wpe(position_vec) #  1 x seq_length x embed_size
    x = self.transformer.drop(token_embeddings + pos_embeddings) # will broadcast across batch dimensions
    
    # pass it through the self attention blocks
    for block in self.transformer.h:
        x = block(x)
    x = self.transformer.ln_f(x)
    
    # obtain logits - batch_size x seq_length x vocab_size - 3D tensor
    logits = self.lm_head(x)
    
    loss = None
    if targets is not None:
        # here targets would be a 2D tensor batch_size x seq_length which gives the target label. 
        # we need to reshape into a 1D vector that F.cross_entropy takes
        loss = F.cross_entropy( logits.view(-1, logits.size(-1)) , targets.view(-1))
    return logits, loss

@torch.no_grad()
def generate(self, inputs, max_new_tokens, temperature = 1.0, do_sample = False, top_k = None):
   """
   Take a conditioning sequence of indices 'inputs' of size (batch_size, seq_length) and complete the sequence max_new_token times, 
   feeding the predictions back into the model each itme. The model needs to be in model.eval() model
   """ 
   for t_idx in range(max_new_tokens):
       # take the last 'block_size' sequence length
       inputs_cond = inputs if inputs.size(1) <= self.block_size else inputs[:,-self.block_size:]
       
       # forward the model to get logits
       logits, _ = self(inputs_cond)
       
       # taking the logits from the last step of the sequence
       logits = logits[:,-1,:] / temperature
       
       