In [57]:
import tiktoken
import torch
import torch.nn as nn 
import math
from torch.nn import functional as F
from model import Encoder

# -------- CONSTANTS -------- # 
device = "cpu" # "cuda:0" if torch.cuda.is_available() else "cpu"
VOCAB_SIZE = 50257 # 51000
LR = 6e-4 # 3e-4
DROPOUT = 0.2
HEADS = 8
NX = 8
LR = 3e-4 # 6e-4
BATCH_SIZE = 14 # 64
CTX = 200 # 256
EMBED_DIM = 584
# --------------------------- # 
enc = tiktoken.get_encoding("gpt2")

In [58]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [59]:
train_example = "What is the worst customer service experience you have ever had? "
train_example += '<|endoftext|>'
print(train_example)
idx = enc.encode(train_example, allowed_special={"<|endoftext|>"})
idx += [50257 for _ in range(CTX - len(idx))]


What is the worst customer service experience you have ever had? <|endoftext|>


In [60]:
# Make B x T tensor
test_batch = torch.tensor([idx for i in range(BATCH_SIZE)]).to(device)
print(test_batch.shape)

torch.Size([14, 200])


In [61]:
class TestEncoder(nn.Module):

    def __init__(self):
        super().__init__()
        
        self.token_embedding_table = nn.Embedding(VOCAB_SIZE, EMBED_DIM)
        self.position_embedding_table = PositionalEncoding(EMBED_DIM, DROPOUT, CTX).to(device)
        self.encoder = nn.Sequential(*[Encoder() for _ in range(NX)])

    def forward(self, x, targets=None):

        tok_enb = self.token_embedding_table(x) # B, T, C
        pos_enb = self.position_embedding_table(torch.transpose(tok_enb, 0, 1).to(device)) # T, B, C
        x = torch.transpose(pos_enb, 1, 0).to(device) # B, T, C

        # Feed into encoder
        enc_out = self.encoder(x) # -> B, T, C

        return enc_out


In [63]:
blargh = TestEncoder()
cunt = blargh(test_batch)

IndexError: index out of range in self