In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

#Transformer Implementation

## Model

In [17]:
class Transformer(nn.Module):
    def __init__(self, input_vocab_size, output_vocab_size, embedding_dim, n_layers, hidden_dim, n_heads):
      
        super(Transformer, self).__init__()
        self.output_vocab_size = output_vocab_size
        self.input_vocab_size = input_vocab_size
        self.hidden_dim = hidden_dim
        self.n_heads = n_heads
        self.sos_idx = 2
        self.eos_idx = 3
        self.mask_idx = 0
        self.pad_idx = 1

        self.embedding_dim = embedding_dim
        self.n_layers = n_layers
        self.embedding_scale = np.sqrt(self.embdding_dim)

        self.fc1 = Linear(self.embedding_dim, self.output_vocab_size)

        self.encoder = Encoder(
            self.input_vocab_size,
            self.embedding_dim,
            self.n_layers,
            self.hidden_dim,
            self.n_heads,
            self.pad_idx,
        )

        self.decoder = Decoder(
            self.input_vocab_size,
            self.embedding_dim,
            self.n_layers,
            self.hidden_dim,
            self.n_heads,
            self.pad_idx,
        )

    def forward(self, source, targets, source_mask=None, tgt_mask=None):
        y = self.decoder(
            targets, self.encoder(source, source_mask), source_mask, tgt_mask
        )
        y = self.fc1(y)
        return y

## Encoder and Decoder

In [18]:
class Encoder(nn.Module):
    def __init__(self, input_vocab_size, embedding_dim, n_layers, hidden_dim,
                 n_heads, pad_idx):
        super(Encoder, self).__init__()

        self.vocab_size = input_vocab_size
        self.embedding_dim = embedding_dim
        self.n_layers = n_layers

        self.embedding = Embedding(num_embeddings=self.vocab_size,
                                   embedding_dim=self.embedding_dim, padding_idx=pad_idx)

        self.pos_encoding = PositionalEmbedding(self.embedding_dim)

        self.layers = nn.ModuleList([TransformerEncoderLayer(
            self.embedding_dim, hidden_dim, n_heads) for x in range(self.n_layers)])

    def forward(self, x, mask=None):
        x = self.embed(x)
        for layer in self.layers:
            x = layer(x, mask)
        return x

    def embed(self, source):
        x = self.embedding(source)
        positional_encoding = self.pos_encoding(x)
        x += positional_encoding
        return x


class Decoder(nn.Module):
    def __init__(self, output_vocab_size, embedding_dim, n_layers, hidden_dim, n_heads, pad_idx):
        super(Decoder, self).__init__()

        self.vocab_size = output_vocab_size
        self.embedding_dim = embedding_dim
        self.n_layers = n_layers

        self.embedding = Embedding(
            num_embeddings=self.vocab_size, embedding_dim=self.embedding_dim, padding_idx=pad_idx)

        self.pos_encoding = PositionalEmbedding(self.embedding_dim)

        self.layers = nn.ModuleList([TransformerDecoderLayer(
            self.embedding_dim, hidden_dim, n_heads)for x in range(self.n_layers)])

    def forward(self, x, memory, source_mask=None, attention_mask=None):
        x = self.embed(x)
        for layer in self.layers:
            x = layer(x, memory, source_mask, attention_mask)
        return x

    def embed(self, source):
        x = self.embeding(source)
        positional_encoding = self.pos_encoding(x)
        x += positional_encoding
        return x

Layers

In [19]:
def Embedding(num_embeddings, embedding_dim, padding_idx):
    m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
    return m


def Linear(in_features, out_features, bias=True):
    m = nn.Linear(in_features, out_features, bias)
    return m


class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_sentence_length=512):
        super(PositionalEmbedding, self).__init__()

        positional_embedding = torch.zeros(max_sentence_length, d_model)

        for position in range(max_sentence_length):
            for i in range(0, d_model, 2):
                positional_embedding[position, i] = torch.sin(
                    position / (10000 ** (2*i/d_model)))
                positional_embedding[position, i +
                                     1] = torch.cos(position / (10000 ** (2*i/d_model)))

    def forward(self, x):
        x = x + self.positional_embedding.requires_grad_(False)

## Transformer Layer

In [20]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, n_heads):
        super(TransformerEncoderLayer, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.n_heads = n_heads

        # TO DO: Replace with custom multihead attention class
        self.self_attention = nn.MultiheadAttention(
            self.embedding_dim, self.n_heads)

        self.fc1 = Linear(self.embedding_dim, self.hidden_dim)
        self.fc2 = Linear(self.hidden_dim, self.embedding_dim)

        self.self_attention_layer_norm = nn.LayerNorm(self.embedding_dim)
        self.final_layer_norm = nn.LayerNorm(self.embedding_dim)

    def forward(self, x, source_mask=None):
        r = x
        x = self.self_attention_layer_norm(x)
        x, _ = self.self_attention(
            query=x, key=x, value=x, key_padding_mask=source_mask)
        x = x + r

        r = x
        x = self.final_layer_norm(x)
        # Check dimension, need to transpose?
        x = self.fc2(self.fc1(x).relu())
        x = x + r

        return x


class TransformerDecoderLayer(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, n_heads):
        super(TransformerDecoderLayer, self).__init__()
        self.embedding_dim = embedding_dim
        self.n_heads = n_heads
        self.hidden_dim = hidden_dim

        self.self_attention = nn.MultiheadAttention(
            self.embedding_dim, self.n_heads
        )
        self.self_attention_layer_norm = nn.LayerNorm(self.embedding_dim)

        # TO DO: Replace with custom multihead attention class
        self.enc_attention = nn.MultiheadAttention(
            self.embedding_dim, self.n_heads
        )
        self.enc_attention_layer_norm = nn.LayerNorm(self.embedding_dim)

        self.fc1 = Linear(self.embedding_dim, self.hidden_dim)
        self.fc2 = Linear(self.hidden_dim, self.embedding_dim)

        self.final_layer_norm = nn.LayerNorm(self.embedding_dim)

    def forward(self, x, enc_out, source_mask=None, tgt_mask=None):
        r = x
        x = self.self_attention_layer_norm(x)
        x, _ = self.self_attention(query=x, key=x, value=x, attn_mask=tgt_mask)
        x = x + r

        r = x
        x = self.enc_attention_layer_norm(x)
        x, _ = self.enc_attention(
            query=x, key=enc_out, value=enc_out, key_padding_mask=source_mask
        )
        x = x + r

        r = x
        x = self.final_layer_norm(x)
        x = self.fc2(self.fc1(x).relu())
        x = x + r

        return x

## Multi-head Attention

In [21]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, d_model):
        super(MultiHeadAttention, self).__init__()
        assert d_model % n_heads == 0

    def forward(self, q, k, v, mask=None):
        if mask is not None:
            mask = mask.unsqueeze(1)

# Train

In [1]:
from torch.utils.data import DataLoader
from torch.nn.functional import pad
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import Multi30k
from torchtext.data.functional import to_map_style_dataset
import spacy

In [12]:
device = torch.device('cuda')

In [13]:
# load_tokenizer
spacy_de = spacy.load("de_core_news_sm")
spacy_en = spacy.load("en_core_web_sm")

In [5]:
def tokenize(text, tokenizer):
    return [tok.text for tok in tokenizer.tokenizer(text)]


def yield_tokens(data_iter, tokenizer, index):
    for from_to_tuple in data_iter:
        yield tokenizer(from_to_tuple[index])


def tokenize_de(text):
    return tokenize(text, spacy_de)


def tokenize_en(text):
    return tokenize(text, spacy_en)

In [2]:
train, val, test = Multi30k(language_pair=("de", "en"))

In [7]:
vocab_src = build_vocab_from_iterator(
    yield_tokens(train, tokenize_de, index=0),
    min_freq=2,
    specials=["<s>", "</s>", "<blank>", "<unk>"],
)

vocab_tgt = build_vocab_from_iterator(
    yield_tokens(train, tokenize_en, index=1),
    min_freq=2,
    specials=["<s>", "</s>", "<blank>", "<unk>"]
)

In [8]:
train_map = to_map_style_dataset(train)
val_map = to_map_style_dataset(val)

In [9]:
def collate_batch(
    batch,
    src_tokenizer,
    tgt_tokenizer,
    src_vocab,
    tgt_vocab,
    device,
    max_padding=128,
    pad_id=2,
):
    bs_id = torch.tensor([0], device=device)  # <s> token id
    eos_id = torch.tensor([1], device=device)  # </s> token id
    src_list, tgt_list = [], []
    for (_src, _tgt) in batch:
        processed_src = torch.cat(
            [
                bs_id,
                torch.tensor(
                    src_vocab(src_tokenizer(_src)),
                    dtype=torch.int64,
                    device=device,
                ),
                eos_id,
            ],
            0,
        )
        processed_tgt = torch.cat(
            [
                bs_id,
                torch.tensor(
                    tgt_vocab(tgt_tokenizer(_tgt)),
                    dtype=torch.int64,
                    device=device,
                ),
                eos_id,
            ],
            0,
        )
        src_list.append(
            # warning - overwrites values for negative values of padding - len
            pad(
                processed_src,
                (
                    0,
                    max_padding - len(processed_src),
                ),
                value=pad_id,
            )
        )
        tgt_list.append(
            pad(
                processed_tgt,
                (0, max_padding - len(processed_tgt)),
                value=pad_id,
            )
        )

    src = torch.stack(src_list)
    tgt = torch.stack(tgt_list)
    return (src, tgt)

In [14]:
def collate_fn(batch):
    return collate_batch(
        batch,
        tokenize_de,
        tokenize_en,
        vocab_src,
        vocab_tgt,
        device,
        max_padding=max_padding,
        pad_id=vocab_src.get_stoi()["<blank>"],
    )

In [15]:
batch_size=12000
max_padding=128

train_dataloader = DataLoader(
    train_map,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
)
valid_dataloader = DataLoader(
    val_map,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
)