In [1]:
import datetime
import numpy as np
import pandas as pd
from tqdm import tqdm
from numba import jit
import pickle
import scipy.sparse
from collections import Counter

In [12]:
load_from_file = False

In [3]:
# load doc list
with open('doc_list.txt') as f:
    doc_list = f.read().splitlines()

In [4]:
# load doc from list
if load_from_file != True:
  docs_counter = []
  words = set()
  for doc in tqdm(doc_list):
      with open('docs/' + doc + '.txt') as f:
          doc_words = f.read().split()
          docs_counter.append(Counter(doc_words))
          words = words.union(set(doc_words))

100%|██████████| 14955/14955 [01:00<00:00, 245.25it/s]


In [6]:
# load query list
with open('query_list.txt') as f:
    query_list = f.read().splitlines()

In [7]:
# load query from list
queries = []
queries_words = set()
for query in tqdm(query_list):
    with open('queries/' + query + '.txt') as f:
        query_words = f.read().split()
        queries.append(query_words)
        if load_from_file == False:
            words = words.union(set(query_words))
            queries_words = queries_words.union(set(query_words))

if load_from_file:
    # load query words from file
    with open('query_word_list.txt') as f:
        queries_words = f.read().split()
else:
    # save query words
    with open('query_word_list.txt', 'w') as f:
        f.write(' '.join(queries_words))
    
    queries_words = list(queries_words)

100%|██████████| 100/100 [00:00<00:00, 3173.78it/s]


In [8]:
if load_from_file:
    # load words dict from file
    with open('word_list.txt') as f:
        words = f.read().split()
else:
    # save words
    with open('word_list.txt', 'w') as f:
        f.write(' '.join(words))

    words = list(words)

In [9]:
docs_amount = len(doc_list)
words_amount = len(words)
query_word_amount = len(queries_words)

print(docs_amount, words_amount, query_word_amount)

14955 111449 226


In [10]:
# document length
if load_from_file:
    docs_len = np.load('docs_len.npy')
else:
    docs_len = []

    for j in tqdm(range(docs_amount)):
        docs_len.append(sum(docs_counter[j].values()))

    docs_len = np.array(docs_len)
    np.save('docs_len', docs_len)

In [11]:
# all words count in documents and probability
if load_from_file:
    cwd = scipy.sparse.load_npz('cwd.npz')
    pwd = scipy.sparse.load_npz('pwd.npz')
else:
    indptr = [0]
    indices = []
    cwd_data = []
    pwd_data = []

    for j in tqdm(range(docs_amount)):
        doc_len = docs_len[j]

        for i in range(words_amount):
            word_count = docs_counter[j][words[i]]
            if word_count != 0:
                indices.append(i)
                cwd_data.append(word_count)
                pwd_data.append(word_count / doc_len)
        indptr.append(len(indices))

    cwd = scipy.sparse.csr_matrix((cwd_data, indices, indptr), dtype=np.float32).transpose()
    pwd = scipy.sparse.csr_matrix((pwd_data, indices, indptr), dtype=np.float32).transpose()

    scipy.sparse.save_npz('cwd', cwd)
    scipy.sparse.save_npz('pwd', pwd)

In [13]:
# process slim words
if load_from_file:
    with open('slim_word_list_10000.txt') as f:
        slim_words = f.read().split()
else:
    words_count_list = []
    for word_row in tqdm(cwd):
        words_count_list.append(word_row.sum())
        
    most_word_index = np.flip(np.argsort(words_count_list))

    slim_words_amount = 10000
    slim_words = []
    for i in range(slim_words_amount):
        slim_words.append(words[most_word_index[i]])
    
    slim_words = slim_words + queries_words
    slim_words = list(set(slim_words))

    slim_words_amount = len(slim_words)

    # save slim words
    with open('slim_word_list_10000.txt', 'w') as f:
        f.write(' '.join(slim_words))

# update slim words amount
slim_words_amount = len(slim_words)
print(slim_words_amount)

111449it [00:13, 8365.09it/s]10020



In [14]:
type(slim_words)

list

In [15]:
# slim words count in documents and probability
if load_from_file:
    slim_cwd = scipy.sparse.load_npz('slim_cwd_10000.npz').A
    slim_pwd = scipy.sparse.load_npz('slim_pwd_10000.npz').A
else:
    indptr = [0]
    indices = []
    cwd_data = []
    pwd_data = []

    for j in tqdm(range(docs_amount)):
        doc_len = docs_len[j]
        for i in range(slim_words_amount):
            word_count = docs_counter[j][slim_words[i]]
            if word_count != 0:
                indices.append(i)
                cwd_data.append(word_count)
                pwd_data.append(word_count / doc_len)
        indptr.append(len(indices))

    slim_cwd = scipy.sparse.csr_matrix((cwd_data, indices, indptr), dtype=np.float32).transpose()
    slim_pwd = scipy.sparse.csr_matrix((pwd_data, indices, indptr), dtype=np.float32).transpose()

    scipy.sparse.save_npz('slim_cwd_10000', slim_cwd)
    scipy.sparse.save_npz('slim_pwd_10000', slim_pwd)

    slim_cwd = slim_cwd.A
    slim_pwd = slim_pwd.A

100%|██████████| 14955/14955 [01:44<00:00, 143.55it/s]


In [16]:
# background language model
bg = []
bg_model_cd = docs_len.sum()

for word_row in tqdm(slim_cwd):
    bg.append(word_row.sum() / bg_model_cd)

bg = np.array(bg)

100%|██████████| 10020/10020 [00:01<00:00, 5400.60it/s]


In [14]:
# def E_step():
#     ptwd_CD = np.matmul(pwt, ptd) # Common Denominator
#     for i in range(slim_words_amount):
#         for j in range(docs_amount):
#             if ptwd_CD[i][j] != 0:
#                 for k in range(topic_k):
#                     ptwd[k][i][j] = pwt[i][k] * ptd[k][j] / ptwd_CD[i][j]
#             else:
#                 ptwd[:,i,j] = 0

In [17]:
@jit(nopython=True)
def nb_E_step(pwt, ptd, cwd, topic_amount, word_amount, doc_amount):
    # empty matrix
    ptwd = np.empty((topic_amount, word_amount, doc_amount))

    # Common Denominator
    # ptwd_CD = np.dot(pwt, ptd) 

    for i in range(word_amount):
        for j in range(doc_amount):
            if cwd[i][j] != 0: 
                ptwd_CD = 0
                for k in range(topic_amount):
                    single_ptwd = pwt[i][k] * ptd[k][j]
                    ptwd[k][i][j] = single_ptwd
                    ptwd_CD += single_ptwd
                if ptwd_CD != 0:
                    for k in range(topic_amount):
                        ptwd[k][i][j] /= ptwd_CD
                else:
                    ptwd[:,i,j] = 0
            else:
                ptwd[:,i,j] = 0
    return ptwd

In [18]:
# # @jit
# def M_step():
#     # p(w/t)
#     for k in range(topic_k):
#         single_wt = np.multiply(cwd, ptwd[k])
#         single_wt_sum = single_wt.sum()
#         if single_wt_sum != 0:
#             for i in range(len(words)):
#                 pwt[i][k] = single_wt[i].sum() / single_wt_sum
#         else:
#             for i in range(len(words)):
#                 pwt[i][k] = 0
    
#     for i in range(len(pwt)):
#         pwt[i] /= pwt[i].sum()

#     # p(t/d)
#     for k in range(topic_k):
#         single_k_cwd_ptwd = np.multiply(cwd, ptwd[k])
#         for j in range(len(docs)):
#             if docs_len[j] != 0:
#                 ptd[k][j] = single_k_cwd_ptwd[:,j].sum() / docs_len[j]
#             else:
#                 ptd[k][j] = 0

In [19]:
# @jit
def M_step(times):
    for k in range(topic_k):
        single_topic_wd = np.multiply(cwd, ptwd[k])

        # p(w/t)
        single_wt_sum = single_topic_wd.sum()
        if single_wt_sum != 0:
            for i in range(slim_words_amount):
                pwt[i][k] = single_topic_wd[i].sum() / single_wt_sum
        else:
            pwt[:,k] = 1 / slim_words_amount

        # p(t/d)
        for j in range(docs_amount):
            ptd[k][j] = single_topic_wd[:,j].sum() / docs_len[j]
    
    # # norm to 1
    # for k in range(topic_k):
    #     if np.isnan(pwt[:,k].sum()):
    #         print(times, "norm ", k)
    #     pwt[:,k] /= pwt[:,k].sum()
    for j in range(docs_amount):
        deno = ptd[:,j].sum()
        if deno != 0:
            ptd[:,j] /= deno
        else:
            ptd[:,j].fill(1 / topic_k) 

In [20]:
@jit(nopython=True)
def nb_M_step(ptwd, cwd, docs_len, topic_amount, word_amount, doc_amount):
    # empty matrix
    pwt = np.empty((word_amount, topic_amount))
    ptd = np.empty((topic_amount, doc_amount))

    for k in range(topic_amount):
        single_topic_wd = np.multiply(cwd, ptwd[k])

        # p(w/t)
        single_wt_sum = single_topic_wd.sum()
        if single_wt_sum != 0:
            for i in range(word_amount):
                pwt[i][k] = single_topic_wd[i].sum() / single_wt_sum
        else:
            pwt[:,k] = 1 / slim_words_amount

        # p(t/d)
        for j in range(doc_amount):
            ptd[k][j] = single_topic_wd[:,j].sum() / docs_len[j]
            # ptd[k][j] = single_topic_wd[:,j].sum() / cwd[:,j].sum()
    
    # # norm to 1
    # for k in range(topic_k):
    #     if np.isnan(pwt[:,k].sum()):
    #         print(times, "norm ", k)
    #     pwt[:,k] /= pwt[:,k].sum()
    # for j in range(doc_amount):
    #     deno = ptd[:,j].sum()
    #     if deno != 0:
    #         ptd[:,j] /= deno
    #     else:
    #         ptd[:,j].fill(1 / topic_amount) 
    
    return pwt, ptd

In [21]:
def loss(times):
    loss = np.multiply(cwd, np.log(np.matmul(pwt, ptd))).sum()
    print("\nStep", times, "loss: ", loss)

In [22]:
@jit(nopython=True)
def nb_loss(times, cwd, pwt, ptd):
    loss = np.multiply(cwd, np.log(np.dot(pwt, ptd))).sum()
    print("\nStep", times, "loss: ", loss)

In [23]:
# topic

topic_k = 48
EPOCH = 30
alpha = 0.7
beta = 0.1

In [24]:
# EM Step Initial (normal)
pwt = np.random.random(size = (slim_words_amount, topic_k))

for k in range(topic_k):
    pwt[:,k] /= pwt[:,k].sum()

ptd = np.full((topic_k, docs_amount), 1 / topic_k)

ptwd = np.empty((topic_k, slim_words_amount, docs_amount))

In [78]:
for i in tqdm(range(EPOCH)):
    # nb_E_step(pwt, ptd, cwd, topic_amount, word_amount, doc_amount)
    ptwd = nb_E_step(pwt, ptd, slim_cwd, topic_k, slim_words_amount, docs_amount)
    # nb_M_step(ptwd, cwd, docs_len, topic_amount, word_amount, doc_amount)
    pwt, ptd = nb_M_step(ptwd, slim_cwd, docs_len, topic_k, slim_words_amount, docs_amount)
    # nb_loss(times, cwd, pwt, ptd)
    # nb_loss(i + 1, slim_cwd, pwt, ptd)

100%|██████████| 30/30 [1:31:07<00:00, 182.26s/it]


In [79]:
plsa_EM_final = np.matmul(pwt, ptd)

In [80]:
plsa_EM_final.shape

(5042, 14955)

In [81]:
queries_result = []

for query in tqdm(queries):
    query_result = []
    for doc_index in range(docs_amount):
        plsa_result = 1
        for word in query:
            word_index = slim_words.index(word)
            unigram_pwd = slim_pwd[word_index][doc_index]
            plsa_result = plsa_result * (alpha * unigram_pwd + beta * plsa_EM_final[word_index][doc_index] + (1 - alpha - beta) * bg[word_index])
            
            # plsa_result = plsa_result * (alpha * unigram_pwd + (1 - alpha - beta) * bg[word_index])
        query_result.append(plsa_result)
    queries_result.append(query_result)

100%|██████████| 100/100 [03:07<00:00,  1.87s/it]


In [82]:
## sort and export result
sim_df = pd.DataFrame(queries_result)
sim_df = sim_df.transpose()
sim_df.index = doc_list
sim_df.columns = query_list

In [83]:
sim_df

Unnamed: 0,301,302,303,304,305,306,307,308,309,310,...,391,392,393,394,395,396,397,398,399,400
FBIS3-1001,2.093690e-11,1.061319e-15,4.588218e-16,1.046446e-14,8.949977e-09,3.627352e-11,2.990166e-12,3.262359e-12,1.845634e-10,1.332322e-18,...,1.114330e-07,0.000026,6.548698e-10,2.407025e-08,0.000077,1.499156e-14,6.246422e-10,5.649983e-14,9.684796e-11,9.004595e-13
FBIS3-10014,2.738967e-11,9.201556e-16,4.931589e-16,1.051770e-14,8.672704e-09,8.421297e-13,2.796414e-12,3.416711e-12,2.524660e-10,1.702369e-18,...,1.677498e-07,0.000026,1.881855e-09,2.689350e-08,0.000060,2.100064e-14,6.314947e-10,4.756985e-14,9.564623e-11,1.213751e-14
FBIS3-10035,3.988780e-09,1.641297e-15,2.165017e-14,3.117691e-14,8.650170e-09,3.154298e-13,3.225514e-12,2.048372e-11,3.064716e-10,1.759899e-17,...,2.032955e-07,0.000026,4.604987e-10,3.071536e-08,0.000093,1.560984e-13,5.072881e-10,1.336222e-14,9.916510e-11,1.916171e-14
FBIS3-1007,2.804547e-11,6.129312e-16,4.047519e-15,1.013780e-14,7.371243e-09,1.705485e-11,3.654089e-12,1.076838e-11,7.227223e-10,3.781390e-18,...,2.800157e-07,0.000027,1.209230e-09,4.749952e-08,0.000067,3.204450e-14,9.382187e-10,1.846709e-14,9.475427e-11,7.704983e-14
FBIS3-10082,9.782491e-10,3.892662e-16,5.615827e-16,1.455064e-14,7.234331e-09,2.121981e-13,2.972856e-12,3.290567e-12,8.018253e-09,1.153053e-18,...,2.730137e-06,0.000027,8.124797e-10,1.861541e-08,0.000052,2.053503e-14,7.066957e-10,1.639840e-13,9.376218e-11,1.291680e-14
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
LA123190-0062,1.365360e-11,9.120537e-16,4.394229e-16,1.188231e-14,1.456308e-08,3.300053e-08,2.611248e-12,4.311811e-12,1.814736e-10,2.339024e-18,...,9.547320e-08,0.000026,3.770814e-09,3.902565e-08,0.000071,2.306106e-14,1.527491e-09,7.430432e-15,2.272709e-10,1.181416e-13
LA123190-0065,1.800242e-10,6.566182e-15,4.932551e-16,8.809623e-14,1.380302e-08,5.055919e-12,2.249842e-11,6.810621e-12,1.015781e-08,2.282887e-18,...,1.373443e-06,0.000027,5.865147e-08,6.052140e-08,0.000060,2.808765e-14,1.965433e-09,1.480000e-14,9.823764e-11,1.781525e-14
LA123190-0069,4.836468e-12,2.658358e-16,4.584159e-16,9.968444e-15,6.329533e-09,1.357133e-13,3.009999e-12,8.135422e-12,2.619919e-10,2.499100e-18,...,1.080084e-07,0.000027,7.241724e-10,4.387245e-07,0.000089,4.988678e-12,8.394233e-10,8.338528e-15,9.277491e-11,1.335501e-14
LA123190-0089,5.962427e-12,2.836151e-16,1.592784e-15,7.253025e-14,1.531723e-08,4.314870e-13,2.951151e-11,4.995984e-12,2.086370e-10,8.075692e-19,...,1.084485e-07,0.000029,2.834273e-09,5.983662e-08,0.000037,3.136418e-14,2.553189e-09,7.724964e-15,1.055725e-10,2.630736e-14


In [84]:
# save results
now = datetime.datetime.now()
save_filename = 'results/result' + '_' + 'topic' + str(topic_k) + '_EPOCH' + str(EPOCH) + '_a' + str(alpha) + '_b' + str(beta) + '_word'+ str(slim_words_amount) + now.strftime("_%y%m%d_%H%M") + '.txt'
print(save_filename)

with open(save_filename, 'w') as f:
    f.write('Query,RetrievedDocuments\n')
    for query in query_list:
        f.write(query + ",")
        query_sim_df = sim_df[query].sort_values(ascending=False)
        f.write(' '.join(query_sim_df[:1000].index.to_list()) + '\n')

results/result_topic48_EPOCH30_a0.7_b0.1_word5042_201125_2113.txt


In [85]:
sparse_pwt = scipy.sparse.csr_matrix(pwt)
sparse_ptd = scipy.sparse.csr_matrix(ptd)

In [86]:
scipy.sparse.save_npz('sparse_pwt' + '_' + 'topic' + str(topic_k) + '_EPOCH' + str(EPOCH) + '_word'+ str(slim_words_amount) + now.strftime("_%y%m%d_%H%M"), sparse_pwt)
scipy.sparse.save_npz('sparse_ptd' + '_' + 'topic' + str(topic_k) + '_EPOCH' + str(EPOCH) + '_word'+ str(slim_words_amount) + now.strftime("_%y%m%d_%H%M"), sparse_ptd)

In [22]:
pwt = scipy.sparse.load_npz('sparse_pwt_topic48_EPOCH30_201125_0207.npz').A
ptd = scipy.sparse.load_npz('sparse_ptd_topic48_EPOCH30_201125_0207.npz').A

In [54]:
for i in range(100):
    print(ptd[:,i].sum())

0.04477611940298507
0.15999999999999998
0.14022140221402213
0.17361111111111108
0.14285714285714285
0.1698841698841699
0.14054054054054055
0.18957345971563985
0.1330049261083744
0.13876651982378851
0.1235294117647059
0.147887323943662
0.14529914529914534
0.12698412698412698
0.12422360248447205
0.14942528735632182
0.15602836879432624
0.15559772296015179
0.15602836879432624
0.1375968992248062
0.13941018766756033
0.2030075187969925
0.159822633506844
0.1307420494699647
0.2985074626865672
0.2030075187969925
0.14625850340136057
0.1473684210526316
0.1310160427807487
0.14625850340136057
0.1473684210526316
0.17647058823529416
0.1488095238095238
0.1746724890829694
0.1746724890829694
0.07499999999999998
0.2113821138211382
0.19827586206896552
0.2
0.0898876404494382
0.14534883720930233
0.1329113924050633
0.15819209039548024
0.1698924731182796
0.1910828025477707
0.14029850746268657
0.18461538461538463
0.17329545454545456
0.22716049382716047
0.16091954022988506
0.22857142857142854
0.17472118959107807

In [55]:
for i in range(topic_k):
    print(pwt[:,i].sum())

1.0000000000001894
1.000000000000139
1.0000000000002431
1.0000000000002203
1.0000000000002818
1.0000000000001545
1.0000000000001794
1.0000000000000693
1.0000000000001346
1.000000000000212
1.0000000000002442
1.0000000000002427
1.0000000000001878
1.0000000000001168
1.0000000000001161
1.0000000000001636
1.0000000000001248
1.0000000000002793
1.0000000000001825
1.0000000000001847
1.000000000000178
1.0000000000002687
1.0000000000001852
1.0000000000001803
1.0000000000002158
1.0000000000001839
1.0000000000002534
1.0000000000001994
1.0000000000001774
1.0000000000001097
1.0000000000001585
1.0000000000001865
1.0000000000002327
1.0000000000000837
1.0000000000001044
1.000000000000179
1.0000000000001266
1.0000000000001728
1.0000000000002345
1.000000000000186
1.0000000000001639
1.0000000000001956
1.0000000000001616
1.0000000000001847
1.0000000000002212
1.0000000000001728
1.0000000000002045
1.0000000000001932
