In [None]:
import random
import time
import torch
from torch import nn
from torch.functional import F

# Create Train Data

In [None]:
char_to_id = {
    '<PAD>': 0,
    '<BOS>': 1,
    '<EOS>': 2,
    '0': 3,
    '1': 4,
    '2': 5,
    '3': 6,
    '4': 7,
    '5': 8,
    '6': 9,
    '7': 10,
    '8': 11,
    '9': 12,
    '+': 13,
    '-': 14,
}

id_to_char = {
    0: '<PAD>',
    1: '<BOS>',
    2: '<EOS>',
    3: '0',
    4: '1',
    5: '2',
    6: '3',
    7: '4',
    8: '5',
    9: '6',
    10: '7',
    11: '8',
    12: '9',
    13: '+',
    14: '-',
}

In [None]:
def id_list_to_sequence(sequence):
    return ''.join([id_to_char.get(i) for i in sequence])

def sequence_to_id_list(sequence):
    return [char_to_id.get(c) for c in sequence]

def create_dataset(size, num_digit=5, ops=['+', '-']):
    source_sequences = []
    target_sequences = []

    for _ in range(size):
        a = random.randint(0, 10**num_digit)
        b = random.randint(0, 10**num_digit)
        op = random.choice(ops)

        if op == '+':
            source_tokens = '{}+{}'.format(a, b)
            target_tokens = '{}'.format(a + b)
        elif op == '-':
            source_tokens = '{}-{}'.format(a, b)
            target_tokens = '{}'.format(a - b)

        source_sequences.append(source_tokens)
        target_sequences.append(target_tokens)

    return source_sequences, target_sequences


def tokenize(sequences, bos=False, eos=False):
    tensor = [torch.LongTensor(
        ([char_to_id['<BOS>']] if bos else []) + \
        sequence_to_id_list(s) + \
        ([char_to_id['<EOS>']] if eos else []))
        for s in sequences
    ]
    tensor = torch.nn.utils.rnn.pad_sequence(
        tensor, batch_first=True, padding_value=0)
    return tensor

In [None]:
train_data_num = 25600
batch_size = 256

train_source_sequences, train_target_sequences = create_dataset(train_data_num)
data_loader = torch.utils.data.DataLoader(
    dataset=torch.utils.data.TensorDataset(
        tokenize(train_source_sequences),
        tokenize(train_target_sequences, bos=True),
        tokenize(train_target_sequences, eos=True)
    ), batch_size=batch_size, shuffle=True, drop_last=True
)

In [None]:
for i in range(10):
    print(train_source_sequences[i], '=', train_target_sequences[i])

# Implementation of Transformer

In [None]:
vocab_size = len(char_to_id)
num_blocks = 2
num_hidden_size = 128
num_heads = 8
dropout_rate = 0.1

num_epochs = 100
num_batches = train_data_num // batch_size

device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'

In [None]:
NEG_INF = -1e9


def get_position_encoding(length, hidden_size, device, dtype=torch.float32):
    position = torch.arange(length, dtype=dtype, device=device)
    timescale = torch.arange(hidden_size // 2, dtype=dtype, device=device)

    angle_rates = 1 / (10000 ** ((2 * timescale) / hidden_size))
    angle_rads = position[:, None] * angle_rates[None, :]

    position_encoding = torch.stack([torch.sin(angle_rads), torch.cos(angle_rads)], axis=2)
    position_encoding = position_encoding.view(length, hidden_size)

    return position_encoding


def get_padding_bias(x):
    attention_bias = (x == 0) * NEG_INF
    attention_bias = attention_bias[:, None, None, :]
    return attention_bias


def get_decoder_self_attention_bias(length, device):
    valid_locs = torch.tril(torch.ones(length, length, device=device))
    valid_locs = valid_locs[None, None, :, :]
    decoder_bias = NEG_INF * (1.0 - valid_locs)
    return decoder_bias

In [None]:
class EmbeddingSharedWeights(nn.Module):

    def __init__(self, vocab_size, hidden_size):
        super(EmbeddingSharedWeights, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.embedding_weights = nn.Parameter(
            torch.normal(mean=0., std=num_hidden_size**-0.5, size=(vocab_size, num_hidden_size))
        )

    def forward(self, inputs, mode='embedding'):
        if mode == 'embedding':
            return self._embedding(inputs)
        elif mode == 'linear':
            return self._linear(inputs)

    def _embedding(self, inputs):
        embeddings = F.embedding(inputs, self.embedding_weights, padding_idx=0)
        embeddings *= self.hidden_size ** 0.5
        return embeddings

    def _linear(self, inputs):
        outputs = inputs @ self.embedding_weights.transpose(1, 0)
        return outputs

In [None]:
class FeedForwardNetwork(nn.Module):

    def __init__(self, hidden_size, filter_size, dropout_rate) -> None:
        super(FeedForwardNetwork, self).__init__()
        self.filter_dense_layer = nn.Linear(hidden_size, filter_size)
        self.output_dense_layer = nn.Linear(filter_size, hidden_size)
        self.gelu = nn.GELU()
        self.dropout_layer = nn.Dropout(dropout_rate)

    def forward(self, x):
        output = self.filter_dense_layer(x)
        output = self.gelu(output)
        output = self.dropout_layer(output)
        output = self.output_dense_layer(output)

        return output

In [None]:
class Attention(nn.Module):

    def __init__(self, hidden_size, num_heads, dropout_rate):
        super(Attention, self).__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.depth = hidden_size // num_heads

        self.q_dense_layer = nn.Linear(hidden_size, hidden_size, bias=False)
        self.k_dense_layer = nn.Linear(hidden_size, hidden_size, bias=False)
        self.v_dense_layer = nn.Linear(hidden_size, hidden_size, bias=False)
        self.output_dense_layer = nn.Linear(hidden_size, hidden_size, bias=False)
        self.dropout_layer = nn.Dropout(dropout_rate)

    def forward(self, x, y, bias):
        q = self.q_dense_layer(x)
        k = self.k_dense_layer(y)
        v = self.v_dense_layer(y)

        q = self._split_heads(q)
        k = self._split_heads(k)
        v = self._split_heads(v)

        q = q * self.depth ** -0.5

        logits = q @ k.transpose(-2, -1)
        logits += bias
        weights = F.softmax(logits, dim=-1)
        weights = self.dropout_layer(weights)

        attention_output = weights @ v 
        attention_output = self._combine_heads(attention_output)
        attention_output = self.output_dense_layer(attention_output)

        return attention_output

    def _split_heads(self, x):
        batch_size, length, _ = x.size()
        return x.view(batch_size, length, self.num_heads, self.depth).transpose(1, 2)

    def _combine_heads(self, x):
        batch_size, _, length, _ = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, length, self.hidden_size)

In [None]:
class SelfAttention(Attention):
    def __call__(self, x, bias):
        return super(SelfAttention, self).__call__(x, x, bias)

In [None]:
class PrePostProcessingWrapper(nn.Module):

    def __init__(self, layer, hidden_size, dropout_rate):
        super(PrePostProcessingWrapper, self).__init__()
        self.layer = layer
        self.layer_norm = nn.LayerNorm(hidden_size, eps=1e-6)
        self.dropout_layer = nn.Dropout(dropout_rate)

    def forward(self, x, *args, **kwargs):
        y = self.layer_norm(x)
        y = self.layer(y, *args, **kwargs)
        y = self.dropout_layer(y)

        return x + y

In [None]:
class EncoderStack(nn.Module):

    def __init__(self, num_blocks, hidden_size, num_heads, dropout_rate):
        super(EncoderStack, self).__init__()
        self.layers = nn.ModuleList()
        for _ in range(num_blocks):
            self_attention_layer = SelfAttention(hidden_size, num_heads, dropout_rate)
            feed_forward_network = FeedForwardNetwork(hidden_size, hidden_size * 4, dropout_rate)

            self.layers.append(nn.ModuleList([
                PrePostProcessingWrapper(self_attention_layer, hidden_size, dropout_rate),
                PrePostProcessingWrapper(feed_forward_network, hidden_size, dropout_rate)
            ]))

        self.output_normalization = nn.LayerNorm(hidden_size, eps=1e-6)

    def forward(self, encoder_inputs, attention_bias):
        for n, layer in enumerate(self.layers):
            self_attention_layer = layer[0]
            feed_forward_network = layer[1]

            encoder_inputs = self_attention_layer(encoder_inputs, attention_bias)
            encoder_inputs = feed_forward_network(encoder_inputs)

        return self.output_normalization(encoder_inputs)

In [None]:
class DecoderStack(nn.Module):

    def __init__(self, num_blocks, hidden_size, num_heads, dropout_rate):
        super(DecoderStack, self).__init__()
        self.layers = nn.ModuleList()
        for _ in range(num_blocks):
            self_attention_layer = SelfAttention(hidden_size, num_heads, dropout_rate)
            enc_dec_attention_layer = Attention(hidden_size, num_heads, dropout_rate)
            feed_forward_network = FeedForwardNetwork(hidden_size, hidden_size * 4, dropout_rate)

            self.layers.append(nn.ModuleList([
                PrePostProcessingWrapper(self_attention_layer, hidden_size, dropout_rate),
                PrePostProcessingWrapper(enc_dec_attention_layer, hidden_size, dropout_rate),
                PrePostProcessingWrapper(feed_forward_network, hidden_size, dropout_rate)
            ]))

        self.output_normalization = nn.LayerNorm(hidden_size, eps=1e-6)

    def forward(
        self,
        decoder_inputs,
        encoder_outputs,
        decoder_self_attention_bias,
        attention_bias
    ):
        for n, layer in enumerate(self.layers):
            self_attention_layer = layer[0]
            enc_dec_attention_layer = layer[1]
            feed_forward_network = layer[2]

            decoder_inputs = self_attention_layer(decoder_inputs, decoder_self_attention_bias)
            decoder_inputs = enc_dec_attention_layer(decoder_inputs, encoder_outputs, attention_bias)
            decoder_inputs = feed_forward_network(decoder_inputs)

        return self.output_normalization(decoder_inputs)

In [None]:
class Transformer(nn.Module):

    def __init__(
        self,
        vocab_size,
        num_blocks,
        hidden_size,
        num_heads,
        dropout_rate,
    ):
        super(Transformer, self).__init__()
        self.num_blocks = num_blocks
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate

        self.embedding_softmax_layer = EmbeddingSharedWeights(vocab_size, hidden_size)
        self.encoder_stack = EncoderStack(num_blocks, hidden_size, num_heads, dropout_rate)
        self.decoder_stack = DecoderStack(num_blocks, hidden_size, num_heads, dropout_rate)

        self.encoder_dropout_layer = nn.Dropout(dropout_rate)
        self.decoder_dropout_layer = nn.Dropout(dropout_rate)

    def forward(self, encoder_inputs, decoder_inputs):
        attention_bias = get_padding_bias(encoder_inputs)
        encoder_outputs = self.encode(encoder_inputs, attention_bias)
        logits = self.decode(decoder_inputs, encoder_outputs, attention_bias)
        return logits

    def encode(self, inputs, attention_bias):
        embedded_inputs = self.embedding_softmax_layer(inputs)

        # add_pos_encoding
        length = embedded_inputs.size(1)
        pos_encoding = get_position_encoding(length, self.hidden_size, embedded_inputs.device)
        encoder_inputs = embedded_inputs + pos_encoding
        encoder_inputs = self.encoder_dropout_layer(encoder_inputs)

        return self.encoder_stack(encoder_inputs, attention_bias)

    def decode(self, inputs, encoder_outputs, attention_bias):
        embedded_inputs = self.embedding_softmax_layer(inputs)

        # add_pos_encoding
        length = embedded_inputs.size(1)
        pos_encoding = get_position_encoding(length, self.hidden_size, embedded_inputs.device)
        decoder_inputs = embedded_inputs + pos_encoding
        decoder_inputs = self.decoder_dropout_layer(decoder_inputs)

        decoder_self_attention_bias = get_decoder_self_attention_bias(length, decoder_inputs.device)
        decoder_outputs = self.decoder_stack(
            decoder_inputs,
            encoder_outputs,
            decoder_self_attention_bias,
            attention_bias
        )
        logits = self.embedding_softmax_layer(decoder_outputs, mode='linear')

        return logits

In [None]:
transformer = Transformer(
    vocab_size,
    num_blocks,
    num_hidden_size,
    num_heads,
    dropout_rate,
).to(device)

optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0002, betas=(0.9, 0.97))

In [None]:
def loss_function(real, pred):
    loss = F.cross_entropy(pred.view(-1, pred.size(-1)), real.view(-1), ignore_index=0)
    return loss


def accuracy_function(real, pred):
    predicted_ids = pred.argmax(-1)
    correct = (predicted_ids == real).type(pred.dtype)
    weights = (real != 0).type(pred.dtype)

    return (correct * weights).sum() / weights.sum()


def train_step(dataset_inputs):
    encoder_inputs, decoder_inputs, decoder_targets = dataset_inputs
    encoder_inputs = encoder_inputs.to(device)
    decoder_inputs = decoder_inputs.to(device)
    decoder_targets = decoder_targets.to(device)

    optimizer.zero_grad()
    logits = transformer(encoder_inputs, decoder_inputs)
    loss = loss_function(decoder_targets, logits)
    loss.backward()
    optimizer.step()

    accuracy = accuracy_function(decoder_targets, logits)

    return loss, accuracy

In [None]:
template = '{}/{} (epoch {}), Train Loss: {:.4f}, Train Accuracy: {:.4f}, Elapsed Time: {:.2f}'
train_losses = []
train_accuracies = []
start = time.time()
for e in range(num_epochs):
    transformer.train(True)
    for i, batch_train_data in enumerate(iter(data_loader)):
        loss, accuracy = train_step(batch_train_data)
        train_losses.append(loss.item())
        train_accuracies.append(accuracy.item())

        if (e * num_batches + i + 1) % 100 == 0:
            print(template.format(
                e * num_batches + i + 1,
                num_epochs * num_batches,
                e + 1,
                sum(train_losses) / len(train_losses),
                sum(train_accuracies) / len(train_accuracies),
                time.time() - start
            ))

            train_losses = []
            train_accuracies = []
            start = time.time()


# Validation

In [None]:
valid_data_num = 10240
valid_batch_size = 1024

valid_source_sequences, valid_target_sequences = create_dataset(valid_data_num)
valid_data_loader = torch.utils.data.DataLoader(
    dataset=torch.utils.data.TensorDataset(
        tokenize(valid_source_sequences),
        tokenize(valid_target_sequences, bos=True),
        tokenize(valid_target_sequences, eos=True)
    ), batch_size=valid_batch_size, shuffle=False, drop_last=False
)

In [None]:
transformer.train(False)
valid_losses = []
valid_accuracies = []

for encoder_inputs, decoder_inputs, decoder_targets in iter(valid_data_loader):
    encoder_inputs = encoder_inputs.to(device)
    decoder_inputs = decoder_inputs.to(device)
    decoder_targets = decoder_targets.to(device)

    logits = transformer(encoder_inputs, decoder_inputs)
    loss = loss_function(decoder_targets, logits)
    accuracy = accuracy_function(decoder_targets, logits)

    valid_losses.append(loss.item())
    valid_accuracies.append(accuracy.item())

print('Valid Loss: {:.4f}, Valid Accuracy: {:.4f}'.format(
    sum(valid_losses) / len(valid_losses),
    sum(valid_accuracies) / len(valid_accuracies),
))