##### Input Embeddings and Positional Encoding

In [3]:
import torch
import torch.nn as nn
import math

##### Input Embeddings and Positional Encoding
1. token embedding player
2. constant positional encoding matrix (sinusoidal pattern)
3. return token + positional embedding

In [None]:
class TokenPositionalEmbedding(nn.Module):
    def __init__(self, sizeOfVocab, embedDimension, maxLength=512):
        super().__init__()
        # token embedding layer-> we are mapping token indices to embedding vectors
        self.tokenEmbedding=nn.Embedding(sizeOfVocab, embedDimension)
        self.embedDimension=embedDimension

        # creating constant positional encoding matrix with sinusoidal pattern
        pe=torch.zeros(maxLength, embedDimension)
        position=torch.arange(0, maxLength, dtype=torch.float).unsqueeze(1)  # shape-> (maxLength, 1)

        # computing the divisionTerm for the sinusoidal frequencies
        divisionTerm=torch.exp(torch.arange(0, embedDimension, 2).float() * (-math.log(10000.0) / embedDimension))
        
        # applying sine to even indices in the embedding dimension
        pe[:, 0::2]=torch.sin(position * divisionTerm)

        # applying cosine to odd indices in the embedding dimension
        pe[:, 1::2]=torch.cos(position * divisionTerm)

        # register pe as a buffer so it's saved with the model but not as a parameter
        self.register_buffer('pe', pe)

    def forward(self, x):

        # "x" is tensor of shape (batch_size,sequenceLength) with token indices
        sequenceLength=x.size(1)

        # get token embeddings-> (batch_size,sequenceLength,embedDimension)
        token_emb=self.tokenEmbedding(x)

        # get positional embeddings for the sequence length-> (1,sequenceLength,embedDimension)
        pos_emb=self.pe[:sequenceLength, :].unsqueeze(0)

        # add token and positional embeddings
        return token_emb + pos_emb # tensor of shape (batch_size,sequenceLength,embedDimension)

##### Scaled Dot-Product Attention

In [None]:
def scaled_dot_product_attention(query, key, value, mask=None, visualize=False):
    d_k=query.size(-1)  # get the dimension of the key (embedding size)

    # computing raw attention scores by matrix multiplying query and key_transpose, then scale
    scores=torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

    # if a mask is provided,set masked positions to a large negative value
    if mask is not None:
        scores=scores.masked_fill(mask == 0, float('-inf'))

    # applying softmax to get normalized attention weights
    attentionWeights=torch.softmax(scores, dim=-1)

    # multiplying attention weights by the value vectors to get the output
    op=torch.matmul(attentionWeights, value)

    # if visualize is True, return both output and attention weights
    if visualize:
        return op, attentionWeights
    # otherwise, return only the output
    return op

##### Multi-head attention

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embedDimension, numHeads):
        super().__init__()
        assert embedDimension % numHeads == 0, "embedDimension must be divisible by numHeads"
        self.embedDimension=embedDimension
        self.numHeads=numHeads
        self.head_dim=embedDimension // numHeads

        # linear layers to project input to queries, keys, and values
        self.q_proj=nn.Linear(embedDimension, embedDimension)
        self.k_proj=nn.Linear(embedDimension, embedDimension)
        self.v_proj=nn.Linear(embedDimension, embedDimension)
        # final linear layer to combine outputs from all heads
        self.out_proj=nn.Linear(embedDimension, embedDimension)

    def forward(self, query, key, value, mask=None, visualize=False):
        batch_size, sequenceLength, _ = query.size()

        # auxiliary function to reshape input for multi-head attention
        def shape(x):
            # (batch_size, sequenceLength, embedDimension) -> (batch_size, numHeads, sequenceLength, head_dim)
            return x.view(batch_size, sequenceLength, self.numHeads, self.head_dim).transpose(1, 2)

        # Project inputs to multi-head Q, K, V
        q=shape(self.q_proj(query))
        k=shape(self.k_proj(key))
        v=shape(self.v_proj(value))

        # if mask is provided, expand its dimensions for all heads
        if mask is not None:
            mask=mask.unsqueeze(1)  # (batch_size, 1, seq_len_q, seq_len_k)

        # compute attention output (and optionally attention weights) for all heads
        attn_output=scaled_dot_product_attention(q, k, v, mask=mask, visualize=visualize)
        if visualize:
            attn_output, attentionWeights = attn_output

        # concatenate outputs from all heads
        attn_output=attn_output.transpose(1, 2).contiguous().view(batch_size, sequenceLength, self.embedDimension)
        
        # final linear projection
        output=self.out_proj(attn_output)

        # return output (and attention weights if visualize=True)
        if visualize:
            return output, attentionWeights
        return output

##### Feed-Forward Networks and Layer Normalization

In [None]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, embedDimension, feedForwardDimension, dropout=0.1):
        super().__init__()
        self.linear1=nn.Linear(embedDimension, feedForwardDimension)
        self.relu=nn.ReLU()
        self.linear2=nn.Linear(feedForwardDimension, embedDimension)
        self.dropout=nn.Dropout(dropout)

    def forward(self, x):
        # x: (batch_size, sequenceLength, embedDimension)
        return self.linear2(self.dropout(self.relu(self.linear1(x))))

class AddNorm(nn.Module):
    def __init__(self, embedDimension, dropout=0.1):
        super().__init__()
        self.norm=nn.LayerNorm(embedDimension)
        self.dropout=nn.Dropout(dropout)

    def forward(self, x, sublayer_out):
        # residual connection followed by layer normalization
        return self.norm(x + self.dropout(sublayer_out))

##### Encoder Layer

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, embedDimension, numHeads, feedForwardDimension, dropout=0.1):
        super().__init__()
        # multi-head self-attention mechanism
        self.self_attn=MultiHeadAttention(embedDimension, numHeads)

        # add & Norm layer after self-attention
        self.addnorm1=AddNorm(embedDimension, dropout)

        # position-wise Feed-Forward Network
        self.ffn=PositionwiseFeedForward(embedDimension, feedForwardDimension, dropout)

        # add & Norm layer after feed-forward
        self.addnorm2=AddNorm(embedDimension, dropout)

    def forward(self, x, mask=None, visualize=False):
        # self-attention sublayer with residual connection and normalization
        # attn_out-> output of self-attention; attentionWeights-> attention weights (if visualize=True)
        attn_out=self.selfAttention(x, x, x, mask=mask, visualize=visualize)
        if visualize:
            attn_out, attentionWeights = attn_out  # unpack output and attention weights

        # add & Norm-> residual connection (x + attn_out) followed by layer normalization
        x=self.addnorm1(x, attn_out)

        # feed-forward sublayer with residual connection and normalization
        ffn_out=self.ffn(x)

        # add & Norm-> residual connection (x + ffn_out) followed by layer normalization
        x=self.addnorm2(x, ffn_out)
        if visualize:
            return x, attentionWeights  # return output and attention weights for visualization
        return x  # return output only

##### Decoder Layer with Masked Attention

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, embedDimension, numHeads, feedForwardDimension, dropout=0.1):
        super().__init__()
        # masked multi-head self-attention (causal)-> attends to previous tokens only
        self.selfAttention=MultiHeadAttention(embedDimension, numHeads)
        self.addnorm1=AddNorm(embedDimension, dropout)

        # encoder-decoder (cross) attention-> attends to encoder outputs
        self.crossAttention=MultiHeadAttention(embedDimension, numHeads)
        self.addnorm2=AddNorm(embedDimension, dropout)

        # position-wise feed-forward network
        self.ffn=PositionwiseFeedForward(embedDimension, feedForwardDimension, dropout)
        self.addnorm3=AddNorm(embedDimension, dropout)

    def forward(self, x, enc_out, tgt_mask=None, memoryMask=None, visualize=False):
        attentionWeights={}

        # we are numbering the steps for clarity
        # 1. masked self-attention-> each position can only attend to earlier positions (causal)
        selfAttentionOutput=self.selfAttention(x, x, x, mask=tgt_mask, visualize=visualize)
        if visualize:
            selfAttentionOutput, selfAttentionWeights=selfAttentionOutput  # unpack output and attention weights
            attentionWeights['self_attn']=selfAttentionWeights

        # add & Norm-> residual connection and layer normalization
        x=self.addnorm1(x, selfAttentionOutput)

        # 2. cross-attention-> decoder attends to encoder outputs (full attention)
        crossAttentionOutput=self.crossAttention(x, enc_out, enc_out, mask=memoryMask, visualize=visualize)
        if visualize:
            crossAttentionOutput, crossAttentionWeights=crossAttentionOutput  # unpack output and attention weights
            attentionWeights['cross_attn']=crossAttentionWeights

        # add & Norm-> residual connection and layer normalization
        x=self.addnorm2(x, crossAttentionOutput)

        # 3. feed-forward network-> position-wise transformation
        ffn_out=self.ffn(x)

        # add & Norm-> residual connection and layer normalization
        x=self.addnorm3(x, ffn_out)

        # if visualize is True, return attention weights for analysis
        if visualize:
            return x, attentionWeights
        return x

##### Transformer Model

In [None]:
class Transformer(nn.Module):
    def __init__(self, sizeOfVocab, embedDimension, numHeads, feedForwardDimension, numEncoderLayers, numDecoderLayers, maxLength=512, dropout=0.1):
        super().__init__()
        # embedding layers for source and target sequences (with positional encoding)
        self.src_embedding=TokenPositionalEmbedding(sizeOfVocab, embedDimension, maxLength)
        self.target_embedding=TokenPositionalEmbedding(sizeOfVocab, embedDimension, maxLength)

        # stack of encoder layers
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(embedDimension, numHeads, feedForwardDimension, dropout)
            for _ in range(numEncoderLayers)
        ])

        # stack of decoder layers
        self.decoder_layers=nn.ModuleList([
            DecoderLayer(embedDimension, numHeads, feedForwardDimension, dropout)
            for _ in range(numDecoderLayers)
        ])

        # final linear layer to project decoder output to vocabulary logits
        self.output_proj=nn.Linear(embedDimension, sizeOfVocab)

        # initialize parameters (weights)
        self._reset_parameters()

    def _reset_parameters(self):
        # xavier uniform initialization for all weights with more than 1 dimension
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, target, src_mask=None, targetMask=None, memoryMask=None, visualize=False):
        attentionWeights = [] if visualize else None

        # 1. embed source and target tokens (add positional encoding)
        src_emb=self.src_embedding(src)  # (batch_size,src_seq_len,embedDimension)
        target_emb=self.target_embedding(target)  # (batch_size,target_seq_len,embedDimension)

        # 2. pass source embeddings through encoder stack
        enc_out=src_emb
        for layer in self.encoder_layers:
            if visualize:
                # Get encoder output and attention weights for visualization
                enc_out, enc_attn=layer(enc_out, mask=src_mask, visualize=True)
                attentionWeights.append({'encoder': enc_attn})
            else:
                enc_out=layer(enc_out, mask=src_mask)

        # 3. pass target embeddings and encoder output through decoder stack
        dec_out=target_emb
        for layer in self.decoder_layers:
            if visualize:
                # get decoder output and attention weights for visualization
                dec_out, decoderAttention=layer(dec_out, enc_out, targetMask=targetMask, memoryMask=memoryMask, visualize=True)
                attentionWeights.append({'decoder': decoderAttention})
            else:
                dec_out=layer(dec_out, enc_out, targetMask=targetMask, memoryMask=memoryMask)

        # 4. project decoder output to vocabulary logits for each position
        logits=self.output_proj(dec_out)  # (batch_size,target_seq_len,sizeOfVocab)

        # return logits and optionally attention weights for visualization
        if visualize:
            return logits, attentionWeights
        return logits

In [None]:
# define model hyperparameters
sizeOfVocab = 10000
embedDimension = 512
numHeads = 8
feedForwardDimension = 2048
numEncoderLayers = 6
numDecoderLayers = 6
maxLength = 512
dropout = 0.1

# instantiate the Transformer model
model = Transformer(
    sizeOfVocab=sizeOfVocab,
    embedDimension=embedDimension,
    numHeads=numHeads,
    feedForwardDimension=feedForwardDimension,
    numEncoderLayers=numEncoderLayers,
    numDecoderLayers=numDecoderLayers,
    maxLength=maxLength,
    dropout=dropout
)

# move model to CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)