In [619]:
import tqdm
import re
import numpy as np
from collections import Counter

In [605]:
class LanguageModel():
    def __init__(self, order=2):
        self.order=order

    def fit(self, corpus):
        self.counts = get_counts(corpus, self.order)
        
    def prob(self, string, log=False):
        return interpolate(self.counts, string, self.order, log=log)
        
    def product(nums):
        "Multiply the numbers together.  (Like `sum`, but with multiplication.)"
        result = 1
        for x in nums: result *= x
        return result

    def get_ngrams(tokens, n):
        return [' '.join(tokens[i:i+n]) for i, token in enumerate(tokens)]
    
    def get_counts(corpus, order):  
        counts = {'n' + str(i) : Counter(get_ngrams(corpus, n=i)) for i in range(1, order+1)}
        counts['n0'] = {'':len(corpus)}
        return counts
    
    def get_prob(counts, word, context=''):
        '''With Laplace shoothing as yet.
        Not for public use.'''
        order = len(context.split())+1
        separator = ' ' if order > 1 else ''
        return (counts['n'+str(order)][separator.join([context, word])] + 1) / \
               (counts['n'+str(order-1)][context] + len(counts['n'+str(order)]))
        
    def get_logprob(counts, word, context=''):
        return np.log(get_prob(counts, word, context))
    
    def get_following(counts, context):
        '''Slow as hell. 
        To optimize might use embedded dictionaries.'''
        order = len(context.split())+1
        return sorted(
            [(k.split()[-1], v, get_prob(counts, k.split()[-1], context)) \
            for k, v in counts['n'+str(order)].items()                    \
            if re.match(context+' '+'\w+', k)],                           \
            key=lambda x:x[1], reverse=True)   
    
    def get_string_probs(counts, string, order, log=True):
        prob_fun = get_logprob if log else get_prob
        tokens = string.split()
        probs = []
        for i in range(len(tokens)):
            context = ' '.join(tokens[i-order+1:i]) if i>=order else ' '.join(tokens[:i])
            prob = prob_fun(counts, word = tokens[i], context = context)
            probs.append(prob)
        return probs
    
    def interpolate(counts, string, order, log=True, lambdas='default'):
        lmbd = [0.3, 0.7] if lambdas == 'default' else lambdas
        aggregate = sum if log else product
        probs = [get_string_probs(counts, string, order=i, log=log) for i in range(1, order+1)]
        probs_interpolated = []
        for tup in zip(*probs):
            prob_token = 0
            for i in range(len(tup)):
                prob_token += tup[i] * lmbd[i]
            probs_interpolated.append(prob_token)
        return aggregate(probs_interpolated)

In [None]:
def cleanse(s, rgxp = '[\W\da-z]'):
    return re.sub(' +', ' ', re.sub(rgxp, ' ', s.lower()))

In [None]:
with open('lt1.txt', encoding='utf-8') as f:
    tokens = cleanse(f.read().lower()).split()

In [618]:
%%time
model = LanguageModel(order=2)
model.fit(tokens)

Wall time: 656 ms


In [622]:
%%time
model.prob('наташа')

Wall time: 0 ns


0.001646777278875629

In [623]:
%%time
model.prob('наташа и пьер не хотели ехать')

Wall time: 0 ns


5.515250993684362e-19

In [624]:
model.counts

{'n0': {'': 451046},
 'n1': Counter({'л': 15,
          'н': 15,
          'толстой': 11,
          'война': 64,
          'и': 21710,
          'мир': 51,
          'том': 780,
          'первый': 178,
          'часть': 84,
          'первая': 32,
          'е': 46,
          'к': 3629,
          'поместья': 1,
          'й': 1066,
          'мой': 317,
          'верный': 6,
          'раб': 7,
          'см': 16,
          'сноски': 32,
          'в': 11173,
          'конце': 72,
          'части': 51,
          'ну': 568,
          'здравствуйте': 13,
          'садитесь': 7,
          'рассказывайте': 2,
          'так': 2032,
          'говорила': 267,
          'июле': 2,
          'года': 112,
          'известная': 11,
          'анна': 208,
          'павловна': 100,
          'шерер': 9,
          'фрейлина': 3,
          'приближенная': 1,
          'императрицы': 10,
          'марии': 6,
          'феодоровны': 2,
          'встречая': 5,
          'важного': 12,
      