In [2]:
import torch

In [None]:
GPT_CONFIG_124M={
    'vocab_size':50257,
    'context_length':1024,
    'emb_dim':768,
    'n_head':12,
    'n_layers':12,
    'drop_rate':0.1,
    'qkv_bias':False
}

In [5]:
inputs=torch.tensor(
    [
     [0.43,0.15,0.89],#your
     [0.55,0.87,0.66],#journey
     [0.57,0.85,0.64],#starts
     [0.22,0.58,0.33],#with
     [0.77,0.25,0.10],#one
     [0.05,0.80,0.55],#step
    ]
)

In [27]:
batch=torch.stack((inputs,inputs),dim=0)
print(batch)
print(batch.shape)

tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])
torch.Size([2, 6, 3])


In [None]:
class CausalSelfAttention(torch.nn.Module):
    
    
    

    def __init__(self,in_dim,out_dim,context_length,drop_rate):
        super().__init__()
        self.kw=torch.nn.Linear(in_dim,out_dim)
        
        self.qw=torch.nn.Linear(in_dim,out_dim)
        self.vw=torch.nn.Linear(in_dim,out_dim)
        self.dropout=torch.nn.Dropout(drop_rate)
        self.register_buffer('mask',torch.triu(torch.ones(context_length,context_length),diagonal=1))

    
    def forward(self,x):
        b,t,in_dim=x.shape

        key=self.kw(x)
        
        query=self.qw(x)
        value=self.vw(x)

        attn_score=query@key.transpose(1,2)
        

        attn_score.masked_fill_(self.mask.bool()[:t,:t],-torch.inf)
        
        attn_weight=torch.softmax(attn_score/key.shape[-1]**0.5,dim=-1)
        
        
        attn_weight=self.dropout(attn_weight)
        
        context_vectore=attn_weight@value

       


        return context_vectore









In [55]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self,in_dim,out_dim,context_length,drop_rate,num_head):
        super().__init__()
        self.head=torch.nn.ModuleList([CausalSelfAttention(in_dim,out_dim,context_length,drop_rate) for _ in range(num_head)])
        
    def forward(self,x):
       
        return torch.cat([head(x) for head in self.head],dim=-1)
    
x=MultiHeadAttention(batch.shape[-1],6,batch.shape[1],0.3,2)
x(batch)

torch.Size([2, 6, 6])
torch.Size([2, 6, 6])


tensor([[[ 0.0331, -0.3032, -0.5112,  0.5845,  0.5095, -1.3587,  0.2394,
          -0.2758, -1.1549,  0.0226, -0.0384, -0.1624],
         [ 0.0256, -0.4098, -0.1742,  0.3135,  0.4450, -0.9321,  0.1538,
          -0.1779, -1.0326,  0.0702,  0.1620, -0.2115],
         [ 0.0187, -0.3738, -0.2881,  0.3964,  0.4557, -1.0545,  0.0481,
          -0.0457, -0.5802,  0.0843,  0.2480, -0.1798],
         [ 0.0186, -0.4089, -0.1771,  0.3094,  0.4373, -0.9196,  0.0985,
          -0.1061, -0.7383,  0.0696,  0.1775, -0.1781],
         [-0.0203, -0.5659, -0.2880,  0.5080,  0.5585, -1.2978,  0.1860,
          -0.0441, -0.3022,  0.1267,  0.1170, -0.2173],
         [-0.1272, -0.5625, -0.2021,  0.4259,  0.3811, -0.9700,  0.1866,
          -0.0515, -0.4178,  0.1388,  0.1679, -0.2466]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0419, -0.5591, -0.4259,  0.6014,  0.6959, -1.6012,  0.1538,
          -0.1779, -1.

In [None]:
class GPTModel(torch.nn.Module):
    def __init__(self,cfg):
        super().__init__()
        self.tok_emb=torch.nn.Embedding(cfg['vocab_size'],cfg['emb_dim'])
        self.pos_emb=torch.nn.Embedding(cfg['context_length'],cfg['emb_dim'])
        self.drop_emb=torch.nn.Dropout(cfg['drop_rate'])

        self.trf_blocks=torch.nn.Sequential(
            *[TransformerBlock(cfg) for _ in range(cfg['n_layers'])])

        self.final_norm=LayerNorm(cfg['emb_dim'])#implemented
        self.output_head=torch.nn.Linear(cfg['emb_dim'], cfg['vocab_size'])

    
    def forward(self,in_idx):#input ID
        batch_size,seq_len=in_idx.shape()
        tok_emb=self.tok_emb(in_idx)
        pos_emb=self.pos_emb(torch.arange(seq_len,device=in_idx.device))
        x=tok_emb+pos_emb
        x=self.drop_emb(x)
        x=self.trf_blocks(x)#transformerblock implemented
        logits=self.output_head(x)

        return logits