In [1]:
import torch
import torch.nn as nn
import math

SEED = 10
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

## Input Example

In [2]:
x = torch.randint(0, 100, (1, 6))
x

tensor([[37,  5, 32, 67, 32,  5]])

## Embeddings

In [3]:
class InputEmbeddings(nn.Module):
    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_model)


embedding_encoder = InputEmbeddings(d_model=4, vocab_size=100)
output = embedding_encoder(x)
output

tensor([[[-1.8799, -0.8493,  3.3999,  1.4201],
         [-0.1888,  0.1051,  0.4773, -3.1130],
         [ 1.2626,  1.2161, -2.1373, -4.4780],
         [-1.1958,  3.4485, -0.8264, -0.4976],
         [ 1.2626,  1.2161, -2.1373, -4.4780],
         [-0.1888,  0.1051,  0.4773, -3.1130]]], grad_fn=<MulBackward0>)

## Positional Encoding

In [4]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, seq_len, dropout):
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(p=dropout)

        ## (L,d_model)
        pe = torch.zeros(seq_len, d_model)
        ## (L, 1)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float)
            * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # (N, L, d_model)
        pe = pe.unsqueeze(0)

        ## Register Buffer
        self.register_buffer("pe", pe)

    def forward(self, x):
        ## x = x + self.pe[:, : x.shape[1], :].requires_grad_(False)
        x = x + self.pe
        return self.dropout(x)

In [5]:
pe_encoder = PositionalEncoding(d_model=4, seq_len=6, dropout=0.1)
output_pe = pe_encoder(output)
output_pe

tensor([[[-2.0888,  0.1674,  3.7777,  0.0000],
         [ 0.7252,  0.7172,  0.5414, -2.3479],
         [ 0.0000,  0.8888, -0.0000, -3.8646],
         [-1.1719,  2.7317, -0.8849,  0.5577],
         [ 0.5620,  0.6249, -2.3304, -3.8653],
         [-1.2753,  0.4320,  0.5858, -2.3492]]], grad_fn=<MulBackward0>)

## Multihead Attention

In [6]:
class MultiHeadAttentionBlock(nn.Module):
    def __init__(self, d_model, h, dropout):
        super().__init__()
        assert d_model % h == 0, "d_model is not divisible by h"
        self.d_k = d_model // h
        self.h = h
        self.d_model = d_model
        ## Tensores empaquetados
        self.W_q = nn.Linear(d_model, self.d_k * h)
        self.W_k = nn.Linear(d_model, self.d_k * h)
        self.W_v = nn.Linear(d_model, self.d_k * h)

        self.w_o = nn.Linear(self.d_k * h, d_model)
        self.dropout = nn.Dropout(p=dropout)

    @staticmethod
    def scale_dot_prod(Q, K, V, mask=None, dropout=None):
        d_k = Q.shape[-1]

        # (N, h, L, L)
        attention_scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k)
        if mask is not None:
            attention_scores.masked_fill_(mask == 0, -1e9)

        attention_scores = attention_scores.softmax(dim=-1)
        if dropout is not None:
            attention_scores = dropout(attention_scores)

        ## (N, h, L, d_v)
        return attention_scores @ V, attention_scores

    def forward(self, q, k, v, mask=None):
        ## (N, L, d_k*h)
        Q = self.W_q(q)
        K = self.W_k(k)
        ## (N, L, d_v*h)
        V = self.W_v(v)

        # (N, L, h, d_k) --> (N,h,L,d_k)
        Q = Q.view(q.shape[0], -1, self.h, self.d_k).transpose(1, 2)
        # (N, L, h, d_k) --> (N,h,L,d_k)
        K = K.view(k.shape[0], -1, self.h, self.d_k).transpose(1, 2)
        # (N, L, h, d_k) --> (N,h,L,d_k)
        V = V.view(v.shape[0], -1, self.h, self.d_k).transpose(1, 2)

        x, self.attention_scores = self.scale_dot_prod(
            Q, K, V, mask, self.dropout
        )
        x = x.transpose(1, 2).reshape(q.shape[0], -1, self.h * self.d_k)
        return self.w_o(x)


mh_attention = MultiHeadAttentionBlock(d_model=4, h=2, dropout=0.1)
output_mh = mh_attention(output, output, output)
output_mh

tensor([[[-0.0834, -0.5228,  0.0203, -0.0140],
         [ 0.0389, -0.4723, -0.1963, -0.0514],
         [ 0.1011, -0.5052, -0.5318, -0.1051],
         [-0.0349, -0.4584, -0.4303, -0.1006],
         [ 0.1187, -0.5545, -0.5845, -0.1033],
         [ 0.0778, -0.5032, -0.2475, -0.0361]]], grad_fn=<ViewBackward0>)

## Add&Norm

In [7]:
class LayerNormalization(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps
        ## Multiplicative
        self.alpha = nn.Parameter(torch.ones(1))
        ## Additive
        self.bias = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, unbiased=False, keepdim=True)
        return (x - mean) / torch.sqrt(
            var + self.eps
        ) * self.alpha + self.bias


class ResidualConnection(nn.Module):
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.norm = LayerNormalization()

    def forward(self, x, sublayer):
        output = sublayer(x)
        output = self.dropout(output)
        return self.norm(x + output)


residual_mh = ResidualConnection(dropout=0.1)
mh_attention = MultiHeadAttentionBlock(d_model=4, h=2, dropout=0.1)
output = residual_mh(output_pe, lambda x: mh_attention(x, x, x))
output

tensor([[[-0.9388, -0.2235,  1.6751, -0.5128],
         [ 0.5153,  0.6049,  0.6106, -1.7308],
         [ 0.5091,  0.6443,  0.5766, -1.7301],
         [-1.0167,  1.4599, -0.8285,  0.3853],
         [ 1.0495,  0.7769, -0.3532, -1.4732],
         [-0.1209,  0.5691,  1.1103, -1.5585]]], grad_fn=<AddBackward0>)

## Feed Forward

In [8]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff)
        self.w2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.w1(x)
        x = self.relu(x)
        x = self.w2(x)
        return x


residual_ffn = ResidualConnection(dropout=0.1)
ffn = FeedForward(d_model=4, d_ff=8)
output = residual_ffn(output, ffn)
output

tensor([[[-1.2459,  0.4804,  1.3635, -0.5981],
         [ 0.1829,  0.9943,  0.4803, -1.6575],
         [ 0.1813,  1.0141,  0.4568, -1.6523],
         [-1.3174,  1.4358, -0.3721,  0.2538],
         [ 0.8782,  0.8883, -0.2200, -1.5464],
         [-0.4689,  1.0912,  0.7833, -1.4056]]], grad_fn=<AddBackward0>)

## Encoder

In [9]:
class EncoderBlock(nn.Module):
    def __init__(self, d_model, d_ff, h, dropout):
        super().__init__()
        self.mh_attention = MultiHeadAttentionBlock(d_model, h, dropout)

        self.ffn = FeedForward(d_model, d_ff)
        self.residuals = nn.ModuleDict(
            dict(
                mh=ResidualConnection(dropout),
                ffn=ResidualConnection(dropout),
            )
        )

    def forward(self, x):
        x = self.residuals["mh"](x, lambda x: self.mh_attention(x, x, x))
        x = self.residuals["ffn"](x, self.ffn)
        return x


encoder = EncoderBlock(d_model=4, d_ff=8, h=2, dropout=0.1)
encoder(output_pe)

tensor([[[-1.4794, -0.3560,  0.9296,  0.9059],
         [ 0.9185,  0.6873,  0.0322, -1.6380],
         [ 0.9489,  0.8417, -0.2676, -1.5230],
         [-0.6131,  1.1944, -1.2978,  0.7165],
         [ 1.2573,  0.6877, -0.8072, -1.1378],
         [-0.3268,  1.6067, -0.1436, -1.1363]]], grad_fn=<AddBackward0>)

In [10]:
class TransformerEncoder(nn.Module):
    def __init__(self, N, d_model, d_ff, h, dropout):
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        self.h = h
        self.dropout = dropout
        self.encoders = nn.ModuleList(
            [
                EncoderBlock(self.d_model, self.d_ff, self.h, self.dropout)
                for _ in range(N)
            ]
        )

    def forward(self, x):
        for encoder in self.encoders:
            x = encoder(x)
        return x


transformer_encoder = TransformerEncoder(
    N=6, d_model=4, d_ff=8, h=2, dropout=0.1
)
encoder_output = transformer_encoder(output_pe)
encoder_output

tensor([[[-1.1471, -0.7726,  1.3095,  0.6102],
         [ 0.8043, -1.5599,  0.9410, -0.1854],
         [ 0.7972, -1.3047,  1.1294, -0.6219],
         [-1.2670,  0.3383, -0.4971,  1.4258],
         [ 1.5233,  0.1592, -1.1868, -0.4956],
         [-0.3347, -1.1785,  1.5795, -0.0662]]], grad_fn=<AddBackward0>)

## Decoder

In [11]:
x_decoder = torch.randint(0, 200, (2, 6))
embedding_decoder = InputEmbeddings(d_model=4, vocab_size=200)
output_decoder = embedding_decoder(x_decoder)

pe_decoder = PositionalEncoding(d_model=4, seq_len=6, dropout=0.1)
output_pe_decoder = pe_decoder(output_decoder)
output_pe_decoder

tensor([[[ 1.4902, -0.1060, -0.2673,  0.9572],
         [ 0.0000, -1.3015,  3.5659, -0.9811],
         [ 1.9236, -4.7471,  1.0052, -0.0000],
         [-1.0813,  1.5492, -0.6965,  2.6818],
         [ 0.9349,  2.8193,  0.6835,  5.1822],
         [-0.1522, -3.9696,  1.0385, -1.7455]],

        [[-0.0000, -1.7481, -0.6793,  2.1605],
         [-1.4601, -2.3530, -0.7971,  0.4837],
         [-1.2516,  0.9073, -2.5760, -1.4898],
         [ 3.9898, -2.8120, -0.0000, -0.0000],
         [-2.0904,  3.8685,  0.0390,  1.1420],
         [-2.1899, -2.4715,  0.2579,  1.6673]]], grad_fn=<MulBackward0>)

In [12]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model, d_ff, h, dropout):
        super().__init__()
        self.causal_mh_attention = MultiHeadAttentionBlock(
            d_model, h, dropout
        )
        self.cross_attention = MultiHeadAttentionBlock(d_model, h, dropout)
        self.ffn = FeedForward(d_model, d_ff)
        self.residuals = nn.ModuleDict(
            dict(
                causal=ResidualConnection(dropout),
                cross=ResidualConnection(dropout),
                ffn=ResidualConnection(dropout),
            )
        )

    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        x = self.residuals["causal"](
            x, lambda x: self.causal_mh_attention(x, x, x, tgt_mask)
        )
        x = self.residuals["cross"](
            x,
            lambda x: self.cross_attention(
                x, encoder_output, encoder_output, src_mask
            ),
        )
        x = self.residuals["ffn"](x, self.ffn)
        return x


class TransformerDecoder(nn.Module):
    def __init__(self, N, d_model, d_ff, h, dropout):
        super().__init__()
        self.layers = N
        self.decoders = nn.ModuleList(
            [DecoderBlock(d_model, d_ff, h, dropout) for _ in range(N)]
        )

    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        for decoder in self.decoders:
            x = decoder(x, encoder_output, src_mask, tgt_mask)

        return x


transformer_decoder = TransformerDecoder(
    N=6, d_model=4, d_ff=8, h=2, dropout=0.1
)
decoder_output = transformer_decoder(output_pe_decoder, encoder_output)
decoder_output

tensor([[[ 0.2385,  1.3848, -0.2166, -1.4066],
         [-0.2838,  0.8749,  0.9249, -1.5160],
         [ 1.2895, -0.1632,  0.3523, -1.4786],
         [-0.9994,  1.6398, -0.5515, -0.0889],
         [-0.1921,  1.5629, -0.1464, -1.2244],
         [ 1.0752, -0.8200,  0.9065, -1.1617]],

        [[ 0.1356,  0.0096, -1.4816,  1.3365],
         [-1.2769,  0.2926,  1.4407, -0.4565],
         [-0.6868,  1.6583, -0.0942, -0.8773],
         [ 0.7595,  1.1732, -0.6672, -1.2655],
         [-0.8131,  1.6748, -0.1457, -0.7160],
         [-1.2560,  0.6754,  1.2384, -0.6578]]], grad_fn=<AddBackward0>)

## Projection Layer

In [13]:
class ProjectionLayer(nn.Module):
    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.linear_proj = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = self.linear_proj(x)
        ## En general esto no se hace ya que el Loss puede incluir el Softmax
        ## (batch, vocab_size)
        return x.softmax(dim=-1)


proj_layer = ProjectionLayer(d_model=4, vocab_size=200)
logits = proj_layer(decoder_output)
logits.shape

torch.Size([2, 6, 200])

In [14]:
## Secuencias Predichas
torch.argmax(logits, dim=-1)

tensor([[ 21, 188,  33, 180,  21,  27],
        [  9,  80, 188,  33, 188,  80]])