In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, seq_len, dropout=0.1):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.position = torch.arange(0, seq_len)
        self.emb = nn.Embedding(seq_len, d_model)

    def forward(self, x):
        x = x + self.emb(self.position).unsqueeze(1) # broadcasting across batches
        return self.dropout(x)

In [3]:
seq_len=10
bs=4
embed_dim=6
x = torch.randn(size=(seq_len, bs, embed_dim))

In [4]:
pe = PositionalEncoding(d_model=embed_dim, seq_len=seq_len)
output = pe(x)

x=tensor([[-1.3666, -0.9245, -1.1971, -1.0691, -0.0745, -1.1086],
        [-1.4757, -0.6124, -0.3600, -0.3302, -2.2309,  0.7564],
        [-1.2283,  0.1618,  1.1980, -0.3296,  0.0476,  1.2692],
        [-2.5438,  0.0488, -0.8552, -0.5329, -0.4584,  1.2226]])torch.Size([10, 4, 6])
embedding.shape=torch.Size([10, 1, 6])
x=tensor([[ 0.2638, -1.4541, -1.7296, -1.3162, -0.1708, -0.9043],
        [ 0.1547, -1.1419, -0.8926, -0.5773, -2.3271,  0.9607],
        [ 0.4021, -0.3677,  0.6655, -0.5767, -0.0487,  1.4735],
        [-0.9134, -0.4807, -1.3877, -0.7800, -0.5546,  1.4269]],
       grad_fn=<SliceBackward>)torch.Size([10, 4, 6])


In [72]:
class TransformerModel(nn.Module):
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, seq_len, dropout=0.5):
        super(TransformerModel, self).__init__()
        self.src_mask = None
        self.dropout = nn.Dropout(p=dropout)
        encoder_layers = nn.TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, nlayers)
        self.pos_encoder = nn.Embedding(seq_len, ninp)
        self.word_encoder = nn.Embedding(ntoken, ninp)
        self.ninp = ninp
        self.decoder = nn.Linear(ninp, ntoken)
        self.init_weights()
        self.position = torch.arange(0, seq_len).unsqueeze(1)

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.word_encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src, has_mask=True):
        print(f'has_mask=={has_mask}')
        if has_mask:
            device = src.device
            if self.src_mask is None or self.src_mask.size(0) != len(src):
                mask = self._generate_square_subsequent_mask(len(src)).to(device)
                self.src_mask = mask
        else:
            self.src_mask = None
        word_emb = self.word_encoder(src) * math.sqrt(self.ninp)
        print(word_emb.shape)
        pos_emb = self.pos_encoder(self.position) #.unsqueeze(1)
        print(pos_emb.shape)
        
        src = self.dropout(word_emb + pos_emb)
        output = self.transformer_encoder(src, self.src_mask)
        output = self.decoder(output)
        return F.log_softmax(output, dim=-1)

In [73]:
ntoken = 1000
lm = TransformerModel(ntoken=ntoken, ninp=6, nhead=2, nhid=10, nlayers=1, seq_len=10)

In [74]:
x2 = torch.randint(high=ntoken, size=(seq_len, bs))

In [75]:
x2.shape

torch.Size([10, 4])

In [76]:
lm(x2).shape

has_mask==True
torch.Size([10, 4, 6])
torch.Size([10, 1, 6])


torch.Size([10, 4, 1000])

In [67]:
torch.arange(0, 10).unsqueeze(1).shape

torch.Size([10, 1])