In [1]:
import torch

In [2]:
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.hidden_size = 100
        self.embed_size = 100
        self.vocab_size = 10000
        self.embedding = torch.nn.Embedding(self.vocab_size, self.embed_size)
        self.num_layers = 1
        self.encoder = torch.nn.GRU(
            input_size=self.embed_size,
            hidden_size=self.hidden_size,
            batch_first=True,
            num_layers=self.num_layers
        )
        self.decoder = torch.nn.GRU(
            input_size=self.embed_size,
            hidden_size=self.hidden_size,
            batch_first=True,
            num_layers=self.num_layers
        )
        self.linear = torch.nn.Linear(self.hidden_size, self.vocab_size)
        self.max_len = 50
        
    def forward(self, inputs, mode="greedy", k=0.0, p=0.0):
        batch_size = inputs.size(0)
        enc_inputs = self.embedding(inputs)  # [batch, time, embed]
        enc_outputs, enc_hidden = self.encoder(enc_inputs)  # [batch, time, hidden], [layers, batch, hidden]
        
        dec_inputs = torch.ones(batch_size, 1, dtype=torch.int64)
        dec_inputs = self.embedding(dec_inputs)  # [batch, 1, embed]
        dec_hidden = enc_hidden
        
        words = torch.zeros(batch_size, self.max_len, dtype=torch.int64)
        for i in range(self.max_len):
            dec_outputs, dec_hidden = self.decoder(dec_inputs, dec_hidden)  # [batch, 1, hidden], [layers, batch, hidden]
            logits = self.linear(dec_outputs.squeeze(dim=1))  # [batch, vocab]
            probs = torch.softmax(logits, dim=1)
            if mode == "greedy":
                word = probs.argmax(dim=1).unsqueeze(dim=1)  # [batch, 1]
            elif mode == "sampling":
                if k > 0:
                    mask = probs < torch.topk(probs, k)[0][:, -1].unsqueeze(dim=1)
                    probs.masked_fill_(mask, value=0)
                if p > 0:
                    for batch_idx in range(batch_size):
                        sorted_probs, sorted_idx = torch.sort(probs[batch_idx], descending=True)  # [vocab], [vocab]
                        cumulative_probs = torch.cumsum(sorted_probs, dim=0)
                        sorted_mask = cumulative_probs > p
                        mask = sorted_idx[sorted_mask]
                        probs[batch_idx, mask]=0
                word = torch.multinomial(probs, 1)  # [batch, 1]
            words[:, i] = word.squeeze(dim=1)
            dec_inputs = self.embedding(word)
        return words

In [3]:
model = Model()

In [4]:
inputs = torch.ones(4, 10, dtype=torch.int64)
outputs = model(inputs)
outputs

tensor([[1941, 4321, 2671, 9853, 6813, 6510, 3636, 5522, 2278, 3361, 7396, 8294,
         3754,  120, 2268, 9773, 3785,  911, 9538,  911, 9538, 6668, 5420, 1204,
         5189, 2923, 9549, 8157, 1542, 5776, 6392, 7061, 3878, 8643, 5594, 2471,
         6618, 4434, 7196, 9222, 1728, 1396, 5829, 5592, 4174, 6464, 2603, 6286,
         5156, 1282],
        [1941, 4321, 2671, 9853, 6813, 6510, 3636, 5522, 2278, 3361, 7396, 8294,
         3754,  120, 2268, 9773, 3785,  911, 9538,  911, 9538, 6668, 5420, 1204,
         5189, 2923, 9549, 8157, 1542, 5776, 6392, 7061, 3878, 8643, 5594, 2471,
         6618, 4434, 7196, 9222, 1728, 1396, 5829, 5592, 4174, 6464, 2603, 6286,
         5156, 1282],
        [1941, 4321, 2671, 9853, 6813, 6510, 3636, 5522, 2278, 3361, 7396, 8294,
         3754,  120, 2268, 9773, 3785,  911, 9538,  911, 9538, 6668, 5420, 1204,
         5189, 2923, 9549, 8157, 1542, 5776, 6392, 7061, 3878, 8643, 5594, 2471,
         6618, 4434, 7196, 9222, 1728, 1396, 5829, 5592, 4174, 64

In [5]:
inputs = torch.ones(4, 10, dtype=torch.int64)
outputs = model(inputs, mode="sampling", k=30, p=0.9)
outputs

tensor([[9245, 9595, 7072, 5389, 7098, 2671, 3199, 7919, 3396, 6590, 8866, 4545,
         2285, 5052, 1502, 1724, 9764, 2592, 9765, 3730, 6967, 5363, 4635, 3270,
         4952, 9021, 2268, 5972,   57, 5420, 3545, 5910, 7632, 7038, 4979, 5048,
         2780, 8297, 6278, 6911, 3683, 4185, 2725, 4585, 6984, 3392, 9630, 4950,
         1157, 8340],
        [1399,  782,  566, 4068, 8798, 1054, 5087, 2784, 2331, 8797, 7428, 3831,
         7227, 4632, 3342, 8588, 5772, 8160,  994, 4039, 3533,  772, 4374, 8502,
         3707, 5753, 9286, 9877, 6536, 9336, 9633, 7431, 4837, 3408, 2170, 1708,
         1654, 4098,  417, 9155, 9969, 7112, 1009, 3976, 1289, 3125, 3081, 4943,
         1579, 5886],
        [4373, 7027, 7633, 3735, 9280, 5309, 6678, 1673, 1654, 6281, 9623, 9369,
         5725, 2552, 2943, 1820,  452, 9528, 9397, 9778, 2565, 9445, 1131, 9981,
         9029, 2307, 7747, 6292, 2565, 1723, 9353, 7695, 5495, 4891, 6821, 4496,
         6620, 2036, 7785, 4776, 6389, 4908, 7725, 7706, 7756, 64