In [3]:
import torch
from torch import nn
import torch.nn.functional as F
import math

In [20]:
class ScaledDotAttention(nn.Module):
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, valid_lens=None):
        """
        @param q: shape = (batch_size, query_num, hidden_size)
        @param k, v: shape = (batch_size, kv_num, hidden_size)
        @param valid_lens: shape = (batch_size,) or (batch_size, query_num)
        @return: shape = (batch_size, query_num, hidden_size)
        """
        hidden_size = q.shape[-1]
        # scores: shape = (batch_size, query_num, kv_num)
        scores = torch.bmm(q, k.transpose(1, 2)) / math.sqrt(hidden_size)
        attention_weights = ScaledDotAttention.masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(attention_weights), v)

    @staticmethod
    def masked_softmax(scores, valid_lens=None):
        """
        @param scores: shape = (batch_size, query_num, kv_num)
        @param valid_lens: shape = (batch_size,) or (batch_size, query_num)
        @return: shape = (batch_size, query_num, kv_num)
        """
        if valid_lens is not None:
            if valid_lens.dim() == 1:
                mask = torch.arange(scores.shape[-1])[None, :] >= valid_lens[:, None]
                mask = torch.repeat_interleave(mask.unsqueeze(1), scores.shape[1], dim=1)
                # mask = mask.unsqueeze(1).repeat(1, scores.shape[1], 1)
            else:
                mask = torch.arange(scores.shape[-1])[None, None, :] >= valid_lens[:, :, None]
            scores[mask] = -torch.inf
        return F.softmax(scores, dim=-1)

x = ScaledDotAttention.masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))
print(x)
x = ScaledDotAttention.masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))
print(x)

tensor([[[0.5194, 0.4806, 0.0000, 0.0000],
         [0.3737, 0.6263, 0.0000, 0.0000]],

        [[0.3165, 0.3582, 0.3253, 0.0000],
         [0.2171, 0.3787, 0.4042, 0.0000]]])
tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.4159, 0.3399, 0.2442, 0.0000]],

        [[0.4289, 0.5711, 0.0000, 0.0000],
         [0.2249, 0.2022, 0.3421, 0.2309]]])


'\ntensor([[[1.0000, 0.0000, 0.0000, 0.0000],\n         [0.4125, 0.3273, 0.2602, 0.0000]],\n        [[0.5254, 0.4746, 0.0000, 0.0000],\n         [0.3117, 0.2130, 0.1801, 0.2952]]])\n'

In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self, query_size, key_size, value_size, hidden_size, output_size, num_heads, dropout):
        self.num_heads = num_heads
        self.wq = nn.Linear(query_size, hidden_size, bias=False)
        self.wk = nn.Linear(key_size, hidden_size, bias=False)
        self.wv = nn.Linear(value_size, hidden_size, bias=False)
        self.scaledDotAttention = ScaledDotAttention(dropout)
        self.dense = nn.Linear(hidden_size, output_size)

    def forward(self, q, k, v, valid_lens=None):
        """
        @param q: shape = (batch_size, query_num, query_size)
        @param k: shape = (batch_size, kv_num, key_size)
        @param v: shape = (batch_size, kv_num, value_size)
        @param valid_lens: shape = (batch_size,) or (batch_size, query_num)
        @return: shape = (batch_size, query_num, hidden_size)
        """
        q, k, v = map(self.transpose_qkv, (self.wq(q), self.wk(k), self.wv(v)))
        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens, self.num_heads, dim=0)
        cat_hidden = self.scaledDotAttention(q, k, v, valid_lens)
        return self.dense(self.invtranspose_qkv(cat_hidden))
    
    def transpose_qkv(self, tensor):
        """
        @param tensor: shape = (batch_size, num, hidden_size)
        @return: shape = (batch_size * num_heads, num, hidden_size / num_heads)
        """
        tensor = tensor.reshape(tensor.shape[0], tensor.shape[1], self.num_heads, -1)
        tensor = tensor.permute(0, 2, 1, 3)
        return tensor.reshape(-1, tensor.shape[2], tensor.shape[3])

    def invtranspose_qkv(self, tensor):
        """
        @param tensor: shape = (batch_size * num_heads, num, hidden_size / num_heads)
        @return: shape = (batch_size, num, hidden_size)
        """
        tensor = tensor.reshape(-1, self.num_heads, tensor.shape[1], tensor.shape[2])
        tensor = tensor.permute(0, 2, 1, 3)
        return tensor.reshape(tensor.shape[0], tensor.shape[1], -1)

In [6]:
class FFN(nn.Module):
    def __init__(self, hidden_size, ffn_size):
        super().__init__()
        self.dense1 = nn.Linear(hidden_size, ffn_size)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_size, hidden_size)

    def forward(self, x):
        return self.dense2(self.relu(self.dense1(x)))

In [7]:
class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.alpha = nn.Parameter(torch.ones(features))
        self.beta = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.alpha * (x - mean) / (std + self.eps) + self.beta

class AddNorm(nn.Module):
    def __init__(self, hidden_size, dropout=0):
        super().__init__()
        # self.ln = nn.LayerNorm(hidden_size)
        self.ln = LayerNorm(hidden_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, fx):
        return self.ln(x + self.dropout(fx))

In [11]:
class PositionEmbedding(nn.Module):
    def __init__(self, hidden_size, max_length=1024):
        super().__init__()
        self.P = torch.zeros((max_length, hidden_size))
        x = torch.arange(max_length).reshape(-1, 1) /\
            torch.pow(torch.tensor(10000), torch.arange(0, hidden_size, 2) / hidden_size).reshape(1, -1)
        self.P[:, 0::2] = torch.sin(x)
        self.P[:, 1::2] = torch.cos(x)

    def forward(self, x):
        seq_length = x.shape[1]
        return x + self.P[:seq_length, :]

In [21]:
class EncoderLayer(nn.Module):
    def __init__(self, hidden_size, num_heads, ffn_size, dropout):
        super().__init__()
        self.self_attn = MultiHeadAttention(
            hidden_size, hidden_size, hidden_size, hidden_size, hidden_size, num_heads, dropout)
        self.norm1 = AddNorm(hidden_size, dropout)
        self.ffn = FFN(hidden_size, ffn_size)
        self.norm2 = AddNorm(hidden_size, dropout)

    def forward(self, x, valid_lens):
        x = self.norm1(x, self.self_attn(x, x, x, valid_lens))
        return self.norm2(x, self.ffn(x))

class Encoder(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_encoder, num_heads, ffn_size, dropout):
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.position = PositionEmbedding(hidden_size)
        self.dropout = nn.Dropout(dropout)
        self.layers = nn.ModuleList(
            ((f"encoder_{i}", EncoderLayer(hidden_size, num_heads, ffn_size, dropout)) for i in range(num_encoder))
        )

    def forward(self, x, valid_lens):
        # TODO why times sqrt(d_model)
        x = self.dropout(self.position(self.embedding(x) * math.sqrt(self.hidden_size)))
        for layer in self.layers:
            x = layer(x, valid_lens)
        return x

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, hidden_size, num_heads, ffn_size, dropout):
        super().__init__()
        self.self_attn = MultiHeadAttention(
            hidden_size, hidden_size, hidden_size, hidden_size, hidden_size, num_heads, dropout)
        self.norm1 = AddNorm(hidden_size, dropout)
        self.cross_attn = MultiHeadAttention(
            hidden_size, hidden_size, hidden_size, hidden_size, hidden_size, num_heads, dropout)
        self.norm2 = AddNorm(hidden_size, dropout)
        self.ffn = FFN(hidden_size, ffn_size)
        self.norm3 = AddNorm(hidden_size, dropout)

    def forward(self, x, self_kv, encoder_output, valid_lens, seq_lens):
        x = self.norm1(x, self.self_attn(x, self_kv, self_kv, valid_lens))
        x = self.norm2(x, self.cross_attn(x, encoder_output, encoder_output, seq_lens))
        return self.norm3(x, self.ffn(x))

class Decoder(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_decoder, num_heads, ffn_size, dropout):
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.position = PositionEmbedding(hidden_size)
        self.dropout = nn.Dropout(dropout)
        self.layers = nn.ModuleList(
            (f"decoder_{i}", DecoderLayer(hidden_size, num_heads, ffn_size, dropout)) for i in range(num_decoder)
        )
        self.dense = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, valid_lens, encoder_output, seq_lens):
        x = self.dropout(self.position(self.embedding(x) * math.sqrt(self.hidden_size)))
        for layer in self.layers:
            x = layer(x, x, encoder_output, valid_lens, seq_lens)
        return self.dense(x)

    def predict(self, x, encoder_output, seq_lens):
        x = self.dropout(self.position(self.embedding(x) * math.sqrt(self.hidden_size)))
        for layer in self.layers:
            x = layer(x[:, -1, :], x, encoder_output, None, seq_lens)
        return self.dense(x)

In [None]:
class Transformer(nn.Module):
    def __init__(self, hidden_size, encoder_blk, decoder_blk, ffn_size, num_heads, src_vocab, tgt_vocab, dropout):
        super().__init__()
        self.encoder = Encoder(src_vocab, hidden_size, encoder_blk, num_heads, ffn_size, dropout)
        self.decoder = Decoder(tgt_vocab, hidden_size, decoder_blk, num_heads, ffn_size, dropout)

    def encode(self, x, seq_lens):
        return self.encoder(x, seq_lens)

    def decode(self, x, valid_lens, encoder_output, seq_lens):
        return self.decoder(x, valid_lens, encoder_output, seq_lens)

    def forward(self, x, seq_lens, y, valid_lens):
        enc_output = self.encode(x, seq_lens)
        return self.decode(y, valid_lens, enc_output, seq_lens)

    def predict(self, x, encoder_output, seq_lens):
        return self.decoder.predic(x, encoder_output, seq_lens)

In [None]:
def train(model, lr, epoch, dataLoader, print_every=100):
    optim = torch.optim.SGD(model.params(), lr=lr)
    cretira = nn.CrossEntropyLoss()
    for ep in range(epoch):
        print(f"----- epoch {ep} -----")
        step_count = 0
        for X, y, seq_lens in dataLoader:
            optim.zero_grad()
            valid_lens = torch.arange(1, y.shape[1] + 1).unsqueeze(0).repeat(y.shape[0], 1)
            logits = model(X, seq_lens, y, valid_lens)
            loss = cretira(y, logits)
            loss.backward()
            optim.step()
            step_count += 1
            if step_count % print_every == 0:
                print(f"epoch {ep}, step {step_count}, ppl = {torch.exp(loss)}")

In [None]:
dataLoader = None

In [None]:
hidden_size = 512
ffn_size = 2048
N = 6
num_heads = 8
dropout = 0.1
vocab_size = 1000
transformer = Transformer(hidden_size, N, N, ffn_size, num_heads, vocab_size, vocab_size, dropout)

train(transformer, lr=1e-4, epoch=5, dataLoader=dataLoader)