In [9]:
import torch
import tiktoken
import torch.nn as nn

# GPT dummy model

In [79]:
class DummyTransformerBlock(nn.Module):
    def __init__(self,cfg) -> None:
        super().__init__()

    def forward(self,x):
        return x

class DummyLayerNorm(nn.Module):

    def __init__(self,normalized_shape ) -> None:
        super().__init__()

    def forward(self,x):
        return x
    

class DummyGPTModel(nn.Module):
    def __init__(self,cfg) -> None:
        super().__init__()
        #toke pos #transformer #layer norm
        
        self.toke_emb=nn.Embedding(cfg['vocab_size'],cfg['emb_dim'])
        self.pos_emb=nn.Embedding(cfg['context_length'],cfg['emb_dim'])
        
        self.btrf=nn.Sequential(*[DummyTransformerBlock(cfg) for _ in range(cfg['n_layers'])])
        self.drop=nn.Dropout(cfg['drop_rate'])

        self.final=DummyLayerNorm(cfg['emb_dim'])
        self.out_head=nn.Linear(cfg['emb_dim'],cfg['vocab_size'],bias=False)

    def forward(self,in_idx):
        batch_size,seq_length=in_idx.shape
        token_embd=self.toke_emb(in_idx)
        pos_embd=self.pos_emb(torch.arange(seq_length))
        
        input_embd=token_embd+pos_embd

        input_embd=self.drop(input_embd)
        input_embd=self.btrf(input_embd)
        input_embd=self.final(input_embd)

        logits=self.out_head(input_embd)
        return logits

In [80]:
tokenizer=tiktoken.get_encoding('gpt2')

In [81]:
batch=[]
txt='hello there i am'
txt2="this is not what"

batch.append(torch.tensor(tokenizer.encode(txt)))
batch.append(torch.tensor(tokenizer.encode(txt2)))

batch=torch.stack(batch,dim=0)

In [82]:
batch

tensor([[31373,   612,  1312,   716],
        [ 5661,   318,   407,   644]])

In [83]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,    # Vocabulary size
    "context_length": 1024, # Context length
    "emb_dim": 768,         # Embedding dimension
    "n_heads": 12,          # Number of attention heads
    "n_layers": 12,         # Number of layers
    "drop_rate": 0.1,       # Dropout rate
    "qkv_bias": False       # Query-Key-Value bias
}

In [84]:
model=DummyGPTModel(GPT_CONFIG_124M)
logits=model(batch)

In [88]:
logits.shape

torch.Size([2, 4, 50257])

In [89]:
batch.shape

torch.Size([2, 4])