## Imports

In [46]:
import numpy as np
import torch
import unittest

## Toy Model Function

In [47]:
class Model:
  def __init__(self, vocab_size):
    self.vocab_size = vocab_size

  def forward(self, sequence):
    return np.random.uniform(size = (self.vocab_size,))

## Decoding Class

In [48]:
class BeamSearch:

  def __init__(self, model, vocab_size = 1000, beam_size = 5):
    self.vocab_size = vocab_size
    self.beam_size = beam_size
    self.tokens = [i for i in range(vocab_size)]
    self.model = model

  def get_top_k(self, sequence):
    probs = self.model.forward(sequence)
    topk_probs, topk_tokens = torch.Tensor(probs).topk(self.beam_size)
    return topk_tokens.tolist(), topk_probs.tolist()

  def sample(self, prompt, seq_len):
    self.beams = [(prompt, 0)] * self.beam_size
    self.count = 0
    while self.count != seq_len:
      self.all_candidates = []
      for sequence, logprob in self.beams:
        tokens, prob = self.get_top_k(sequence)
        for i in range(self.beam_size):
          self.all_candidates.append(
              (sequence + [tokens[i]],
              logprob + prob[i])
          )
      self.beams = sorted(self.all_candidates, key = lambda x: x[1], reverse = True)[:self.beam_size]
      self.count += 1
    return sorted(self.beams, key = lambda x: x[1], reverse = True)[0][0]



## Tests

In [49]:
class TestClass(unittest.TestCase):

  def setUp(self):
    self.beam_size = 5
    self.vocab_size = 1000
    self.model = Model(self.vocab_size)
    self.decoder = BeamSearch(self.model, vocab_size = self.vocab_size, beam_size = self.beam_size)
    self.prompt = [1]
    self.seq_len = 20

  def test_decoder_output(self):
    output = self.decoder.sample(self.prompt, self.seq_len)
    self.assertEqual(len(output), self.seq_len + 1)

  def test_beam_sizes(self):
    output = self.decoder.sample(self.prompt, self.seq_len)
    self.assertEqual(len(self.decoder.beams), self.beam_size)

  def test_candidate_sizes(self):
    output = self.decoder.sample(self.prompt, self.seq_len)
    self.assertEqual(len(self.decoder.all_candidates), self.beam_size**2)

  def test_counter(self):
    output = self.decoder.sample(self.prompt, self.seq_len)
    self.assertEqual(self.decoder.count, self.seq_len)

  def test_token_range(self):
    output = self.decoder.sample(self.prompt, self.seq_len)
    self.assertGreater(self.vocab_size, max(output))
    self.assertLess(0, min(output))



## Execute Decoder

In [50]:

if __name__ == "__main__":
  unittest.main(argv = [""], verbosity = 2, exit = False)


test_beam_sizes (__main__.TestClass.test_beam_sizes) ... ok
test_candidate_sizes (__main__.TestClass.test_candidate_sizes) ... ok
test_counter (__main__.TestClass.test_counter) ... ok
test_decoder_output (__main__.TestClass.test_decoder_output) ... ok
test_token_range (__main__.TestClass.test_token_range) ... ok

----------------------------------------------------------------------
Ran 5 tests in 0.056s

OK
