In [1]:
import torch

In [14]:

import math

In [2]:
GPT_CONFIG_124M={
    'vocab_size':50257,
    'context_length':1024,
    'emb_dim':768,
    'n_head':12,
    'n_layers':12,
    'drop_rate':0.1,
    'qkv_bias':False
}

In [18]:
class LayerNorm(torch.nn.Module):
    def __init__(self,emb_dim):
        super().__init__()
        self.eps=1e-5
        self.scale=torch.nn.Parameter(torch.ones(emb_dim))
        self.shift=torch.nn.Parameter(torch.zeros(emb_dim))
    def forward(self,x):
        mean=x.mean(dim=-1,keepdim=True)
        var=x.var(dim=-1,keepdim=True,unbiased=False)
        norm_x=(x-mean)/torch.sqrt(var+self.eps)

        return self.scale*norm_x+self.shift

In [None]:
class CausalSelfAttention(torch.nn.Module):

    def __init__(self,emb_dim,out_dim,context_length,drop_rate):
        super().__init__()
        self.kw=torch.nn.Linear(emb_dim,out_dim)
        
        self.qw=torch.nn.Linear(emb_dim,out_dim)
        self.vw=torch.nn.Linear(emb_dim,out_dim)
        self.dropout=torch.nn.Dropout(drop_rate)
        self.register_buffer('mask',torch.triu(torch.ones(context_length,context_length),diagonal=1))

    
    def forward(self,x):
        b,t,emb_dim=x.shape

        key=self.kw(x)
        query=self.qw(x)
        value=self.vw(x)

        attn_score=query@key.transpose(1,2)
        attn_score.masked_fill_(self.mask.bool()[:t,:t],-torch.inf)
        
        attn_weight=torch.softmax(attn_score/key.shape[-1]**0.5,dim=-1)
        attn_weight=self.dropout(attn_weight)
        
        context_vectore=attn_weight@value
        
        return context_vectore









In [None]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self,emb_dim,out_dim,context_length,drop_rate,n_head):
        super().__init__()
        assert (out_dim%n_head==0)
        head_dim=out_dim//n_head
        self.head=torch.nn.ModuleList([CausalSelfAttention(emb_dim,head_dim,context_length,drop_rate) for _ in range(n_head)])
        
    def forward(self,x):
        return torch.cat([head(x) for head in self.head],dim=-1)





key tensor([[[ 0.0919,  0.5472, -0.1305],
         [ 0.1653,  0.5250,  0.0755],
         [ 0.1538,  0.5255,  0.0464],
         [ 0.0212,  0.3707, -0.0217],
         [-0.1513,  0.4570, -0.5829],
         [ 0.1455,  0.3767,  0.2710]],

        [[ 0.0919,  0.5472, -0.1305],
         [ 0.1653,  0.5250,  0.0755],
         [ 0.1538,  0.5255,  0.0464],
         [ 0.0212,  0.3707, -0.0217],
         [-0.1513,  0.4570, -0.5829],
         [ 0.1455,  0.3767,  0.2710]]], grad_fn=<ViewBackward0>)
key tensor([[[-0.5557,  0.0300, -0.1206],
         [-0.9426, -0.4177,  0.0160],
         [-0.9397, -0.4113, -0.0018],
         [-0.6286, -0.4219, -0.1434],
         [-0.6830, -0.2513, -0.4390],
         [-0.6829, -0.4860,  0.0425]],

        [[-0.5557,  0.0300, -0.1206],
         [-0.9426, -0.4177,  0.0160],
         [-0.9397, -0.4113, -0.0018],
         [-0.6286, -0.4219, -0.1434],
         [-0.6830, -0.2513, -0.4390],
         [-0.6829, -0.4860,  0.0425]]], grad_fn=<ViewBackward0>)


tensor([[[ 1.1967e+00, -5.0772e-01,  1.0888e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00],
         [ 8.2270e-01, -3.0353e-01,  5.5394e-01,  2.0959e-01, -8.1366e-01,
          -4.8994e-02],
         [ 1.1090e+00, -4.0311e-01,  7.4355e-01,  2.6713e-01, -1.0503e+00,
          -6.8118e-02],
         [ 1.0121e+00, -4.1737e-01,  7.4802e-01,  2.1148e-01, -9.9142e-01,
          -1.3058e-01],
         [ 1.1730e+00, -3.7479e-01,  8.0469e-01,  7.6060e-02, -3.0298e-01,
          -2.1087e-02],
         [ 7.0502e-01, -2.1542e-01,  4.3351e-01, -6.9874e-03, -2.6695e-01,
          -8.3627e-02]],

        [[ 1.1967e+00, -5.0772e-01,  1.0888e+00,  7.8556e-02, -9.8259e-01,
          -3.4292e-01],
         [ 1.4401e+00, -5.6548e-01,  1.1157e+00,  2.4585e-01, -1.2671e+00,
          -2.0726e-01],
         [ 1.5236e+00, -5.7905e-01,  1.1209e+00,  1.5419e-01, -8.1630e-01,
          -1.3966e-01],
         [ 1.1485e+00, -4.3633e-01,  8.4431e-01,  1.8317e-01, -9.5842e-01,
           9.7165e-03],
        

In [None]:
class GELU(torch.nn.Module):

    def forward(self,x):
        sqrt_2_over_pi = math.sqrt(2.0 / math.pi)
        res=0.5 * x * (1 + torch.tanh(sqrt_2_over_pi * (x + 0.044715 * x**3)))

        return res





In [None]:
class FeedForward(torch.nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.layer=torch.nn.Sequential(
            torch.nn.Linear(emb_dim, 4*emb_dim),
            GELU(),
            torch.nn.Linear(4*emb_dim,emb_dim)

        )

    def forward(self,x):
        return self.layer(x)
        


In [None]:
class TransformerBlock(torch.nn.Module):
    def __init__(self):
        super().__init__()
        config=GPT_CONFIG_124M
        self.layer_norm=LayerNorm(config['emb_dim'])
        self.mha=MultiHeadAttention(config['emb_dim'],config['emb_dim'],config['context_length'],config['drop_rate'],config['n_head'])
        self.dropout=torch.nn.Dropout(config['drop_rate'])
        self.ff=FeedForward(config['emb_dim'])

    def forward(self,x):
        shortcut=x
        l1_x=self.layer_norm(x)
        mha=self.mha(l1_x)
        dropout=self.dropout(mha)

        x=shortcut+dropout
        l2_x=self.layer_norm(x)
        ff=self.ff(l2_x)
        dropout=self.dropout(ff)

        out=dropout+l2_x
        return out



    

In [None]:
class GPTModel(torch.nn.Module):
    def __init__(self,cfg):
        super().__init__()
        self.tok_emb=torch.nn.Embedding(cfg['vocab_size'],cfg['emb_dim'])
        self.pos_emb=torch.nn.Embedding(cfg['context_length'],cfg['emb_dim'])
        self.drop_emb=torch.nn.Dropout(cfg['drop_rate'])

        self.trf_blocks=torch.nn.Sequential(
            *[TransformerBlock(cfg) for _ in range(cfg['n_layers'])])

        self.final_norm=LayerNorm(cfg['emb_dim'])#implemented
        self.output_head=torch.nn.Linear(cfg['emb_dim'], cfg['vocab_size'])

    
    def forward(self,in_idx):#input ID
        batch_size,seq_len=in_idx.shape()
        tok_emb=self.tok_emb(in_idx)
        pos_emb=self.pos_emb(torch.arange(seq_len,device=in_idx.device))
        x=tok_emb+pos_emb
        x=self.drop_emb(x)
        x=self.trf_blocks(x)#transformerblock implemented
        logits=self.output_head(x)

        return logits