##### Input Embeddings and Positional Encoding

In [21]:
import torch
import torch.nn as nn
import math

1. token embedding player
2. constant positional encoding matrix (sinusoidal pattern)
3. return token + positional embedding

In [None]:
class TokenPositionalEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_dim, max_len=512):
        super().__init__()
        # token embedding layer: we are mapping token indices to embedding vectors
        self.token_embedding=nn.Embedding(vocab_size, embed_dim)
        self.embed_dim=embed_dim

        # creating constant positional encoding matrix with sinusoidal pattern
        pe=torch.zeros(max_len, embed_dim)
        position=torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # shape: (max_len, 1)

        # computing the div_term for the sinusoidal frequencies
        div_term=torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))
        
        # applying sine to even indices in the embedding dimension
        pe[:, 0::2]=torch.sin(position * div_term)

        # applying cosine to odd indices in the embedding dimension
        pe[:, 1::2]=torch.cos(position * div_term)

        # register pe as a buffer so it's saved with the model but not as a parameter
        self.register_buffer('pe', pe)

    def forward(self, x):

        # "x" is tensor of shape (batch_size, seq_len) with token indices
        seq_len=x.size(1)

        # get token embeddings: (batch_size, seq_len, embed_dim)
        token_emb=self.token_embedding(x)

        # get positional embeddings for the sequence length: (1, seq_len, embed_dim)
        pos_emb=self.pe[:seq_len, :].unsqueeze(0)

        # add token and positional embeddings
        return token_emb + pos_emb # tensor of shape (batch_size, seq_len, embed_dim)

##### Scaled Dot-Product Attention

In [None]:
def scaled_dot_product_attention(query, key, value, mask=None, visualize=False):
    d_k=query.size(-1)  # get the dimension of the key (embedding size)

    # computing raw attention scores by matrix multiplying query and key_transpose, then scale
    scores=torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

    # if a mask is provided, set masked positions to a large negative value
    if mask is not None:
        scores=scores.masked_fill(mask == 0, float('-inf'))

    # applying softmax to get normalized attention weights
    attn_weights=torch.softmax(scores, dim=-1)

    # multiplying attention weights by the value vectors to get the output
    op=torch.matmul(attn_weights, value)

    # if visualize is True, return both output and attention weights
    if visualize:
        return op, attn_weights
    # otherwise, return only the output
    return op

##### Multi-head attention

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        self.embed_dim=embed_dim
        self.num_heads=num_heads
        self.head_dim=embed_dim // num_heads

        # linear layers to project input to queries, keys, and values
        self.q_proj=nn.Linear(embed_dim, embed_dim)
        self.k_proj=nn.Linear(embed_dim, embed_dim)
        self.v_proj=nn.Linear(embed_dim, embed_dim)
        # final linear layer to combine outputs from all heads
        self.out_proj=nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value, mask=None, visualize=False):
        batch_size, seq_len, _ = query.size()

        # auxiliary function to reshape input for multi-head attention
        def shape(x):
            # (batch_size, seq_len, embed_dim) -> (batch_size, num_heads, seq_len, head_dim)
            return x.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Project inputs to multi-head Q, K, V
        q=shape(self.q_proj(query))
        k=shape(self.k_proj(key))
        v=shape(self.v_proj(value))

        # if mask is provided, expand its dimensions for all heads
        if mask is not None:
            mask=mask.unsqueeze(1)  # (batch_size, 1, seq_len_q, seq_len_k)

        # compute attention output (and optionally attention weights) for all heads
        attn_output=scaled_dot_product_attention(q, k, v, mask=mask, visualize=visualize)
        if visualize:
            attn_output, attn_weights = attn_output

        # concatenate outputs from all heads
        attn_output=attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
        
        # final linear projection
        output=self.out_proj(attn_output)

        # return output (and attention weights if visualize=True)
        if visualize:
            return output, attn_weights
        return output

##### Feed-Forward Networks and Layer Normalization

In [None]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, embed_dim, ffn_dim, dropout=0.1):
        super().__init__()
        self.linear1=nn.Linear(embed_dim, ffn_dim)
        self.relu=nn.ReLU()
        self.linear2=nn.Linear(ffn_dim, embed_dim)
        self.dropout=nn.Dropout(dropout)

    def forward(self, x):
        # x: (batch_size, seq_len, embed_dim)
        return self.linear2(self.dropout(self.relu(self.linear1(x))))

class AddNorm(nn.Module):
    def __init__(self, embed_dim, dropout=0.1):
        super().__init__()
        self.norm=nn.LayerNorm(embed_dim)
        self.dropout=nn.Dropout(dropout)

    def forward(self, x, sublayer_out):
        # residual connection followed by layer normalization
        return self.norm(x + self.dropout(sublayer_out))

##### Encoder Layer

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ffn_dim, dropout=0.1):
        super().__init__()
        # multi-head self-attention mechanism
        self.self_attn=MultiHeadAttention(embed_dim, num_heads)

        # add & Norm layer after self-attention
        self.addnorm1=AddNorm(embed_dim, dropout)

        # position-wise Feed-Forward Network
        self.ffn=PositionwiseFeedForward(embed_dim, ffn_dim, dropout)

        # add & Norm layer after feed-forward
        self.addnorm2=AddNorm(embed_dim, dropout)

    def forward(self, x, mask=None, visualize=False):
        # self-attention sublayer with residual connection and normalization
        # attn_out: output of self-attention; attn_weights: attention weights (if visualize=True)
        attn_out=self.self_attn(x, x, x, mask=mask, visualize=visualize)
        if visualize:
            attn_out, attn_weights = attn_out  # unpack output and attention weights

        # add & Norm: residual connection (x + attn_out) followed by layer normalization
        x=self.addnorm1(x, attn_out)

        # feed-forward sublayer with residual connection and normalization
        ffn_out=self.ffn(x)

        # add & Norm: residual connection (x + ffn_out) followed by layer normalization
        x=self.addnorm2(x, ffn_out)
        if visualize:
            return x, attn_weights  # return output and attention weights for visualization
        return x  # return output only

##### Decoder Layer with Masked Attention

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ffn_dim, dropout=0.1):
        super().__init__()
        # masked multi-head self-attention (causal): attends to previous tokens only
        self.self_attn=MultiHeadAttention(embed_dim, num_heads)
        self.addnorm1=AddNorm(embed_dim, dropout)

        # encoder-decoder (cross) attention: attends to encoder outputs
        self.cross_attn=MultiHeadAttention(embed_dim, num_heads)
        self.addnorm2=AddNorm(embed_dim, dropout)

        # position-wise feed-forward network
        self.ffn=PositionwiseFeedForward(embed_dim, ffn_dim, dropout)
        self.addnorm3=AddNorm(embed_dim, dropout)

    def forward(self, x, enc_out, tgt_mask=None, memory_mask=None, visualize=False):
        attn_weights={}

        # we are numbering the steps for clarity
        # 1. masked self-attention: each position can only attend to earlier positions (causal)
        self_attn_out=self.self_attn(x, x, x, mask=tgt_mask, visualize=visualize)
        if visualize:
            self_attn_out, self_attn_weights=self_attn_out  # unpack output and attention weights
            attn_weights['self_attn']=self_attn_weights

        # add & Norm: residual connection and layer normalization
        x=self.addnorm1(x, self_attn_out)

        # 2. cross-attention: decoder attends to encoder outputs (full attention)
        cross_attn_out=self.cross_attn(x, enc_out, enc_out, mask=memory_mask, visualize=visualize)
        if visualize:
            cross_attn_out, cross_attn_weights=cross_attn_out  # unpack output and attention weights
            attn_weights['cross_attn']=cross_attn_weights

        # add & Norm: residual connection and layer normalization
        x=self.addnorm2(x, cross_attn_out)

        # 3. feed-forward network: position-wise transformation
        ffn_out=self.ffn(x)

        # add & Norm: residual connection and layer normalization
        x=self.addnorm3(x, ffn_out)

        # if visualize is True, return attention weights for analysis
        if visualize:
            return x, attn_weights
        return x

##### Transformer Model

In [None]:
class Transformer(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, ffn_dim, num_encoder_layers, num_decoder_layers, max_len=512, dropout=0.1):
        super().__init__()
        # embedding layers for source and target sequences (with positional encoding)
        self.src_embedding=TokenPositionalEmbedding(vocab_size, embed_dim, max_len)
        self.target_embedding=TokenPositionalEmbedding(vocab_size, embed_dim, max_len)

        # stack of encoder layers
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(embed_dim, num_heads, ffn_dim, dropout)
            for _ in range(num_encoder_layers)
        ])

        # stack of decoder layers
        self.decoder_layers=nn.ModuleList([
            DecoderLayer(embed_dim, num_heads, ffn_dim, dropout)
            for _ in range(num_decoder_layers)
        ])

        # final linear layer to project decoder output to vocabulary logits
        self.output_proj=nn.Linear(embed_dim, vocab_size)

        # initialize parameters (weights)
        self._reset_parameters()

    def _reset_parameters(self):
        # xavier uniform initialization for all weights with more than 1 dimension
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, target, src_mask=None, target_mask=None, memory_mask=None, visualize=False):
        attn_weights = [] if visualize else None

        # 1. embed source and target tokens (add positional encoding)
        src_emb=self.src_embedding(src)  # (batch_size, src_seq_len, embed_dim)
        target_emb=self.target_embedding(target)  # (batch_size, target_seq_len, embed_dim)

        # 2. pass source embeddings through encoder stack
        enc_out=src_emb
        for layer in self.encoder_layers:
            if visualize:
                # Get encoder output and attention weights for visualization
                enc_out, enc_attn=layer(enc_out, mask=src_mask, visualize=True)
                attn_weights.append({'encoder': enc_attn})
            else:
                enc_out=layer(enc_out, mask=src_mask)

        # 3. pass target embeddings and encoder output through decoder stack
        dec_out=target_emb
        for layer in self.decoder_layers:
            if visualize:
                # get decoder output and attention weights for visualization
                dec_out, dec_attn=layer(dec_out, enc_out, target_mask=target_mask, memory_mask=memory_mask, visualize=True)
                attn_weights.append({'decoder': dec_attn})
            else:
                dec_out=layer(dec_out, enc_out, target_mask=target_mask, memory_mask=memory_mask)

        # 4. project decoder output to vocabulary logits for each position
        logits=self.output_proj(dec_out)  # (batch_size, target_seq_len, vocab_size)

        # return logits and optionally attention weights for visualization
        if visualize:
            return logits, attn_weights
        return logits

In [None]:
# define model hyperparameters
vocab_size = 10000
embed_dim = 512
num_heads = 8
ffn_dim = 2048
num_encoder_layers = 6
num_decoder_layers = 6
max_len = 512
dropout = 0.1

# instantiate the Transformer model
model = Transformer(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    num_heads=num_heads,
    ffn_dim=ffn_dim,
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
    max_len=max_len,
    dropout=dropout
)

# move model to CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)