In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class Transformer(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_heads, dropout):
        super(Transformer, self).__init__()

        self.embedding = nn.Embedding(input_dim, hidden_dim)
        self.positional_encoding = PositionalEncoding(hidden_dim, dropout)
        encoder_layer = nn.TransformerEncoderLayer(hidden_dim, num_heads, dropout)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        self.decoder = nn.Linear(hidden_dim, input_dim)

    def forward(self, x):
        x = self.embedding(x)
        x = self.positional_encoding(x)
        x = self.encoder(x)
        x = self.decoder(x)
        return F.log_softmax(x, dim=-1)

In [3]:
class PositionalEncoding(nn.Module):
    def __init__(self, hidden_dim, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()

        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, hidden_dim, 2) * (-math.log(10000.0) / hidden_dim))
        pos_enc = torch.zeros((max_len, hidden_dim))
        pos_enc[:, 0::2] = torch.sin(position * div_term)
        pos_enc[:, 1::2] = torch.cos(position * div_term)
        pos_enc = pos_enc.unsqueeze(0).transpose(0, 1)

        self.register_buffer('pos_enc', pos_enc)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        x = x + self.pos_enc[:x.size(0), :]
        return self.dropout(x)