In [1]:
import pickle
from typing import List

from nltk.tokenize import RegexpTokenizer
import numpy as np
from tqdm import tqdm

In [2]:
class BigramModel:
    def __init__(self):
        self.tokenizer = RegexpTokenizer(r'\w+')

        self.vocab = set()
        self.occur_dict = []
        self.freq_dict = []
        
    @property
    def vocab_size(self):
        return len(self.vocab)
    
    def token_pos(self, token: str):
        return self.vocab.index(token)
        
    def _tokenize(self, texts: List[str]):
        result = []
        for text in texts:
            tokens = self.tokenizer.tokenize(text.lower())
            result.extend(tokens)
            
        return result
    
    def fit(self, data: List[str]):
        tokens = self._tokenize(data)
        self.vocab = list(set(tokens))

        vocab_dict = {self.vocab[i]: i for i in range(self.vocab_size)}
        occur_dict = [[1] * self.vocab_size] * self.vocab_size
        freq_dict = [[1] * self.vocab_size] * self.vocab_size
#         occur_dict = {v: dict(vocab_dict) for v in self.vocab}
#         freq_dict = {v: dict(vocab_dict) for v in self.vocab}
        
        tokens_inds = [vocab_dict[token] for token in tqdm(tokens)]
        for i in tqdm(range(0, len(tokens_inds)-1)):
#             w_1, w_2 = self.token_pos(tokens[i]), self.token_pos(tokens[i+1])
            w_1, w_2 = tokens_inds[i], tokens_inds[i+1]
            occur_dict[w_1][w_2] += 1
                
        c_w = [sum(occur_dict[self.token_pos(v)]) for v in tqdm(self.vocab)]
        for v in tqdm(range(self.vocab_size)):
            v_c_w = c_w[v]
            for b in range(self.vocab_size):
                freq_dict[v][b] = occur_dict[v][b] / v_c_w
        
        self.occur_dict = occur_dict
        self.freq_dict = freq_dict
        
    def sentence_prob(self, sentence: str) -> float:
        tokens = self._tokenize([sentence])
        
        result = 1
#         probs = []
        tokens_inds = [self.token_pos(token) for token in tokens]
        for i in range(0, len(tokens_inds)-1, 2):
#             w_1, w_2 = self.token_pos(tokens[i]), self.token_pos(tokens[i+1])
            w_1, w_2 = tokens_inds[i], tokens_inds[i+1]
            bigram_prob = self.freq_dict[w_1][w_2]
#             bigram_prob = self.freq_dict[tokens[i]][tokens[i+1]] # np.log()
            result *= bigram_prob
#             probs.append(bigram_prob)

        return result
    
    def top_freq(self, word: str, count: int = 10):
        vocab_dict = {}
        for v_word, freq in zip(self.vocab, self.freq_dict[self.token_pos(word)]):
            vocab_dict[v_word] = freq

        vocab_dict = sorted(
            vocab_dict.items(), key=lambda x: x[1], reverse=True
        )

        return vocab_dict[:count]
    
    def predict(self, sentence: str, count: int):
        new_sentence = sentence
        for _ in range(count):
            last_word = new_sentence.split(' ')[-1]
            word_pos = self.token_pos(last_word)
            probs = self.freq_dict[word_pos]
            word = np.random.choice(self.vocab, p=probs)
            new_sentence += " " + word
            
        return new_sentence

In [3]:
with open('data/all_news.pkl', 'rb') as f:
    news = pickle.load(f)
# news = np.random.choice(news, 100)
data = [topic['body'] for topic in news]

In [4]:
model = BigramModel()

In [5]:
model.fit(data)

100%|██████████| 3840719/3840719 [00:01<00:00, 3115049.09it/s]
100%|██████████| 3840718/3840718 [00:01<00:00, 3035580.35it/s]
100%|██████████| 190719/190719 [22:09<00:00, 143.40it/s]
100%|██████████| 190719/190719 [1:00:51<00:00, 52.23it/s]


In [6]:
model.sentence_prob("сегодня я как")

0.0017978700894991041

In [7]:
model.top_freq("я")

[('в', 0.044158943820776564),
 ('и', 0.02211271067859922),
 ('на', 0.017701628476396877),
 ('что', 0.012350930945962941),
 ('с', 0.010259121003255166),
 ('по', 0.00996890190768205),
 ('не', 0.009288499361393964),
 ('из', 0.006127591724737358),
 ('за', 0.005252965629873417),
 ('о', 0.004848891350652385)]

In [8]:
model.predict("сегодня я увидел как", 10)

'сегодня я увидел как без защиту там пострадавшей 900 зеркальных на являемся с и'