In [10]:
from nltk.tokenize import word_tokenize
from collections import defaultdict
from sortedcontainers import SortedDict
import nltk.corpus

In [None]:
class GenerationStrategy:
    def __init__(self, sorted_word_freq_pairs):
        """
        Assume `sorted_word_freq_pairs` is in the form of a list of 2-element tuples, where each tuple is (word, frequency),
        sorted in descending order by frequency (most frequent first).
        """
        self.candidates = sorted_word_freq_pairs

    def greedy_select(self):
        return self.candidates[0][0]
    
    def temperature_based_select(self):
        pass
    
    def randomly_select(self):
        pass

In [None]:
class N_Gram:
    def __init__(self, corpus, tokenizer, n, start_sym = "<s>", end_sym = "<s/>"):
        self.corpus = corpus
        self.tokenizer = tokenizer
        self.n = n
        self.lookup_table_by_word = {} # dict(context -> dict(potential_word -> count))
        self.lookup_table_by_count = {} # dict(context -> list(tuple(potential_word, count)))
        self.token_seq = []
        self.start_sym = start_sym
        self.end_sym = end_sym
    
    def tokenize(self, corpus):
        """
        Tokenize the given corpus, including appropriate start and end symbols.
        """
        return  [self.start_sym] * (self.n - 1) + self.tokenizer(corpus) + [self.end_sym] * (self.n - 1)
    
    def train(self, token_seq):
        """
        Given a sequence of tokens (including adequate START and END tokens), populate the lookup_table
        based on the n-grams (n-long word sequences) encountered in the corpus.
        """
        # Populate the lookup_table for next word prediction by word
        for i in range(self.n - 1, len(token_seq)):
            cur, context = token_seq[i], token_seq[i-self.n+1:i-1]
            if context not in self.lookup_table_by_word:
                self.lookup_table_by_word[context] = defaultdict(int)
            self.lookup_table_by_word[context][cur] += 1
        
        # Populate the lookup_table for next word prediction by count
        for context in self.lookup_table_by_word:
            self.lookup_table_by_count[context] = sorted(self.lookup_table_by_word[context].items(), key=lambda item: item[1], reverse=True)
            
    def sample(self, strategy):
        """
        Repeatedly sample to generate a text autoregressively, starting from an empty input consisting only
        of the start symbol.
        
        Follow the indicated strategy for how to sample.
        """
        text_tokens = [self.start_sym] * (self.n - 1)
        generated_token = None
        
        while generated_token != self.end_sym:
            context = text_tokens[len(text_tokens) - self.n + 1:len(text_tokens)]
            potential_words = self.lookup_table[context]
            generated_token = strategy(potential_words)
            text_tokens.append(generated_token)
        
        return text_tokens

[1, 2]