In [159]:
from copy import deepcopy

In [160]:
words = {}

line = None
count  = 0
def update_dict(file_path: str):
    global count, line, words
    with open(file_path, 'r', encoding='UTF_8') as f:
        while True:
            line = f.readline()
            if not line:
                break
            for word in line.split():
                if word not in words:
                    words[word] = 1
                else:
                    words[word] += 1

In [161]:
update_dict('ferdowsi_train.txt')
update_dict('hafez_train.txt')
update_dict('molavi_train.txt')

In [162]:
temp = []
for word in words:
    if words[word] < 3:
        temp.append(word)
for word in temp:
    del words[word]
words["<s>"] = 1

In [163]:
def unigram(file_path: str):
    global words
    lines_count = 0
    total_words = 0
    prob_dict = {}
    with open(file_path, 'r', encoding='UTF_8') as f:
        while True:
            line = f.readline()
            if not line:
                break
            lines_count += 1
            for word in line.split():
                total_words += 1
                if word in words:
                    if word in prob_dict:
                        prob_dict[word] += 1
                    else:
                        prob_dict[word] = 1
                else:
                    prob_dict[word] = 0
    words_count = deepcopy(prob_dict)
    for word in prob_dict:
        prob_dict[word] /= total_words
    
    prob_dict["</s>"] = lines_count / total_words
    return prob_dict, lines_count, words_count

In [164]:
ferdowsi_unigram, ferdowsi_lines, ferdowsi_words = unigram('ferdowsi_train.txt')
hafez_unigram, hafez_lines, hafez_words = unigram('hafez_train.txt')
moalvi_unigram, molavi_lines, molavi_words = unigram('molavi_train.txt')

In [165]:
def bigram(file_path: str, lines_count: int, words_count: dict):
    global words
    prob_dict = {}
    with open(file_path, 'r', encoding='UTF_8') as f:
        while True:
            line = f.readline()
            if not line:
                break
            words_line = line.split()
            words_line.append("<s>")
            words_line.insert(0, "<s>")
            for i in range(1, len(words_line)):
                if words_line[i] not in words or words_line[i-1] not in words:
                    continue
                if (words_line[i], words_line[i-1]) not in prob_dict:
                    prob_dict[(words_line[i], words_line[i-1])] = 1
                else:
                    prob_dict[(words_line[i], words_line[i-1])] += 1
    
    for pair in prob_dict:
        if pair[1] != "<s>":
            prob_dict[pair] /= words_count[pair[1]]
        else:
            prob_dict[pair] /= lines_count
            
    return prob_dict           

In [166]:
ferdowsi_bigram = bigram('ferdowsi_train.txt', ferdowsi_lines, ferdowsi_words)
hafez_bigram = bigram('hafez_train.txt', hafez_lines, hafez_words)
moalvi_bigram = bigram('molavi_train.txt', molavi_lines, molavi_words)

In [172]:
# lambda3 = 0.8
# lambda2 = 0.15
# lambda1 = 0.1
# epsilon = 0.7

In [173]:
def backoff_model(unigram: dict, bigram: dict, line : list, lambda3: int, lambda2: int, lambda1: int, epsilon: int):
    probability = 1
    
    for i in range(1, len(line)):
        bigram_prob = 0
        unigram_prob = 0
        
        if (line[i], line[i-1]) in bigram:
            bigram_prob = bigram[(line[i], line[i-1])]
            
        if line[i] in unigram:
            unigram_prob = unigram[line[i]]
        
        probability *= (lambda3 * bigram_prob) + (lambda2 * unigram_prob) + (lambda1 * epsilon)
    return probability

In [174]:
def test(file_path: str, lambda3: int, lambda2: int, lambda1: int, epsilon: int):
    correct_case = 0
    total_case = 0
    
    with open(file_path, 'r', encoding='UTF_8') as f:
        while True:
            total_case += 1
            result = 1
            line = f.readline()
            if not line:
                break
            
            ferdowsi_probability = backoff_model(ferdowsi_unigram, ferdowsi_bigram, line.split()[1:], lambda3, lambda2, lambda1, epsilon)
            hafez_probability = backoff_model(hafez_unigram, hafez_bigram, line.split()[1:], lambda3, lambda2, lambda1, epsilon)
            moalvi_probability = backoff_model(moalvi_unigram, moalvi_bigram, line.split()[1:], lambda3, lambda2, lambda1, epsilon)
            
            if max(ferdowsi_probability, hafez_probability, moalvi_probability) == hafez_probability:
                result = 2
            elif max(ferdowsi_probability, hafez_probability, moalvi_probability) == moalvi_probability:
                result = 3
            if result == int(line.split()[0]):
                correct_case += 1
    
    return correct_case/total_case

In [305]:
test('test_file.txt', 0.800, 0.18, 0.02, 0.000005)

0.7987649836541955

In [None]:
best = 0
lambda3_max = 0
lambda2_max = 0
lambda1_max = 0
epsilon_max = 0

for lambda3 in range(800, 802):
#     print("lambda3 step = {}".format(lambda3))
    for lambda2 in range(177, 1000-lambda3):
#             print("lambda2 step = {}".format(lambda2))
            for epsilon in range(0, 1000):
                lambda1 = (1000 - lambda3 - lambda2)
                
                result = test('test_file.txt', lambda3/1000, lambda2/1000, lambda1/1000, epsilon/1000)
#                 print("result = {}, lambda3 = {}, lambda2 = {}, lambda1 = {}, epsilon = {}\n".format(result, lambda3/100, lambda2/100, lambda1/100, epsilon/100))
                if best < result:
#                     print("max = {}, lambda3 = {}, lambda2 = {}, lambda1 = {}, epsilon = {}\n".format(result, lambda3/1000, lambda2/1000, lambda1/1000, epsilon/1000))
                    best = result
                    lambda3_max = lambda3
                    lambda2_max = lambda2
                    lambda1_max = lambda1
                    epsilon_max = epsilon