In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [13]:
class SequenceDataset(Dataset):

    def __init__(self, txt_path, context_len):
        pass

    def __len__(self):
        pass

    def __getitem__(self, idx):
        pass

In [2]:
class PositionalEncoder(nn.Module):

    def __init__(self, d_model, seq_len, device, p=0.1):
        super(PositionalEncoder, self).__init__()
        self.pe = torch.arange(seq_len).unsqueeze(-1).repeat(1, d_model).type(torch.float32)
        even_pos = torch.arange(0, d_model, 2)
        self.pe[:, ::2] = torch.sin(self.pe[:, ::2] / (10000 ** (even_pos/d_model)))
        self.pe[:, 1::2] = torch.cos(self.pe[:, 1::2] / (10000 ** ((even_pos + 1)/d_model)))
        self.pe = self.pe.unsqueeze(0).to(device)
        self.dropout = nn.Dropout(p=p)

    # x has shape [batch, seq_len, embed_dim]
    def forward(self, x):
        return self.dropout(x + self.pe)


class MultiHeadAttention(nn.Module):

    def __init__(self, input_dim, output_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert output_dim % num_heads == 0    # output_dim must be divisible by num_heads
        self.num_heads, self.head_dim = num_heads, output_dim // num_heads
        self.qkv_linear = nn.Linear(input_dim, output_dim * 3)
        self.out_linear = nn.Linear(output_dim, output_dim)

    # x has shape [batch_size, seq_len, input_dim]
    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        if mask is not None:
            mask_batch_size, num_heads, d1, d2 = mask.shape
            assert d1 == seq_len and d2 == seq_len
            assert mask_batch_size == batch_size and num_heads == self.num_heads

        # computing q, k and v across multiple heads with a single linear layer
        qkv = self.qkv_linear(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim * 3)
        qkv = qkv.permute(0, 2, 1, 3)
        q, k, v = qkv.chunk(3, dim=-1)

        attn_output = self.scaled_dot_product(q, k, v, mask)
        attn_output = attn_output.permute(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
        return self.out_linear(attn_output)

    # q, k and v have shape [batch_size, num_heads, seq_len, head_dim]
    def scaled_dot_product(self, q, k, v, mask):
        d_k = k.shape[-1]
        qk = q.matmul(k.transpose(-1, -2)) / d_k
        if mask is not None:
            qk = qk.masked_fill(mask, -torch.inf)
        attn_weights = qk.softmax(dim=-1)
        return attn_weights.matmul(v)
    

class TransformerDecoder(nn.Module):

    def __init__(self, d_model, num_heads, p=0.1):
        super(TransformerDecoder, self).__init__()
        self.linear_layer = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.ReLU(),
            nn.Linear(d_model * 4, d_model)
        )

        self.layers = nn.ModuleList([
            MultiHeadAttention(d_model, d_model, num_heads),
            nn.Dropout(p=p),
            nn.LayerNorm(d_model),
            MultiHeadAttention(d_model, d_model, num_heads),
            nn.Dropout(p=p),
            nn.LayerNorm(d_model),
            self.linear_layer,
            nn.Dropout(p=p),
            nn.LayerNorm(d_model)
        ])

    # x has shape [batch_size, seq_len, embed_dim]
    def forward(self, x):
        prev = x

        for layer in self.layers:
            if isinstance(layer, nn.LayerNorm):
                x = layer(x + prev)
                prev = x
            else:
                x = layer(x)
        return x


class TransformerModel(nn.Module):

    def __init__(self, input_dim, d_model, output_dim, context_len, device):
        super(TransformerModel, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, d_model),
            PositionalEncoder(d_model, context_len, device),
            TransformerDecoder(d_model, 8),
            TransformerDecoder(d_model, 8),
            TransformerDecoder(d_model, 8),
        )
        self.out_proj = nn.Linear(d_model * context_len, output_dim)

    def forward(self, x):
        x = self.layers(x)
        return self.out_proj(x.flatten(1))

In [11]:
model = TransformerModel(
    input_dim=300,
    d_model=512,
    output_dim=24558,
    context_len=100,
    device="cpu"
).to("cpu")
model = torch.compile(model)
x = torch.randn(32, 100, 300).to("cpu")
# model(x).shape

In [4]:
sum([p.numel() for p in model.parameters()])

1270160366

In [77]:
a = torch.tensor([
    [1, 1, 0, 0],
    [1 ,0, 0, 0],
    [1, 1, 1, 0],
    [1, 1, 1, 1],
    [1, 0, 0, 0],
    [0, 0, 0, 0]
])
# [batch_size, seq_len] mask for padding tokens
a_exp = a.unsqueeze(1).unsqueeze(-2).repeat(1, 8, 4, 1) # expanding for MHA
a_exp[1][2]

tensor([[1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0]])

In [37]:
f = open("The-Secret-History.txt").readlines()

In [49]:
f.remove("\n")
while "\n" in f:
    f.remove("\n")

In [53]:
vocab = set()
for line in f:
    line = line.strip().lower()
    words = set(line.split())
    vocab = vocab.union(words)

In [60]:
text = []
for line in f:
    line = line.strip().lower()
    text.extend(line.split())

In [62]:
len(text)

204140

In [9]:
torch._dynamo.config.suppress_errors = True

In [12]:
%%time

out = model(x)

CPU times: user 921 ms, sys: 2.07 s, total: 3 s
Wall time: 2.17 s


In [6]:
(232 * (204140 // 32) / 1000) / 60

24.665466666666667

In [55]:
len(vocab)

24558