# Lab 6: Translate with transformers

## Authors
- Francisco Roh
- Bryan Calisto

Read the paper:

https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf

Implement a English-Spanish translator using transformers, use this tutorial for help:

https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html

Use the attention model from class to compare a few translations.

## Preprocesamiento con las funciones proporcionadas

In [55]:
import torch
import re
import unicodedata
from torch.utils.data import DataLoader, TensorDataset, RandomSampler
import numpy as np

SOS_token = 0
EOS_token = 1
MAX_LENGTH = 10

class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS"}
        self.n_words = 2  # Count SOS and EOS

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s.strip()

def indexesFromSentence(lang, sentence):
    return [lang.word2index[word] for word in sentence.split(' ')]

def tensorFromSentence(lang, sentence):
    indexes = indexesFromSentence(lang, sentence)
    indexes.append(EOS_token)
    return torch.tensor(indexes, dtype=torch.long).unsqueeze(0)

def tensorsFromPair(pair, input_lang, output_lang):
    input_tensor = tensorFromSentence(input_lang, pair[0])
    target_tensor = tensorFromSentence(output_lang, pair[1])
    return (input_tensor, target_tensor)


def filterPair(p):
    try:
        return len(p[0].split(' ')) < MAX_LENGTH and \
            len(p[1].split(' ')) < MAX_LENGTH
    except:
        print(p)


def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

def prepareData(lang1, lang2, file):
    text = open(file, encoding='utf-8').read().strip().split('\n')
    pairs = [[normalizeString(s) for s in l.split('\t')][:2] for l in text ]
    pairs = [pair for pair in pairs if len(pair) == 2]

    input_lang = Lang(lang1)
    output_lang = Lang(lang2)

    pairs = filterPairs(pairs)

    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])

    return input_lang, output_lang, pairs

file = 'data/spa.txt'  # ruta del archivo
input_lang, output_lang, pairs = prepareData('eng', 'spa', file)

# Función para obtener el DataLoader
def get_dataloader(batch_size):

    input_lang, output_lang, pairs = prepareData('eng', 'spa', file)

    n = len(pairs)
    input_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)
    target_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)

    for idx, (inp, tgt) in enumerate(pairs):
        inp_ids = indexesFromSentence(input_lang, inp)
        tgt_ids = indexesFromSentence(output_lang, tgt)

        # Aseguramos que las secuencias no excedan MAX_LENGTH
        inp_ids = inp_ids[:MAX_LENGTH-1]  # Reservamos un espacio para EOS_token
        tgt_ids = tgt_ids[:MAX_LENGTH-1]

        # Añadimos el EOS_token
        inp_ids.append(EOS_token)
        tgt_ids.append(EOS_token)

        # Aseguramos que tanto inputs como targets tengan la misma longitud final
        input_ids[idx, :len(inp_ids)] = inp_ids
        target_ids[idx, :len(tgt_ids)] = tgt_ids

    # Aseguramos que el tamaño de lote y las dimensiones de input y target sean consistentes
    input_tensor = torch.LongTensor(input_ids).to(device)
    target_tensor = torch.LongTensor(target_ids).to(device)

    train_data = TensorDataset(input_tensor, target_tensor)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

    return train_dataloader



##Transformer

In [56]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt

# Codificación posicional
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)



# Transformer personalizado
class Transformer(nn.Module):
    def __init__(self, input_vocab_size, target_vocab_size, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1):
        super(Transformer, self).__init__()
        self.embedding_input = nn.Embedding(input_vocab_size, d_model)
        self.embedding_target = nn.Embedding(target_vocab_size, d_model)

        self.transformer = nn.Transformer(d_model=d_model, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, dim_feedforward=dim_feedforward, dropout=dropout)

        self.fc_out = nn.Linear(d_model, target_vocab_size)
        self.d_model = d_model
        self.positional_encoding = PositionalEncoding(d_model, dropout)

    def forward(self, src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask):
        # Embedding + Positional Encoding
        src = self.embedding_input(src) * math.sqrt(self.d_model)
        tgt = self.embedding_target(tgt) * math.sqrt(self.d_model)
        src = self.positional_encoding(src)
        tgt = self.positional_encoding(tgt)

        # Transformer forward pass
        output = self.transformer(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask,
                                  memory_mask=None, src_key_padding_mask=src_padding_mask,
                                  tgt_key_padding_mask=tgt_padding_mask)
        output = self.fc_out(output)

        return F.log_softmax(output, dim=-1)



# Máscara subsecuente (para evitar que un token vea el siguiente en el decodificador)
def generate_square_subsequent_mask(sz):
    mask = torch.triu(torch.ones(sz, sz)) == 1
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask




# Máscara de padding
def create_padding_mask(seq, pad_token):
    # Aquí nos aseguramos que la máscara de padding sea del tamaño (batch_size, seq_len)
    return (seq == pad_token).transpose(0, 1)



# Crear las máscaras necesarias
def create_masks(src, tgt, pad_token):
    # La máscara de atención para el encoder (src_mask) es de forma (src_len, src_len)
    src_mask = torch.zeros((src.size(1), src.size(1)), device=src.device).type(torch.bool)

    # La máscara de atención para el decoder (tgt_mask) es de forma (tgt_len, tgt_len)
    tgt_mask = generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)

    # Las máscaras de padding dependen del tamaño del lote y la longitud de secuencia
    src_padding_mask = create_padding_mask(src, pad_token)
    tgt_padding_mask = create_padding_mask(tgt, pad_token)

    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

# Crear las máscaras necesarias
def create_masks(src, tgt, pad_token):
    # La máscara de atención para el encoder (src_mask) es de forma (src_len, src_len)
    src_mask = torch.zeros((src.size(1), src.size(1)), device=src.device).type(torch.bool)

    # La máscara de atención para el decoder (tgt_mask) es de forma (tgt_len, tgt_len)
    tgt_mask = generate_square_subsequent_mask(tgt.size(1)).to(src.device)

    # Las máscaras de padding dependen del tamaño del lote y la longitud de secuencia
    src_padding_mask = create_padding_mask(src, pad_token)
    tgt_padding_mask = create_padding_mask(tgt, pad_token)

    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask



## Funciones de Entranamiento

In [57]:
# Función de entrenamiento
def train_epoch(dataloader, model, optimizer, criterion, device, pad_token, scheduler):
    model.train()
    total_loss = 0
    for batch in dataloader:
        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)

        # Verificamos que inputs y targets tengan el mismo tamaño de lote
        assert inputs.size(0) == targets.size(0), f"El tamaño del lote de src ({inputs.size(0)}) y tgt ({targets.size(0)}) no coincide."

        # Generación de máscaras
        src_mask = None
        tgt_mask = generate_square_subsequent_mask(targets.size(1)).to(device)
        src_padding_mask = (inputs == pad_token).to(device)
        tgt_padding_mask = (targets == pad_token).to(device)

        optimizer.zero_grad()

        # Asegurarse de que las secuencias de entrada y objetivo tengan el mismo tamaño de lote
        output = model(inputs, targets[:, :-1], src_mask, tgt_mask, src_padding_mask, tgt_padding_mask[:, :-1])

        # Ajuste de las dimensiones del objetivo para calcular la pérdida
        loss = criterion(output.view(-1, output.size(-1)), targets[:, 1:].contiguous().view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    scheduler.step()
    return total_loss / len(dataloader)








# Entrenamiento del Transformer
def train_transformer(model, dataloader, epochs, device, pad_token):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss(ignore_index=pad_token)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)

    losses = []
    for epoch in range(epochs):
        loss = train_epoch(dataloader, model, optimizer, criterion, device, pad_token, scheduler)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {loss}")
        losses.append(loss)

    return losses



## Entrenamiento y evaluación

In [58]:
# Visualización de las pérdidas durante el entrenamiento
def plot_losses(losses, epochs):
    plt.plot(range(1, epochs+1), losses)
    plt.title("Pérdida durante el entrenamiento")
    plt.xlabel("Épocas")
    plt.ylabel("Pérdida (Cross Entropy)")
    plt.show()


# Parámetros y configuración
input_vocab_size = input_lang.n_words
output_vocab_size = output_lang.n_words
epochs = 20
batch_size = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pad_token = 0  # Consideramos 0 como el token de padding

# Preparamos los datos
dataloader = get_dataloader(batch_size)

# Inicializamos el transformer
model = Transformer(input_vocab_size, output_vocab_size).to(device)

# Entrenamos el modelo
losses = train_transformer(model, dataloader, epochs, device, pad_token)

# Graficamos las pérdidas
plot_losses(losses, epochs)




RuntimeError: the batch number of src and tgt must be equal

## Generación de traducciones

In [None]:
# Función para evaluar y realizar traducciones
def evaluate_transf(model, src_sentence, input_lang, output_lang, max_len=MAX_LENGTH):
    model.eval()
    src_tensor = tensorFromSentence(input_lang, src_sentence).to(device)

    # Crear máscaras
    src_mask = torch.zeros((src_tensor.size(0), src_tensor.size(0)), device=src_tensor.device).type(torch.bool)
    tgt_tensor = torch.zeros((max_len, 1), dtype=torch.long, device=device).fill_(SOS_token)

    for i in range(1, max_len):
        tgt_mask = generate_square_subsequent_mask(i).to(device)
        output = model(src_tensor, tgt_tensor[:i], src_mask, tgt_mask, None, None)
        _, top1 = output[-1, :, :].max(dim=-1)
        tgt_tensor[i, 0] = top1.item()

        if top1.item() == EOS_token:
            break

    translated_sentence = [output_lang.index2word[tok.item()] for tok in tgt_tensor if tok.item() not in [SOS_token, EOS_token, pad_token]]
    return ' '.join(translated_sentence)

In [None]:
# Probamos con una oración de ejemplo
example_sentence = "she is my sister"
translated_sentence = evaluate_transf(model, example_sentence, input_lang, output_lang)
print(f"Traducción: {translated_sentence}")

## Modelo de atención de clase para comparación

In [None]:
import random

def indexesFromSentence(lang, sentence):
    return [lang.word2index[word] for word in sentence.split(' ')]

def tensorFromSentence(lang, sentence):
    indexes = indexesFromSentence(lang, sentence)
    indexes.append(EOS_token)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(1, -1)

def tensorsFromPair(pair):
    input_tensor = tensorFromSentence(input_lang, pair[0])
    target_tensor = tensorFromSentence(output_lang, pair[1])
    return (input_tensor, target_tensor)

def evaluate_attention(encoder, decoder, sentence, input_lang, output_lang):
    with torch.no_grad():
        input_tensor = tensorFromSentence(input_lang, sentence)

        encoder_outputs, encoder_hidden = encoder(input_tensor)
        decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)

        _, topi = decoder_outputs.topk(1)
        decoded_ids = topi.squeeze()

        decoded_words = []
        for idx in decoded_ids:
            if idx.item() == EOS_token:
                decoded_words.append('<EOS>')
                break
            decoded_words.append(output_lang.index2word[idx.item()])
    return decoded_words, decoder_attn

def evaluateRandomly(encoder, decoder, n=10):
    for i in range(n):
        pair = random.choice(pairs)
        print('>', pair[0])
        print('=', pair[1])
        output_words, _ = evaluate_attention(encoder, decoder, pair[0], input_lang, output_lang)
        output_sentence = ' '.join(output_words)
        print('<', output_sentence)
        print('')

In [None]:
path = "./models/"
encoder = torch.load(path+"translate_sp_en_encoder.pt")
decoder = torch.load(path+"translate_sp_en_decoder.pt")

encoder_attn = torch.load(path+"translate_sp_en_attn_encoder.pt")
decoder_attn = torch.load(path+"translate_sp_en_attn_decoder.pt")

## Comparación de modelos de traducción

In [None]:
phrase='she is my sister'
trd_sin_atencion= evaluate_attention(encoder, decoder, phrase, input_lang, output_lang)[0]
trd_con_atencion= evaluate_attention(encoder_attn, decoder_attn, phrase, input_lang, output_lang)[0]
trd_transformer = evaluate_transf(model, phrase, input_lang, output_lang)
print(f"Traducción sin atención : {trd_sin_atencion}")
print(f"Traducción con atención : {trd_con_atencion}")
print(f"Traducción con transfor.: {trd_transformer}")

## ************* SACADO DEL TUTORIAL ********************************************

In [None]:
def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    attn_logits = torch.matmul(q, k.transpose(-2, -1))
    attn_logits = attn_logits / math.sqrt(d_k)
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

In [None]:
# Helper function to support different mask shapes.
# Output shape supports (batch_size, number of heads, seq length, seq length)
# If 2D: broadcasted over batch size and number of heads
# If 3D: broadcasted over number of heads
# If 4D: leave as is
def expand_mask(mask):
    assert mask.ndim >= 2, "Mask must be at least 2-dimensional with seq_length x seq_length"
    if mask.ndim == 3:
        mask = mask.unsqueeze(1)
    while mask.ndim < 4:
        mask = mask.unsqueeze(0)
    return mask

In [None]:
class MultiheadAttention(nn.Module):

    def __init__(self, input_dim, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads."

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Stack all weight matrices 1...h together for efficiency
        # Note that in many implementations you see "bias=False" which is optional
        self.qkv_proj = nn.Linear(input_dim, 3*embed_dim)
        self.o_proj = nn.Linear(embed_dim, embed_dim)

        self._reset_parameters()

    def _reset_parameters(self):
        # Original Transformer initialization, see PyTorch documentation
        nn.init.xavier_uniform_(self.qkv_proj.weight)
        self.qkv_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)

    def forward(self, x, mask=None, return_attention=False):
        batch_size, seq_length, _ = x.size()
        if mask is not None:
            mask = expand_mask(mask)
        qkv = self.qkv_proj(x)

        # Separate Q, K, V from linear output
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3*self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
        q, k, v = qkv.chunk(3, dim=-1)

        # Determine value outputs
        values, attention = scaled_dot_product(q, k, v, mask=mask)
        values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
        values = values.reshape(batch_size, seq_length, self.embed_dim)
        o = self.o_proj(values)

        if return_attention:
            return o, attention
        else:
            return o

In [None]:
class EncoderBlock(nn.Module):

    def __init__(self, input_dim, num_heads, dim_feedforward, dropout=0.0):
        """
        Inputs:
            input_dim - Dimensionality of the input
            num_heads - Number of heads to use in the attention block
            dim_feedforward - Dimensionality of the hidden layer in the MLP
            dropout - Dropout probability to use in the dropout layers
        """
        super().__init__()

        # Attention layer
        self.self_attn = MultiheadAttention(input_dim, input_dim, num_heads)

        # Two-layer MLP
        self.linear_net = nn.Sequential(
            nn.Linear(input_dim, dim_feedforward),
            nn.Dropout(dropout),
            nn.ReLU(inplace=True),
            nn.Linear(dim_feedforward, input_dim)
        )

        # Layers to apply in between the main layers
        self.norm1 = nn.LayerNorm(input_dim)
        self.norm2 = nn.LayerNorm(input_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Attention part
        attn_out = self.self_attn(x, mask=mask)
        x = x + self.dropout(attn_out)
        x = self.norm1(x)

        # MLP part
        linear_out = self.linear_net(x)
        x = x + self.dropout(linear_out)
        x = self.norm2(x)

        return x

In [None]:
class TransformerEncoder(nn.Module):

    def __init__(self, num_layers, **block_args):
        super().__init__()
        self.layers = nn.ModuleList([EncoderBlock(**block_args) for _ in range(num_layers)])

    def forward(self, x, mask=None):
        for l in self.layers:
            x = l(x, mask=mask)
        return x

    def get_attention_maps(self, x, mask=None):
        attention_maps = []
        for l in self.layers:
            _, attn_map = l.self_attn(x, mask=mask, return_attention=True)
            attention_maps.append(attn_map)
            x = l(x)
        return attention_maps

In [None]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=5000):
        """
        Inputs
            d_model - Hidden dimensionality of the input.
            max_len - Maximum length of a sequence to expect.
        """
        super().__init__()

        # Create matrix of [SeqLen, HiddenDim] representing the positional encoding for max_len inputs
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)

        # register_buffer => Tensor which is not a parameter, but should be part of the modules state.
        # Used for tensors that need to be on the same device as the module.
        # persistent=False tells PyTorch to not add the buffer to the state dict (e.g. when we save the model)
        self.register_buffer('pe', pe, persistent=False)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x

In [None]:
class CosineWarmupScheduler(optim.lr_scheduler._LRScheduler):

    def __init__(self, optimizer, warmup, max_iters):
        self.warmup = warmup
        self.max_num_iters = max_iters
        super().__init__(optimizer)

    def get_lr(self):
        lr_factor = self.get_lr_factor(epoch=self.last_epoch)
        return [base_lr * lr_factor for base_lr in self.base_lrs]

    def get_lr_factor(self, epoch):
        lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters))
        if epoch <= self.warmup:
            lr_factor *= epoch * 1.0 / self.warmup
        return lr_factor

In [None]:
# Needed for initializing the lr scheduler
p = nn.Parameter(torch.empty(4,4))
optimizer = optim.Adam([p], lr=1e-3)
lr_scheduler = CosineWarmupScheduler(optimizer=optimizer, warmup=100, max_iters=2000)

# Plotting
epochs = list(range(2000))
sns.set()
plt.figure(figsize=(8,3))
plt.plot(epochs, [lr_scheduler.get_lr_factor(e) for e in epochs])
plt.ylabel("Learning rate factor")
plt.xlabel("Iterations (in batches)")
plt.title("Cosine Warm-up Learning Rate Scheduler")
plt.show()
sns.reset_orig()

In [None]:
class TransformerPredictor(pl.LightningModule):

    def __init__(self, input_dim, model_dim, num_classes, num_heads, num_layers, lr, warmup, max_iters, dropout=0.0, input_dropout=0.0):
        """
        Inputs:
            input_dim - Hidden dimensionality of the input
            model_dim - Hidden dimensionality to use inside the Transformer
            num_classes - Number of classes to predict per sequence element
            num_heads - Number of heads to use in the Multi-Head Attention blocks
            num_layers - Number of encoder blocks to use.
            lr - Learning rate in the optimizer
            warmup - Number of warmup steps. Usually between 50 and 500
            max_iters - Number of maximum iterations the model is trained for. This is needed for the CosineWarmup scheduler
            dropout - Dropout to apply inside the model
            input_dropout - Dropout to apply on the input features
        """
        super().__init__()
        self.save_hyperparameters()
        self._create_model()

    def _create_model(self):
        # Input dim -> Model dim
        self.input_net = nn.Sequential(
            nn.Dropout(self.hparams.input_dropout),
            nn.Linear(self.hparams.input_dim, self.hparams.model_dim)
        )
        # Positional encoding for sequences
        self.positional_encoding = PositionalEncoding(d_model=self.hparams.model_dim)
        # Transformer
        self.transformer = TransformerEncoder(num_layers=self.hparams.num_layers,
                                              input_dim=self.hparams.model_dim,
                                              dim_feedforward=2*self.hparams.model_dim,
                                              num_heads=self.hparams.num_heads,
                                              dropout=self.hparams.dropout)
        # Output classifier per sequence lement
        self.output_net = nn.Sequential(
            nn.Linear(self.hparams.model_dim, self.hparams.model_dim),
            nn.LayerNorm(self.hparams.model_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(self.hparams.dropout),
            nn.Linear(self.hparams.model_dim, self.hparams.num_classes)
        )

    def forward(self, x, mask=None, add_positional_encoding=True):
        """
        Inputs:
            x - Input features of shape [Batch, SeqLen, input_dim]
            mask - Mask to apply on the attention outputs (optional)
            add_positional_encoding - If True, we add the positional encoding to the input.
                                      Might not be desired for some tasks.
        """
        x = self.input_net(x)
        if add_positional_encoding:
            x = self.positional_encoding(x)
        x = self.transformer(x, mask=mask)
        x = self.output_net(x)
        return x

    @torch.no_grad()
    def get_attention_maps(self, x, mask=None, add_positional_encoding=True):
        """
        Function for extracting the attention matrices of the whole Transformer for a single batch.
        Input arguments same as the forward pass.
        """
        x = self.input_net(x)
        if add_positional_encoding:
            x = self.positional_encoding(x)
        attention_maps = self.transformer.get_attention_maps(x, mask=mask)
        return attention_maps

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr)

        # Apply lr scheduler per step
        lr_scheduler = CosineWarmupScheduler(optimizer,
                                             warmup=self.hparams.warmup,
                                             max_iters=self.hparams.max_iters)
        return [optimizer], [{'scheduler': lr_scheduler, 'interval': 'step'}]

    def training_step(self, batch, batch_idx):
        raise NotImplementedError

    def validation_step(self, batch, batch_idx):
        raise NotImplementedError

    def test_step(self, batch, batch_idx):
        raise NotImplementedError