In [0]:
import json
import random

In [0]:
class Sample:
  def __init__(self, prompt, truth, samples=None, selection=None):
    self.prompt = prompt
    self.truth = truth
    self.samples = samples or []
    self.length = 1
    self.selection = selection
    
  @staticmethod
  def load_file(path):
    '''
    Create a list of Samples from a json file.
    '''
    with open(path) as f:
      j = json.load(f)
    if isinstance(j, dict):
      j = [j]
    return [Sample.load_dict(d) for d in j]
  
  @staticmethod
  def load_dict(d):
    return Sample(d['prompt'], d['truth'], d.get('samples'), d.get('selection'))
  
  def generate(self, n):
    # TODO: use GPT-2 to generate samples (may require reworking `prompt`)
    self.samples = [f'generated {i+1}' for i in range(n)]
    self.length = n + 1
  
  def select(self):
    def print_header(text=None, c='='):
      if not text:
        print(c * 62)
        return
      left = max(30 - len(text) // 2, 10)
      right = max(30 - (len(text) + 1) // 2, 10)
      print(c * left, text, c * right)
  
    print_header('Prompt')
    print(self.prompt)
    completions = [self.truth] + self.samples
    shuffled_indices = list(range(self.length))
    random.shuffle(shuffled_indices)
    for si, ci in enumerate(shuffled_indices):
      print_header(f'Completion {si+1}', c='-')
      print(completions[ci])
    
    while True:
      print_header(c='-')
      si = input(f'Choose best completion (1-{self.length}) (q to quit) >>> ')
      if si.startswith('q'):
        return None
      try:
        si = int(si) - 1
      except ValueError:
        print('Not a number, try again.')
        continue
      if 0 <= si < self.length:
        break
      else:
        print('Invalid number, try again.')
    self.selection = shuffled_indices[si]
    return self.selection == 0
  
  def __str__(self):
    return json.dumps({
        'prompt': self.prompt,
        'truth': self.truth,
        'samples': self.samples,
        'selection': self.selection,
    }, indent=4)

In [13]:
s = Sample('the prompt', 'the truth')
s.generate(8)
s.select()

print(Sample.load_dict(json.loads(str(s))))

the prompt
------------------------ Completion 1 ------------------------
generated 8
------------------------ Completion 2 ------------------------
the truth
------------------------ Completion 3 ------------------------
generated 6
------------------------ Completion 4 ------------------------
generated 4
------------------------ Completion 5 ------------------------
generated 3
------------------------ Completion 6 ------------------------
generated 5
------------------------ Completion 7 ------------------------
generated 7
------------------------ Completion 8 ------------------------
generated 1
------------------------ Completion 9 ------------------------
generated 2
--------------------------------------------------------------
Choose best completion (1-9) (q to quit) >>> 9
{
    "prompt": "the prompt",
    "truth": "the truth",
    "samples": [
        "generated 1",
        "generated 2",
        "generated 3",
        "generated 4",
        "generated 5",
        "generated