In [3]:
import globals

import re
from rank_bm25 import BM25Okapi
import math

import sqlite3
from sqlite3 import Error

import spacy
from spacy.tokens import DocBin
# Initialize spacy 'en' model, keeping only tagger component needed for lemmatization
nlp = spacy.load('en_core_web_sm', disable=['parser', 'ner'])

from pathlib import Path
from shutil import rmtree
import os
from os import listdir
from os.path import isfile, join

import nltk
from nltk import sent_tokenize, tokenize, word_tokenize
nltk.download("punkt")

import sys
import tqdm
import json
from wasabi import msg

[nltk_data] Downloading package punkt to /home/liamca/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [4]:
# Reset bm25 dirs
globals.resetDir(globals.bm25_dir)
globals.resetDir(globals.bm25_tmp_dir)



In [5]:
# Merge all files into a single file
merged_text_path = Path(globals.merged_text_dir)
if merged_text_path.exists():
    rmtree(merged_text_path)
merged_text_path.mkdir(parents=True)

inFiles = globals.getFilesInDir(os.path.join(globals.processed_text_dir, globals.blob_container_path))

counter = 0
with open(os.path.join(globals.merged_text_dir, globals.merged_text_file_name), "wb") as mergedFile:
    for fileToProcess in inFiles:
        counter += 1
        if counter % 1000 == 0:
            print ("Merged file count:", counter)
        with open(os.path.join(globals.processed_text_dir, globals.blob_container_path, fileToProcess), "rb") as infile:
            mergedFile.write(infile.read())
print ("Merged file count:", counter)


Merged file count: 1000
Merged file count: 2000
Merged file count: 3000
Merged file count: 4000
Merged file count: 5000
Merged file count: 6000
Merged file count: 7000
Merged file count: 8000
Merged file count: 9000
Merged file count: 10000
Merged file count: 11000
Merged file count: 12000
Merged file count: 13000
Merged file count: 14000
Merged file count: 15000
Merged file count: 16000
Merged file count: 17000
Merged file count: 18000
Merged file count: 19000
Merged file count: 20000
Merged file count: 21000
Merged file count: 22000
Merged file count: 23000
Merged file count: 24000
Merged file count: 25000
Merged file count: 26000
Merged file count: 27000
Merged file count: 28000
Merged file count: 29000
Merged file count: 29315


In [6]:
# Get the word counts over the merged file
# Get the avg word count 

print ("Get word counts over merged files...")

in_doc_counts = dict() 

total_word_count = 0.0
avg_word_count = 0.0
# Open the file in read mode 
text = open(os.path.join(globals.merged_text_dir, globals.merged_text_file_name), "r", encoding='utf-8') 
  
# This is the total bag of words counts across all docs
word_counts = dict() 

# This is the total word counts / doc
in_doc_total_word_counts = dict() 
  
# Loop through each line of the file 
doc_counter = 0
for line in text: 
    doc_counter += 1
    
    # Remove the leading spaces and newline character 
    line = line.strip() 
  
    # Split the line into words 
    words = line.split(" ") 
    total_word_count += len(words)
  
    in_doc_total_word_counts[doc_counter] = len(words)

    # Iterate over each word in line 
    sc = {}
    for word in words: 
        if word in word_counts: 
            word_counts[word] = word_counts[word] + 1
        else: 
            word_counts[word] = 1
            
        if word in sc:
            sc[word] = sc[word] + 1
        else:
            sc[word] = 1
        
    in_doc_counts[doc_counter] = sc

print ("Get avg word counts...")

avg_word_count = total_word_count / doc_counter

totalDocCount = doc_counter

Get word counts over merged files...
Get avg word counts...


In [7]:
bm25Values = dict() 

count1 = 0
count2 = 0
print ("Calculating avg BM25 values...")

word_count_keys_as_list = list(word_counts.keys())
word_count_keys_len = len(word_count_keys_as_list)
word_count_key_counter = 0
print ('Total terms to process:', word_count_keys_len)

in_doc_counts_keys_as_list = list(in_doc_counts.keys())
# Get the number of sentences 
in_doc_counts_len = len(in_doc_counts_keys_as_list)
in_doc_counts_counter = 0



Calculating avg BM25 values...
Total terms to process: 1357726


In [8]:
# convert the in_doc_counts to a dataframe so it can be processed faster
print ("Writing doc term counts to text...")

counter = 0
with open(os.path.join(globals.bm25_tmp_dir, "doc_terms.txt"), "w", encoding="utf-8") as outfile: 
    for doc_key in list(in_doc_counts.keys()):
        counter += 1
        if counter % 100000 == 0:
            print ("Completed", counter, "of", in_doc_counts_len, "...")
        for doc_term in in_doc_counts[doc_key]:
            outfile.write(str(doc_key) + '\t' + str(doc_term) + '\t' + str(in_doc_counts[doc_key][doc_term]) + '\r\n') 
print ("Completed", counter, "of", in_doc_counts_len, "...")



Writing doc term counts to text...
Completed 100000 of 1110994 ...
Completed 200000 of 1110994 ...
Completed 300000 of 1110994 ...
Completed 400000 of 1110994 ...
Completed 500000 of 1110994 ...
Completed 600000 of 1110994 ...
Completed 700000 of 1110994 ...
Completed 800000 of 1110994 ...
Completed 900000 of 1110994 ...
Completed 1000000 of 1110994 ...
Completed 1100000 of 1110994 ...
Completed 1110994 of 1110994 ...


In [9]:
# load the doc terms into a indexed sqlite db
print ("Loading doc term counts into indexed db...")

conn = sqlite3.connect(os.path.join(globals.bm25_tmp_dir,"bm25.sqlite"))
try:
    sql = "drop table if exists doc_terms"
    conn.execute(sql)
    sql = "create table doc_terms (doc_key int, term text, count int)"
    conn.execute(sql)
except Error as e:
    print(e)

Loading doc term counts into indexed db...


In [10]:
# load the doc terms
count = 0
rows = []
c = conn.cursor()
with open(os.path.join(globals.bm25_tmp_dir, "doc_terms.txt"), encoding='utf-8') as fp: 
    while True: 
        line = fp.readline() 
        if not line: 
            break

        fields = line.split('\t')
        rows.append((fields[0], fields[1], fields[2]))
        count += 1
        if count % 100000 == 0: 
            sql = 'insert into doc_terms (doc_key, term, count) values (?,?,?)'
            c.executemany(sql, rows)
            conn.commit()
            print ("Inserted:", str(count))
            rows = []
if len(rows) > 0:
    sql = 'insert into doc_terms (doc_key, term, count) values (?,?,?)'
    c.executemany(sql, rows)
    conn.commit()
    print ("Inserted:", str(count))
    
print ("Creating indexes...")
sql = "create index idx_doc_terms_doc_key on doc_terms (doc_key)"
conn.execute(sql)
sql = "create index idx_doc_terms_terms on doc_terms (term)"
conn.execute(sql)



Inserted: 100000
Inserted: 200000
Inserted: 300000
Inserted: 400000
Inserted: 500000
Inserted: 600000
Inserted: 700000
Inserted: 800000
Inserted: 900000
Inserted: 1000000
Inserted: 1100000
Inserted: 1200000
Inserted: 1300000
Inserted: 1400000
Inserted: 1500000
Inserted: 1600000
Inserted: 1700000
Inserted: 1800000
Inserted: 1900000
Inserted: 2000000
Inserted: 2100000
Inserted: 2200000
Inserted: 2300000
Inserted: 2400000
Inserted: 2500000
Inserted: 2600000
Inserted: 2700000
Inserted: 2800000
Inserted: 2900000
Inserted: 3000000
Inserted: 3100000
Inserted: 3200000
Inserted: 3300000
Inserted: 3400000
Inserted: 3500000
Inserted: 3600000
Inserted: 3700000
Inserted: 3800000
Inserted: 3900000
Inserted: 4000000
Inserted: 4100000
Inserted: 4200000
Inserted: 4300000
Inserted: 4400000
Inserted: 4500000
Inserted: 4600000
Inserted: 4700000
Inserted: 4800000
Inserted: 4900000
Inserted: 5000000
Inserted: 5100000
Inserted: 5200000
Inserted: 5300000
Inserted: 5400000
Inserted: 5500000
Inserted: 5600000
I

<sqlite3.Cursor at 0x7f3f46050500>

In [16]:
counter = 0
for key in word_count_keys_as_list: 
    counter += 1
    if counter % 100 == 0:
        print ("Completed", counter, "of", word_count_keys_len, "...")
    uniqueWord = key
#     print ('uniqueWord: ', uniqueWord)
    
    bm25Total = 0.0
    wordCounter = 0.0
    avgBM25 = 0.0

    try:
        # get all the docs that contain this uniqueword
        cur = conn.cursor()
        cur.execute("SELECT doc_key, count FROM doc_terms WHERE term=?", (uniqueWord,))
        rows = cur.fetchall()

        for row in rows:
            wordCountOfThisDoc = row[0]
            termFreqInDocument = row[1]
            termFreqInIndex = word_counts[key]
            this_bm25 = math.log((totalDocCount - termFreqInIndex + 0.5) / (termFreqInIndex + 0.5)) * (termFreqInDocument * (globals.k1 + 1)) / (termFreqInDocument + globals.k1 * (1 - globals.b + (globals.b * wordCountOfThisDoc / avg_word_count)))
            bm25Total += this_bm25
            wordCounter += 1
        avgBM25 = bm25Total / wordCounter
        if avgBM25 < 0:
            avgBM25 = 0
         
    except Exception as e:
        ## as long as it is not a divid b 0, print error
        if str(e) != 'math domain error':
            print("error:", e)
        bm25Values[uniqueWord] = 0
    bm25Values[uniqueWord] = avgBM25


Completed 100 of 1357726 ...
Completed 200 of 1357726 ...
Completed 300 of 1357726 ...
Completed 400 of 1357726 ...
Completed 500 of 1357726 ...
Completed 600 of 1357726 ...
Completed 700 of 1357726 ...
Completed 800 of 1357726 ...
Completed 900 of 1357726 ...
Completed 1000 of 1357726 ...
Completed 1100 of 1357726 ...
Completed 1200 of 1357726 ...
Completed 1300 of 1357726 ...
Completed 1400 of 1357726 ...
Completed 1500 of 1357726 ...
Completed 1600 of 1357726 ...
Completed 1700 of 1357726 ...
Completed 1800 of 1357726 ...
Completed 1900 of 1357726 ...
Completed 2000 of 1357726 ...
Completed 2100 of 1357726 ...
Completed 2200 of 1357726 ...
Completed 2300 of 1357726 ...
Completed 2400 of 1357726 ...
Completed 2500 of 1357726 ...
Completed 2600 of 1357726 ...
Completed 2700 of 1357726 ...
Completed 2800 of 1357726 ...
Completed 2900 of 1357726 ...
Completed 3000 of 1357726 ...
Completed 3100 of 1357726 ...
Completed 3200 of 1357726 ...
Completed 3300 of 1357726 ...
Completed 3400 of 1

KeyboardInterrupt: 

In [None]:
# write avg bm25 values to file
print ("Writing avg BM25 values...")

with open(os.path.join(globals.bm25_dir, globals.bm25_file), 'w', encoding='utf-8') as f:
    for bm25_key in list(bm25Values.keys()):
        f.write(bm25_key + '\t' + str(bm25Values[bm25_key]) + '\r\n')
