In [1]:
import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]

class TransformerModel(nn.Module):
    def __init__(self, ntoken, d_model, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__()
        self.model_type = 'Transformer'
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(d_model)
        self.encoder = nn.Embedding(ntoken, d_model)
        self.transformer = nn.Transformer(d_model, nhead, nlayers, nlayers, nhid, dropout)
        self.decoder = nn.Linear(d_model, ntoken)
        self.d_model = d_model
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        src = self.encoder(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        tgt = self.encoder(tgt) * math.sqrt(self.d_model)
        tgt = self.pos_encoder(tgt)
        output = self.transformer(src, tgt, src_mask, tgt_mask)
        output = self.decoder(output)
        return output

# Example usage
ntokens = 1000  # size of vocabulary
d_model = 512  # embedding dimension
nhead = 8  # number of heads in multiheadattention
nhid = 2048  # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 6  # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
dropout = 0.5  # the dropout value

model = TransformerModel(ntokens, d_model, nhead, nhid, nlayers, dropout)

src = torch.randint(0, ntokens, (10, 32))  # (sequence length, batch size)
tgt = torch.randint(0, ntokens, (20, 32))  # (sequence length, batch size)

output = model(src, tgt)
print(output.shape)  # should be (sequence length, batch size, ntokens)



torch.Size([20, 32, 1000])
