In [4]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import tiktoken

In [291]:
class BobNet(nn.Module):

    def __init__(self):
        super().__init__()
        self.encoding = tiktoken.get_encoding("r50k_base")
        self.emb_size = self.encoding.n_vocab
        self.emb_channels = 32
        self.emb = nn.Embedding(self.emb_size, self.emb_channels)
        self.qW = torch.randn(self.emb_channels, self.emb_channels)
        self.kW = torch.randn(self.emb_channels, self.emb_channels)
        self.vW = torch.randn(self.emb_channels, self.emb_channels)
        self.gamma = torch.randn(self.emb_channels)
        self.beta = torch.randn(self.emb_channels)
        self.num_heads = 8
        self.head_dim = self.emb_channels // self.num_heads
        assert self.head_dim * self.num_heads == self.emb_channels, "emb_channels must be divisible by num_heads"

    def positional_encoding(self, x):
        _, seq_length, d = x.shape
        encoding = x.clone()
        pos = torch.arange(seq_length).unsqueeze(1)
        i = torch.arange(d).unsqueeze(0)
        factor = 10000 ** (2 * i / d)
        position_tensor = pos / factor
        for i in i[0]:
            encoding += torch.sin(position_tensor) if i % 2 == 0 else torch.cos(position_tensor)
        return encoding
    
    def split_heads(self, x, batch_size):
        return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

    def self_attention(self, x):
        batch_size = x.shape[0]
        # dim: (batch_size, num_heads, seq_length, head_dim)
        q = self.split_heads(x @ self.qW, batch_size)
        k = self.split_heads(x @ self.kW, batch_size)
        v = self.split_heads(x @ self.vW, batch_size)
        qK = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        qK = self.mask(qK)
        attention_weights = F.softmax(qK, dim=-1)
        output = torch.matmul(attention_weights, v)
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.emb_channels)
        output += x
        output = self.layer_normalization(output)
        return output
        
    def mask(self, x):
        seq_length = x.shape[2]
        mask = torch.tril(torch.ones((seq_length, seq_length), device=x.device))
        mask = mask.unsqueeze(0).unsqueeze(1)
        mask = mask.repeat(x.shape[0], self.num_heads, 1, 1)
        return x.masked_fill(mask == 0, float('-inf'))
    
    def layer_normalization(self, x):
        mean = torch.mean(x, dim=-1, keepdim=True)
        var = torch.var(x, dim=-1, keepdim=True)
        return self.gamma * (x - mean) / torch.sqrt(var + 1e-5) + self.beta
    
    def feed_forward(self, x):
        return x


    def forward(self, x):
        tokenized = [self.encoding.encode(sentence) for sentence in x]
        max_length = max(len(t) for t in tokenized)
        padded = [t + [0] * (max_length - len(t)) for t in tokenized]
        input_tensor = torch.tensor(padded)
        x = self.emb(input_tensor)
        x = self.positional_encoding(x)
        x = self.self_attention(x)
        # x = self.feed_forward(x)
        
    

sup = BobNet()
sup(['sup bro how u doin', 'brsup bro how u doinp', 'sup bro how u'])

torch.Size([3, 8, 32])
