In [None]:
import numpy as np

class BigramLM:
    def __init__(self):
        self.vocab = set()
        self.bigram_counts = {}
        self.bigram_probs = {}

    def learn(self, dataset):
        # Step 1: Collect vocabulary and count bigrams
        for sentence in dataset:
            tokens = sentence.split()
            for i in range(len(tokens) - 1):
                word1 = tokens[i]
                word2 = tokens[i + 1]
                
                # Add words to vocabulary
                self.vocab.add(word1)
                self.vocab.add(word2)
                
                # Count bigrams
                if word1 in self.bigram_counts:
                    if word2 in self.bigram_counts[word1]:
                        self.bigram_counts[word1][word2] += 1
                    else:
                        self.bigram_counts[word1][word2] = 1
                else:
                    self.bigram_counts[word1] = {word2: 1}

        # Step 2: Calculate bigram probabilities
        for word1 in self.bigram_counts:
            total_count = sum(self.bigram_counts[word1].values())
            self.bigram_probs[word1] = {}
            for word2 in self.bigram_counts[word1]:
                self.bigram_probs[word1][word2] = self.bigram_counts[word1][word2] / total_count

    def generate_next_word(self, word):
        if word in self.bigram_probs:
            next_words = list(self.bigram_probs[word].keys())
            probabilities = list(self.bigram_probs[word].values())
            next_word = np.random.choice(next_words, p=probabilities)
            return next_word
        else:
            return None
