In [4]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import tiktoken

In [60]:
class BobNet(nn.Module):

    def __init__(self):
        super().__init__()
        self.encoding = tiktoken.get_encoding("r50k_base")
        self.emb_size = self.encoding.n_vocab
        self.emb_channels = 32
        self.emb = nn.Embedding(self.emb_size, self.emb_channels)

    def positional_encoding(self, x):
        d = x.shape[1]
        encoding = x.clone() 
        for pos in range(x.shape[0]):
            for i in range(d):
                factor = 10000 ** (2 * i / d)
                position_tensor = torch.tensor(pos / factor, dtype=torch.float32)
                encoding_val = torch.sin(position_tensor) if i % 2 == 0 else torch.cos(position_tensor)
                encoding[pos, i] += encoding_val
        return encoding


    def forward(self, x):
        x = self.encoding.encode(x)
        x = self.emb(torch.tensor(x))
        x = self.positional_encoding(x)
        return x
    

sup = BobNet()
sup('tiktoken is goated')

tensor([[-0.3952,  0.4097, -0.3282,  1.4431, -2.0988,  1.6565, -0.0883,  0.1826,
         -0.9863, -0.0432,  0.2187,  0.5748,  0.7740,  0.1391, -1.1404,  0.6420,
         -0.8807,  0.9442,  0.4270,  1.9741, -0.4258,  1.2760,  0.8328,  0.5946,
          0.2083,  1.0247, -0.6994, -0.3280, -1.1487,  0.2408, -0.9198,  0.7375],
        [ 1.8742,  0.9816,  0.0296,  2.1033,  0.0170,  0.9689,  1.5992,  1.5757,
          1.2710,  1.5302,  0.7155,  2.7595,  1.7836,  1.4026,  0.3886,  3.0462,
          0.5029,  1.6068, -2.0694, -0.1292, -0.3238,  1.2027,  1.1921,  1.3729,
         -1.4753,  0.9395,  1.7386,  0.9181, -0.7822,  1.3682,  0.1216,  0.4698],
        [-0.8828, -0.3299,  1.3532,  0.7826,  2.3875,  0.3222,  0.8317,  0.9598,
          1.4308,  1.0618,  0.4267,  1.0229, -1.4073,  1.3116, -0.1809,  0.4010,
          0.4912, -0.0079, -0.2039,  0.5864,  0.0920,  2.1852,  1.1640,  2.3092,
         -0.6134,  0.3480,  0.3462,  0.8159, -0.8374,  1.1957,  1.3222,  1.8344],
        [-0.8025, -0.2186