In [1]:
import torch

In [2]:
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.hidden_size = 100
        self.embed_size = 100
        self.vocab_size = 10000
        self.embedding = torch.nn.Embedding(self.vocab_size, self.embed_size)
        self.num_layers = 1
        self.encoder = torch.nn.GRU(
            input_size=self.embed_size,
            hidden_size=self.hidden_size,
            batch_first=True,
            num_layers=self.num_layers
        )
        self.decoder = torch.nn.GRU(
            input_size=self.embed_size,
            hidden_size=self.hidden_size,
            batch_first=True,
            num_layers=self.num_layers
        )
        self.linear = torch.nn.Linear(self.hidden_size, self.vocab_size)
        self.max_len = 50
        
    def forward(self, inputs, mode="greedy", k=0.0, p=0.0):
        batch_size = inputs.size(0)
        enc_inputs = self.embedding(inputs)  # [batch, time, embed]
        enc_outputs, enc_hidden = self.encoder(enc_inputs)  # [batch, time, hidden], [layers, batch, hidden]
        
        dec_inputs = torch.ones(batch_size, 1, dtype=torch.int64).cuda()
        dec_inputs = self.embedding(dec_inputs)  # [batch, 1, embed]
        dec_hidden = enc_hidden
        
        words = torch.zeros(batch_size, self.max_len, dtype=torch.int64).cuda()
        for i in range(self.max_len):
            dec_outputs, dec_hidden = self.decoder(dec_inputs, dec_hidden)  # [batch, 1, hidden], [layers, batch, hidden]
            logits = self.linear(dec_outputs.squeeze(dim=1))  # [batch, vocab]
            probs = torch.softmax(logits, dim=1)
            if mode == "greedy":
                word = probs.argmax(dim=1).unsqueeze(dim=1)  # [batch, 1]
            elif mode == "sampling":
                if k > 0:
                    mask = probs < torch.topk(probs, k)[0][:, -1].unsqueeze(dim=1)
                    probs.masked_fill_(mask, value=0)
                if p > 0:
                    for batch_idx in range(batch_size):
                        sorted_probs, sorted_idx = torch.sort(probs[batch_idx], descending=True)  # [vocab], [vocab]
                        cumulative_probs = torch.cumsum(sorted_probs, dim=0)
                        sorted_mask = cumulative_probs > p
                        mask = sorted_idx[sorted_mask]
                        probs[batch_idx, mask]=0
                word = torch.multinomial(probs, 1)  # [batch, 1]
            words[:, i] = word.squeeze(dim=1)
            dec_inputs = self.embedding(word)
        return words

In [3]:
model = Model().cuda()

In [4]:
inputs = torch.ones(4, 10, dtype=torch.int64).cuda()
outputs = model(inputs)
outputs

tensor([[6628, 5233, 6492, 4255, 1482, 2146, 8036, 2416, 1373, 4443, 9263, 8254,
         3566, 9409, 9409, 8028, 5183, 2269, 1871, 1484, 6148, 7053, 4691, 9415,
         9817, 1629, 5890, 9428,  703, 3509, 5231, 5125, 7913, 4551, 7514, 7109,
         2209, 3221, 8583, 5308,  973, 9249, 3997, 3557, 2969, 3763, 1145, 6121,
         2779, 6091],
        [6628, 5233, 6492, 4255, 1482, 2146, 8036, 2416, 1373, 4443, 9263, 8254,
         3566, 9409, 9409, 8028, 5183, 2269, 1871, 1484, 6148, 7053, 4691, 9415,
         9817, 1629, 5890, 9428,  703, 3509, 5231, 5125, 7913, 4551, 7514, 7109,
         2209, 3221, 8583, 5308,  973, 9249, 3997, 3557, 2969, 3763, 1145, 6121,
         2779, 6091],
        [6628, 5233, 6492, 4255, 1482, 2146, 8036, 2416, 1373, 4443, 9263, 8254,
         3566, 9409, 9409, 8028, 5183, 2269, 1871, 1484, 6148, 7053, 4691, 9415,
         9817, 1629, 5890, 9428,  703, 3509, 5231, 5125, 7913, 4551, 7514, 7109,
         2209, 3221, 8583, 5308,  973, 9249, 3997, 3557, 2969, 37

In [5]:
inputs = torch.ones(4, 10, dtype=torch.int64).cuda()
outputs = model(inputs, mode="sampling", k=30, p=0.9)
outputs

tensor([[5233, 7811, 9159, 1482, 5882, 8415, 9110, 7950, 4632, 4275, 9847, 1525,
         7183, 8251, 9009, 4828, 5477, 8014, 8973, 3923, 6916, 1713, 6828, 8285,
         7973, 4134, 1063, 5935,  626, 6078, 4866, 7212, 8295, 1616, 3929, 7500,
         7765, 1509, 1241, 7377, 3263,  936,  980, 6111, 5454, 9448, 7852, 2390,
         7260, 8526],
        [8221, 7499, 4755,  929, 4816, 1713, 7582, 8919, 6186, 2105, 2768, 5288,
         9947, 3952,  105,  346, 4015, 8280, 8040, 2675, 6619,   79, 3695,  525,
         1969, 4604, 9559, 6252,  107, 1573, 8099, 1651,  116, 7311, 1828, 9284,
         1899, 2604, 1195, 5592, 9350, 4028, 2160, 3818, 7851, 9454, 6872, 9239,
         2572, 6228],
        [ 851, 6819, 3388, 8546, 9038, 7900, 7851, 2439, 3499, 8768, 5403, 3830,
          876, 4731, 9415, 1482, 8031, 2890, 7414, 6852, 1142, 6508, 6291, 2798,
         2837, 5031, 8539, 7456, 5742, 5152, 9658, 7907, 5945, 4788, 9557, 8215,
         9197, 6659, 8297, 2179, 1450, 6737, 3368, 8274, 9237, 45