# Example Implementation: Attention Layer and Transformer Block with Multi-Head Attention

* see: https://peterbloem.nl/blog/transformers 
* and https://github.com/pbloem/former/blob/b438731ceeaf6c468f8b961bb07c2adde3b54a9f/former/modules.py

In [None]:
# imports
import math
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

In [None]:
# create test dataset to analyse model components

# define the dimensions
b = 1  # minibatch size
t = 5  # sequence length (number of input vectors for each batch)
k = 6  # vector dimension

# create random input tensor
X = torch.rand(b, t, k)

# print(X)

## SelfAttention class (sequence to sequence operation)

inherits from nn.Module (=base class for all neural network modules):
* enables module to keep track of trainable parameters
* ability to apply forward pass (def forward needs to be defined)

forward pass: 
* input: sequence of t vectors of dimension k and a minibatch dimension b = tensor of size (b,t,k)
* output: out_unified, dot, keys, queries, values (each as tensor of size (b,t,k))

In [None]:
def mask_(matrices, maskval=0.0, mask_diagonal=True):
    """
    Masks out all values in the given batch of matrices where i <= j holds,
    i < j if mask_diagonal is false

    In place operation

    :param tns:
    :return:
    """

    b, h, w = matrices.size()

    indices = torch.triu_indices(h, w, offset=0 if mask_diagonal else 1)
    matrices[:, indices[0], indices[1]] = maskval
    
def contains_nan(tensor):
    return bool((tensor != tensor).sum() > 0)

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, k, heads=4, mask=True):

        """
        :param k (int): input vector dimension (e.g. embedding dimension), which must be divisible by heads (h)
        :param heads (int): number of heads, which will be used to split each input vector into h parts
        :param mask (True): whether to apply masking to the softmax operation. necessary for generation models to prevent the model from seeing the future tokens.
        """        
      
        super().__init__()
    
        assert k % heads == 0
        
        self.k, self.heads, self.mask = k, heads, mask
        
        # setup the queries, keys and values for all heads
        # nn.Linear(k, k, bias=False) creates a linear transformation with k inputs and k outputs, and no bias term.
        self.tokeys    = nn.Linear(k, k, bias=False)
        self.toqueries = nn.Linear(k, k, bias=False)
        self.tovalues  = nn.Linear(k, k, bias=False)

        # setup linear transformation after the multi-head self-attention operation in order to resize the output back to k dimensions
        self.unifyheads = nn.Linear(k, k)
    
    
    def forward(self, x, fix_attn_out_unified=None):
        """
        :param x (torch.tensor): input tensor of shape (b, t, k) where b is the batch size, t is the sequence length and k is the input vector dimension
        :param fix_attn_out_unified (torch.tensor): tensor of shape (b, t, k) which will be used to fix the output of the attention layer within causal mediation analysis
        """

        b, t, k = x.size()
        h = self.heads

        assert k == self.k, f'Input embedding dim ({k}) should match model embedding dim ({self.k})'

        # compute queries, keys and values for all heads, based on the weight matrices for each of them
        queries = self.toqueries(x)
        keys    = self.tokeys(x)   
        values  = self.tovalues(x)
        
        # print("Example queries raw:")
        # print(queries)
        # print("Example queries raw size:")
        # print(queries.size())

        # split keys, queries and values into parts for each head
        s = k // h

        keys    = keys.view(b, t, h, s)
        queries = queries.view(b, t, h, s)
        values  = values.view(b, t, h, s)

        # print("Example queries low dim (b, t, h, s):")
        # print(queries)
        # print("Example queries low dim size (b, t, h, s):")
        # print(queries.size())

        # fold heads into the batch dimension, to ensure, that multiplication is as easy as possible, since dot product needs to be created over all batches and heads anyway
        # - transpose swaps second and third dimension
        # - continous returns continous tensor, which means that the data is layed out continously in memory
        
        keys = keys.transpose(1, 2).contiguous().view(b * h, t, s)
        queries = queries.transpose(1, 2).contiguous().view(b * h, t, s)
        values = values.transpose(1, 2).contiguous().view(b * h, t, s)

        # compute the dot product of queries and keys - also known as the attention map
        dot = torch.bmm(queries, keys.transpose(1, 2))

        # print("Dot queries and keys_T (must have t*t dimension to account for all necessary weights concerning number of input vectors):")
        # print(dot)
        # print("Dot queries and keys_T size (b*h, t, t):")
        # print(dot.size())

        # scale dot product, to avoid very large numbers as a result
        dot = dot / math.sqrt(k)
        assert dot.size() == (b*h, t, t), f'Matrix has size {dot.size()}, expected {(b*h, t, t)}.'

        # print("Dot queries and keys_T scaled:")
        # print(dot)
        # print("Dot queries and keys_T scaled size:")
        # print(dot.size())

        # mask out the lower half of the dot matrix,including the diagonal (e.g. for text generation the model should not be able to see the future words)
        if self.mask: 
            mask_(dot, maskval=float('-inf'), mask_diagonal=False)
        
        # normalize into [0,1] - dot now has row-wise self-attention probabilities over each low dimensional input s
        dot = F.softmax(dot, dim=2) 
        
        # print("dot queries and keys_T softmax:")
        # print(dot)
        # print("dot queries and keys_T softmax size:")
        # print(dot.size())


        # apply the self attention to the values and reshape tensor back
        out = torch.bmm(dot, values)
        out = out.view(b, h, t, s)
        # print("out:")
        # print(out)
        # print("out size:")
        # print(out.size())

        # swap h, t back, unify heads
        out = out.transpose(1, 2).contiguous()
        out = out.view(b, t, s * h)
        # print("out reshaped:")
        # print(out)
        # print("out reshaped size:")
        # print(out.size())
        
        # compute the unified output of the multi-head self-attention operation
        attn_out_unified = self.unifyheads(out)
        # print("out unified:")
        # print(out_unified)
        # print("out unified size:")
        # print(out_unified.size())

        
        # fix parts of the attention layer output within causal mediation analysis
        if fix_attn_out_unified is not None:
            mask = fix_attn_out_unified != 0
            # print("mask:")
            # print(mask.shape)
            # print("masked out unified:")
            # print(attn_out_unified.shape)
            attn_out_unified[mask] = fix_attn_out_unified[mask]


        # dot has size of (h, t, t)
        return attn_out_unified, dot, keys, queries, values
    


Test SelfAttention class

In [None]:
# prepare log directory for graph of SelfAttention model
tensorboard_writer_sa = SummaryWriter("torchlogs/selfattention/")
selfattention = SelfAttention(k=6, heads=3, mask=True)
# write model to tensorboard logs
tensorboard_writer_sa.add_graph(selfattention, X)
tensorboard_writer_sa.close()
# print output
attn_out_unified, dot, keys, queries, values = selfattention.forward(X)
# print(attn_out_unified)

# print output on adapted forward pass 
fix_attn_out_unified = torch.zeros(1, 5, 6)
fix_attn_out_unified[0, 0, 0] = 1
attn_out_unified, dot, keys, queries, values = selfattention.forward(X, fix_attn_out_unified=fix_attn_out_unified)
# print(attn_out_unified)

# to open tensorboard, run in terminal: tensorboard --logdir=./src/model-basic/torchlogs/selfattention/ --port=6006 and open browser at localhost:6006

## TransformerBlock class (sequence to sequence operation)

inherits from nn.Module (=base class for all neural network modules):
* enables module to keep track of trainable parameters
* ability to apply forward pass (def forward needs to be defined)

forward pass: 
* input: sequence of t vectors of dimension k and a minibatch dimension b = tensor of size (b,t,k)
* output: sequence of t vectors of dimension k and a minibatch dimension b = tensor of size (b,t,k)

In [None]:
class TransformerBlock(nn.Module):
  def __init__(self, k, heads, mask=False, ff_hidden_mult=2, dropout=0.05):
    """
    :param k (int): input vector dimension (e.g. embedding dimension)
    :param heads (int): number of heads
    :param mask (False, First): whether to apply masking to the softmax operation. Necessary for generation models to prevent the model from seeing the future tokens
    :param ff_hidden_mult (int): the hidden multiplier in the feedforward neural network, common choice is 4.
    :param dropout (float): the dropout rate
    """
    super().__init__()

    self.attention = SelfAttention(k, heads=heads, mask=mask)
    self.mask = mask

    # setup layer normalization over k dimensions for each input vector separately (subtract mean and divide through standard diviation)
    # nn-LayerNorm(k) expects input of size k in last dimension of tensor
    self.norm1 = nn.LayerNorm(k)
    self.norm2 = nn.LayerNorm(k)

    # setup the feedforward neural network, which is applied to each position separately and identically
    self.ff = nn.Sequential(
      nn.Linear(k, ff_hidden_mult * k),
      nn.ReLU(),
      nn.Linear(ff_hidden_mult * k, k))

    # setup dropout
    self.do1 = nn.Dropout(dropout)
    self.do2 = nn.Dropout(dropout)
  
  def forward(self, x, fix_attn_out_unified=None, fix_ff_out=None):
    """
    :param x (torch.tensor): input tensor
    :param fix_attn_out_unified (torch.tensor): tensor to fix parts of the attention layer output within causal mediation analysis
    :param fix_ff_out (torch.tensor): tensor to fix parts of the feedforward layer output within causal mediation analysis
    """
    # perform attention mechanism
    attn_out_unified, dot, keys, queries, values = self.attention(x, fix_attn_out_unified=fix_attn_out_unified)

    # add the residual connection and performs layer normalization
    x = self.norm1(attn_out_unified + x) # reason why second linear layer in sequential is not followed by activation function
    # apply dropout
    x = self.do1(x)
    # apply feedforward neural network
    ff_out  = self.ff(x)
    
    # fix parts of the feedforward layer output within causal mediation analysis
    if fix_ff_out is not None:
      mask = fix_ff_out != 0
      # print("mask:")
      # print(mask.shape)
      # print("masked ff_out:")
      # print(ff_out.shape)
      ff_out[mask] = fix_ff_out[mask]
            
    # add the residual connection and performs layer normalization
    x = self.norm2(ff_out  + x)
    # apply dropout
    tb_out = self.do2(x)
    
    return tb_out, attn_out_unified, ff_out 

Test TransformerBlock class

In [None]:
# prepare log directory for graph of TransformerBlock model
tensorboard_writer_tb = SummaryWriter("torchlogs/transformerblock/")
transformerblock = TransformerBlock(k=6, heads=3, mask=True, ff_hidden_mult=2, dropout=0.05)
# write model to tensorboard logs
tensorboard_writer_tb.add_graph(transformerblock, X)
tensorboard_writer_tb.close()
# print output
tb_out, attn_out_unified, ff_out = transformerblock.forward(X)
# print(ff_out)

# print output on adapted forward pass 
fix_attn_out_unified = torch.zeros(1, 5, 6)
fix_attn_out_unified[0, 0, 0] = 1
fix_ff_out = torch.zeros(1, 5, 6)
fix_ff_out[0, 0, 0] = 1
tb_out, attn_out_unified, ff_out = transformerblock.forward(X, fix_attn_out_unified=fix_attn_out_unified, fix_ff_out=fix_ff_out)
# print(ff_out)

# to open tensorboard, run in terminal: tensorboard --logdir=./src/model-basic/torchlogs/transformerblock/ --port=6006 and open browser at localhost:6006

## GTransformer class (sequence to sequence operation)

inherits from nn.Module (=base class for all neural network modules):
* enables module to keep track of trainable parameters
* ability to apply forward pass (def forward needs to be defined)

forward pass: 
* input: sequence of t vectors of dimension k and a minibatch dimension b  = tensor of size (b,t,k)
* output: log_probabilities over vocab for each position in the sequence

In [None]:
def d(tensor=None):
    """
    Returns a device string either for the best available device,
    or for the device corresponding to the argument
    :param tensor:
    """
    if tensor is None:
        return 'cuda' if torch.cuda.is_available() else 'cpu'
    return 'cuda' if tensor.is_cuda else 'cpu'

In [None]:
class GTransformer(nn.Module):
    """
    Transformer for generating sequence output based on sequence input.
    """

    def __init__(self, k, heads, depth, t, num_tokens):
        """
        :param k: Embedding dimension
        :param heads: Number of attention heads
        :param depth: Number of transformer blocks
        :param t: Sequence length of input
        :param num_tokens: Number of tokens (usually words) in the vocabulary
        """
        super().__init__()

        self.num_tokens = num_tokens

        # setup embedding layers, which convert input tokens to their corresponding vectors for tokens and positions
        # nn.Embedding is a simple lookup table that stores embeddings of a fixed dictionary and size
        self.token_embedding = nn.Embedding(embedding_dim=k, num_embeddings=num_tokens)
        self.pos_embedding = nn.Embedding(embedding_dim=k, num_embeddings=t)

        # setup linear transformation to unify the embeddings of the tokens themselves and their positional encodings
        self.unify_embeddings = nn.Linear(2*k, k)

        # setup the stack of transformer blocks
        tblocks = []
        for i in range(depth):
            tblocks.append(
                TransformerBlock(k=k, heads=heads, mask=True, ff_hidden_mult=2, dropout=0.00))

        self.tblocks = nn.Sequential(*tblocks)

        # setup linear transformation that converts output vector of dimension k to a vector of dimension num_tokens
        self.toprobs = nn.Linear(k, num_tokens)

    def forward(self, x, fix_attn_out_unified_all=None, fix_ff_out_all=None):
        """
        :param x: tensor (b, t, k) where b is batch size, t is token sequence length, k is the embedding dimension
        :param fix_attn_out_unified_all: list of tensors of shape (b, t, k) which will be used to fix the output of the attention layer within causal mediation analysis
        :param fix_ff_out_all: list of tensors of shape (b, t, k) which will be used to fix the output of the feedforward layer within causal mediation analysis
        """

        # apply the token and positional embeddings
        tokens = self.token_embedding(x)
        b, t, k = tokens.size()
        positions = self.pos_embedding(torch.arange(t, device=d()))[None, :, :].expand(b, t, k)
        
        # unify the embeddings of the tokens themselves and their positional encodings
        x = self.unify_embeddings(torch.cat((tokens, positions), dim=2).view(-1, 2*k)).view(b, t, k)

        # run the input through the stack of transformer blocks, while allowing the passing of the fixed attention and feedforward layer outputs for causal mediation analysis
        attn_out_unified_all = []
        ff_out_all = []
        for i, block in enumerate(self.tblocks):
            fix_attn_out_unified = fix_attn_out_unified_all[i] if fix_attn_out_unified_all is not None else None
            # print("Shape of fix_attn_out_unified") if fix_attn_out_unified is not None else None
            # print(fix_attn_out_unified.size()) if fix_attn_out_unified is not None else None
            fix_ff_out = fix_ff_out_all[i] if fix_ff_out_all is not None else None
            # print("Shape of fix_ff_out") if fix_ff_out is not None else None
            # print(fix_ff_out.size()) if fix_ff_out is not None else None
            x, attn_out_unified, ff_out = block(x, fix_attn_out_unified=fix_attn_out_unified, fix_ff_out=fix_ff_out)   # x = tb_out
            attn_out_unified_all.append(attn_out_unified)
            ff_out_all.append(ff_out)

        # map the output vector of dimension k to a vector of dimension num_tokens
        logits = self.toprobs(x.view(b*t, k)).view(b, t, self.num_tokens)

        # compute the log probabilities
        log_probs = F.log_softmax(logits, dim=-1)

        return log_probs, attn_out_unified_all, ff_out_all   
    
    
    @torch.no_grad()
    def get_respresentations(self, x):
        """
        Function for extracting the attention matrices of the whole Transformer for a single batch.
        Input arguments same as the forward pass.
        """

        tokens = self.token_embedding(x)
        b, t, k = tokens.size()

        positions = self.pos_embedding(torch.arange(t, device=d()))[None, :, :].expand(b, t, k)
        x = self.unify_embeddings(torch.cat((tokens, positions), dim=2).view(-1, 2*k)).view(b, t, k)

        attention_maps = []
        key_matrices = []
        query_matrices = []
        value_matrices = []
        out_matrices = []
        probs_matrices = []
        for i, block in enumerate(self.tblocks):
            # set the model to evaluation mode, so that dropout is not applied
            block.eval()
            # get the attention maps, key, query and value matrices for each attention layer
            attn_out_unified, dot, keys, queries, values = block.attention(x)      
            attention_maps.append(dot)
            key_matrices.append(keys)
            query_matrices.append(queries)
            value_matrices.append(values)    

            # get the output of the transformer block
            x, attn_out_unified, ff_out = block(x) # x = tb_out
            out_matrices.append(x)
            logits = self.toprobs(x.view(b*t, k)).view(b, t, self.num_tokens)
            log_probs = F.log_softmax(logits, dim=-1)
            probs_matrices.append(log_probs)
            # sets the model back to training mode
            block.train()
        return out_matrices, probs_matrices, attention_maps, key_matrices, query_matrices, value_matrices 