In [9]:
import tiktoken
import torch
import torch.nn as nn 
import math
from torch.nn import functional as F
from model import Encoder

# -------- CONSTANTS -------- # 
device = "cpu" # "cuda:0" if torch.cuda.is_available() else "cpu"
VOCAB_SIZE = 50257 + 1 # 51000
LR = 6e-4 # 3e-4
DROPOUT = 0.2
HEADS = 8
NX = 8
LR = 3e-4 # 6e-4
BATCH_SIZE = 14 # 64
CTX = 200 # 256
EMBED_DIM = 584
# --------------------------- # 
enc = tiktoken.get_encoding("gpt2")

In [11]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [30]:
train_example = "What is the worst customer service experience you have ever had? "
train_example += '<|endoftext|>'
print(train_example)
idx = enc.encode(train_example, allowed_special={"<|endoftext|>"})
idx += [50257 for _ in range(CTX - len(idx))]
print(idx)


What is the worst customer service experience you have ever had? <|endoftext|>
[2061, 318, 262, 5290, 6491, 2139, 1998, 345, 423, 1683, 550, 30, 220, 50256, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 502

In [31]:
# Make B x T tensor
test_batch = torch.tensor([idx for i in range(BATCH_SIZE)]).to(device)
print(test_batch.shape)

torch.Size([14, 200])


In [32]:
# Make the mask based on padding in tensor
mask = [[0 if token == 50257 else 1 for token in tensor] for tensor in test_batch]
mask = torch.tensor(mask)
print(mask.shape)

torch.Size([14, 200])


In [42]:
class TestEncoder(nn.Module):

    def __init__(self):
        super().__init__()
        
        self.token_embedding_table = nn.Embedding(VOCAB_SIZE, EMBED_DIM)
        self.position_embedding_table = PositionalEncoding(EMBED_DIM, DROPOUT, CTX).to(device)
        self.encoder = nn.Sequential(*[Encoder() for _ in range(NX)])

    def forward(self, x, targets=None):
        
        idx = x[0]
        mask = x[1]

        tok_enb = self.token_embedding_table(idx) # B, T, C
        pos_enb = self.position_embedding_table(torch.transpose(tok_enb, 0, 1).to(device)) # T, B, C
        idx = torch.transpose(pos_enb, 1, 0).to(device) # B, T, C

        # Feed into encoder
        enc_out = self.encoder((idx, mask)) # -> B, T, C

        return enc_out


In [44]:
model = TestEncoder().to(device)
logits = model((test_batch, mask))

TypeError: layer_norm(): argument 'input' (position 1) must be Tensor, not tuple

In [25]:
cunt.shape
cunt

tensor([[[-1.4836e+00,  2.9702e-01, -1.0832e+00,  ...,  1.6547e+00,
          -3.9859e+00,  4.0618e-03],
         [-5.6330e-02, -8.7759e-01,  8.0295e-01,  ...,  3.2260e-01,
          -9.2082e-01,  2.6529e+00],
         [ 2.3402e-01, -1.3602e-01,  1.3140e+00,  ...,  2.4626e-01,
          -2.9459e+00,  3.2325e+00],
         ...,
         [ 1.0202e+00,  2.8104e-01, -6.5004e-01,  ...,  2.0826e+00,
          -6.8240e-01,  1.1810e-01],
         [-6.9696e-01, -2.9717e+00, -2.0448e+00,  ..., -4.4882e-01,
          -5.0821e-01, -2.2952e-01],
         [-1.3400e+00, -3.1098e+00, -3.0651e+00,  ...,  1.4762e+00,
          -6.1024e-01,  1.5624e+00]],

        [[-7.2791e-02,  1.5076e+00, -9.9834e-02,  ...,  5.4619e-01,
          -3.1774e+00, -1.3983e+00],
         [ 5.7119e-01, -6.0368e-01,  1.4086e+00,  ...,  9.9300e-01,
          -2.1463e+00,  1.9085e+00],
         [-2.9002e+00,  1.3395e-02, -4.1856e-01,  ..., -1.4463e-01,
          -2.6331e+00,  3.9858e+00],
         ...,
         [ 3.6525e-01, -2