In [6]:
import torch
import torch.nn as nn

In [19]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        :param x: Tensor of shape [batch_size, seq_length, d_model]
        :return: Tensor with positional encoding added.
        """
        x = x + self.pe[:, :x.size(1)]
        return x

In [22]:
vocab_size = 32768
latent_dim = 128
num_heads = 2
num_layers = 2
batch_size = 32
timesteps = 64

In [23]:
embed = nn.Embedding(
    num_embeddings=vocab_size, 
    embedding_dim=latent_dim)

pos = PositionalEncoding(latent_dim)

encoder_layer_in = nn.TransformerEncoderLayer(
    d_model=latent_dim, 
    nhead=num_heads)
text_transformer_encoder = nn.TransformerEncoder(
    encoder_layer_in, 
    num_layers=num_layers)

In [31]:
input_seq = torch.randint(
    low=0, high=vocab_size, size=(batch_size, timesteps))
input_attention_masks = torch.ones(
    (batch_size, timesteps), dtype=torch.float32)
print(input_seq.shape, input_attention_masks.shape)

x = pos(embed(input_seq)).permute(1, 0, 2)

# Transpose the tensor
input_attention_masks = input_attention_masks

print(x.shape, input_attention_masks.shape)
x = text_transformer_encoder(x, src_key_padding_mask=input_attention_masks)
print(x.permute(1,0,2).shape)

torch.Size([32, 64]) torch.Size([32, 64])
torch.Size([64, 32, 128]) torch.Size([32, 64])
torch.Size([32, 64, 128])


In [37]:
bottleneck = nn.Embedding(4, latent_dim)
bottleneck = bottleneck.weight.unsqueeze(0).repeat(batch_size, 1, 1)
print(bottleneck.shape)

torch.Size([32, 4, 128])
