In [7]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

In [8]:
from datasets import load_dataset

In [9]:
imdb_dataset = load_dataset("imdb", download_mode="reuse_cache_if_exists")

train_dataset = imdb_dataset["train"]
test_dataset = imdb_dataset["test"]

print(len(train_dataset))
print(len(test_dataset))

Generating train split: 100%|█████████████████████████████████████████| 25000/25000 [00:00<00:00, 443389.39 examples/s]
Generating test split: 100%|██████████████████████████████████████████| 25000/25000 [00:00<00:00, 466886.62 examples/s]
Generating unsupervised split: 100%|██████████████████████████████████| 50000/50000 [00:00<00:00, 331271.71 examples/s]

25000
25000





In [10]:
def multihead_attention(Q, K, V, nheads):

    Q = Q.chunk(nheads, dim=-1)
    K = K.chunk(nheads, dim=-1)
    V = V.chunk(nheads, dim=-1)

    outputs, weights = []
    for i in range(nheads):
        Q_i = Q[i]
        K_i = K[i]
        V_i = V[i]

        downscaler = torch.sqrt(torch.tensor(Q_i.size(-1)))
        
        # attention scores (b, m, n)
        S = torch.matmul(Q_i, K_i.transpose(-2, -1)) / downscaler  # 4,2,10 -> 4,10,2
    
        attention_weights = nn.functional.softmax(S, dim=1)
        
        out = torch.matmul(
            attention_weights,
            V_i
        )
        outputs.append(out)
        weights.append(attention_weights)

    out = torch.cat(outputs, dim=-1)
    attention_weights = torch.stack(weights, dim=1)

    return out, attention_weights

In [13]:
class Word2Vec(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(Word2Vec, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.predictor = nn.Linear(embedding_dim, vocab_size)

    def forward(x):
        x = self.embedding(x)
        x = self.predictor(x)
        x = nn.functional.log_softmax(x, dim=-1)
        return x

In [17]:
class PositionalEncoding(nn.Module):
    def __init__(self, embedding_dim, dropout=0.1, max_len=5000): # embedding dim should match model size
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout()

        self.pe = torch.zeros(max_len, embedding_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqeeze(1)
        # term in the denominator of the exponent for the sinusoidal function
        # creates different frequencies based on the dimension
        div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-torch.log(torch.tensor(10000.0)) / embedding_dim))
        # compute sin for even indices
        pe[:, 0::2] = torch.sin(position * div_term)
        # compute cos for odd indices
        pe[:, 1::2] = torch.cos(position * div_term)
        # add batch dimension
        pe = pe.unsqueeze(0)
        # put pe into non-parameter loadable state
        self.register_buffer('pe', pe)

    def forward(x):
        x = x + self.pe[:, :x.size(1), :] # all elements from pe up to the size of x second dim
        x = self.dropout(x)
        return x


In [None]:
# TODO vocab transformer model etc