In [1]:
import datetime
import numpy as np
import pandas as pd
from tqdm import tqdm
from numba import jit
import pickle
import scipy.sparse
from collections import Counter

In [28]:
load_from_file = False

In [3]:
# load doc list
with open('doc_list.txt') as f:
    doc_list = f.read().splitlines()

In [4]:
# load doc from list
if ~load_from_file:
  docs_counter = []
  words = set()
  for doc in tqdm(doc_list):
      with open('docs/' + doc + '.txt') as f:
          doc_words = f.read().split()
          docs_counter.append(Counter(doc_words))
          words = words.union(set(doc_words))

100%|██████████| 14955/14955 [00:51<00:00, 288.84it/s]


In [5]:
# load query list
with open('query_list.txt') as f:
    query_list = f.read().splitlines()

In [24]:
# load query from list
queries = []
queries_words = set()
for query in tqdm(query_list):
    with open('queries/' + query + '.txt') as f:
        query_words = f.read().split()
        queries.append(query_words)
        if load_from_file == False:
            words = words.union(set(query_words))
            queries_words = queries_words.union(set(query_words))

if load_from_file:
    # load query words from file
    with open('query_word_list.txt') as f:
        queries_words = f.read().split()
else:
    # save query words
    with open('query_word_list.txt', 'w') as f:
        f.write(' '.join(queries_words))
    
    queries_words = list(queries_words)

100%|██████████| 100/100 [00:00<00:00, 4364.52it/s]


In [25]:
if load_from_file:
    # load words dict from file
    with open('word_list.txt') as f:
        words = f.read().split()
else:
    # save words
    with open('word_list.txt', 'w') as f:
        f.write(' '.join(words))

    words = list(words)

In [26]:
docs_amount = len(doc_list)
words_amount = len(words)
query_word_amount = len(queries_words)

print(docs_amount, words_amount, query_word_amount)

14955 111449 226


In [27]:
# document length
if load_from_file:
    docs_len = np.load('docs_len.npy')
else:
    docs_len = []

    for j in tqdm(range(docs_amount)):
        docs_len.append(sum(docs_counter[j].values()))

    docs_len = np.array(docs_len)
    np.save('docs_len', docs_len)

In [85]:
# all words count in documents and probability
if load_from_file:
    cwd = scipy.sparse.load_npz('cwd.npz')
    pwd = scipy.sparse.load_npz('pwd.npz')
else:
    indptr = [0]
    indices = []
    cwd_data = []
    pwd_data = []

    for j in tqdm(range(docs_amount)):
        doc_len = docs_len[j]

        for i in range(words_amount):
            word_count = docs_counter[j][words[i]]
            if word_count != 0:
                indices.append(i)
                cwd_data.append(word_count)
                pwd_data.append(word_count / doc_len)
        indptr.append(len(indices))

    cwd = scipy.sparse.csr_matrix((cwd_data, indices, indptr), dtype=np.float32).transpose()
    pwd = scipy.sparse.csr_matrix((pwd_data, indices, indptr), dtype=np.float32).transpose()

    scipy.sparse.save_npz('cwd', cwd)
    scipy.sparse.save_npz('pwd', pwd)

100%|██████████| 14955/14955 [10:29<00:00, 23.76it/s]


In [86]:
# process slim words
if load_from_file:
    with open('slim_word_list.txt') as f:
        slim_words = f.read().split()
else:
    words_count_list = []
    for word_row in tqdm(cwd):
        words_count_list.append(word_row.sum())
        
    most_word_index = np.flip(np.argsort(words_count_list))

    slim_words_amount = 10000
    slim_words = []
    for word_index in range(slim_words_amount):
        slim_words.append(words[word_index])
    
    slim_words = slim_words + queries_words
    slim_words = list(set(slim_words))

    # save slim words
    with open('slim_word_list.txt', 'w') as f:
        f.write(' '.join(slim_words))

# update slim words amount
slim_words_amount = len(slim_words)
print(slim_words_amount)

111449it [00:11, 9962.48it/s]
10208


In [87]:
# slim words count in documents and probability
if load_from_file:
    slim_cwd = scipy.sparse.load_npz('slim_cwd.npz').A
    slim_pwd = scipy.sparse.load_npz('slim_pwd.npz').A
else:
    indptr = [0]
    indices = []
    cwd_data = []
    pwd_data = []

    for j in tqdm(range(docs_amount)):
        doc_len = docs_len[j]
        for i in range(slim_words_amount):
            word_count = docs_counter[j][slim_words[i]]
            if word_count != 0:
                indices.append(i)
                cwd_data.append(word_count)
                pwd_data.append(word_count / doc_len)
        indptr.append(len(indices))

    slim_cwd = scipy.sparse.csr_matrix((cwd_data, indices, indptr), dtype=np.float32).transpose()
    slim_pwd = scipy.sparse.csr_matrix((pwd_data, indices, indptr), dtype=np.float32).transpose()

    scipy.sparse.save_npz('slim_cwd', slim_cwd)
    scipy.sparse.save_npz('slim_pwd', slim_pwd)

    slim_cwd = slim_cwd.A
    slim_pwd = slim_pwd.A

100%|██████████| 14955/14955 [00:59<00:00, 251.19it/s]


In [88]:
# background language model
bg = []
bg_model_cd = docs_len.sum()

for word_row in tqdm(slim_cwd):
    bg.append(word_row.sum() / bg_model_cd)

bg = np.array(bg)

100%|██████████| 10208/10208 [00:01<00:00, 7198.81it/s]


In [None]:
# def E_step():
#     ptwd_CD = np.matmul(pwt, ptd) # Common Denominator
#     for i in range(slim_words_amount):
#         for j in range(docs_amount):
#             if ptwd_CD[i][j] != 0:
#                 for k in range(topic_k):
#                     ptwd[k][i][j] = pwt[i][k] * ptd[k][j] / ptwd_CD[i][j]
#             else:
#                 ptwd[:,i,j] = 0

In [89]:
@jit(nopython=True)
def nb_E_step(pwt, ptd, cwd, topic_amount, word_amount, doc_amount):
    # empty matrix
    ptwd = np.empty((topic_amount, word_amount, doc_amount))

    # Common Denominator
    # ptwd_CD = np.dot(pwt, ptd) 

    for i in range(word_amount):
        for j in range(doc_amount):
            if cwd[i][j] != 0: 
                ptwd_CD = 0
                for k in range(topic_amount):
                    single_ptwd = pwt[i][k] * ptd[k][j]
                    ptwd[k][i][j] = single_ptwd
                    ptwd_CD += single_ptwd
                if ptwd_CD != 0:
                    for k in range(topic_amount):
                        ptwd[k][i][j] /= ptwd_CD
                else:
                    ptwd[:,i,j] = 0
            else:
                ptwd[:,i,j] = 0
    return ptwd

In [None]:
# # @jit
# def M_step():
#     # p(w/t)
#     for k in range(topic_k):
#         single_wt = np.multiply(cwd, ptwd[k])
#         single_wt_sum = single_wt.sum()
#         if single_wt_sum != 0:
#             for i in range(len(words)):
#                 pwt[i][k] = single_wt[i].sum() / single_wt_sum
#         else:
#             for i in range(len(words)):
#                 pwt[i][k] = 0
    
#     for i in range(len(pwt)):
#         pwt[i] /= pwt[i].sum()

#     # p(t/d)
#     for k in range(topic_k):
#         single_k_cwd_ptwd = np.multiply(cwd, ptwd[k])
#         for j in range(len(docs)):
#             if docs_len[j] != 0:
#                 ptd[k][j] = single_k_cwd_ptwd[:,j].sum() / docs_len[j]
#             else:
#                 ptd[k][j] = 0

In [None]:
# @jit
def M_step(times):
    for k in range(topic_k):
        single_topic_wd = np.multiply(cwd, ptwd[k])

        # p(w/t)
        single_wt_sum = single_topic_wd.sum()
        if single_wt_sum != 0:
            for i in range(slim_words_amount):
                pwt[i][k] = single_topic_wd[i].sum() / single_wt_sum
        else:
            pwt[:,k] = 1 / slim_words_amount

        # p(t/d)
        for j in range(docs_amount):
            ptd[k][j] = single_topic_wd[:,j].sum() / docs_len[j]
    
    # # norm to 1
    # for k in range(topic_k):
    #     if np.isnan(pwt[:,k].sum()):
    #         print(times, "norm ", k)
    #     pwt[:,k] /= pwt[:,k].sum()
    for j in range(docs_amount):
        deno = ptd[:,j].sum()
        if deno != 0:
            ptd[:,j] /= deno
        else:
            ptd[:,j].fill(1 / topic_k) 

In [90]:
@jit(nopython=True)
def nb_M_step(ptwd, cwd, docs_len, topic_amount, word_amount, doc_amount):
    # empty matrix
    pwt = np.empty((word_amount, topic_amount))
    ptd = np.empty((topic_amount, doc_amount))

    for k in range(topic_amount):
        single_topic_wd = np.multiply(cwd, ptwd[k])

        # p(w/t)
        single_wt_sum = single_topic_wd.sum()
        if single_wt_sum != 0:
            for i in range(word_amount):
                pwt[i][k] = single_topic_wd[i].sum() / single_wt_sum
        else:
            pwt[:,k] = 1 / slim_words_amount

        # p(t/d)
        for j in range(doc_amount):
            ptd[k][j] = single_topic_wd[:,j].sum() / docs_len[j]
    
    # # norm to 1
    # for k in range(topic_k):
    #     if np.isnan(pwt[:,k].sum()):
    #         print(times, "norm ", k)
    #     pwt[:,k] /= pwt[:,k].sum()
    for j in range(doc_amount):
        deno = ptd[:,j].sum()
        if deno != 0:
            ptd[:,j] /= deno
        else:
            ptd[:,j].fill(1 / topic_amount) 
    
    return pwt, ptd

In [91]:
def loss(times):
    loss = np.multiply(cwd, np.log(np.matmul(pwt, ptd))).sum()
    print("\nStep", times, "loss: ", loss)

In [92]:
@jit(nopython=True)
def nb_loss(times, cwd, pwt, ptd):
    loss = np.multiply(cwd, np.log(np.dot(pwt, ptd))).sum()
    print("\nStep", times, "loss: ", loss)

In [179]:
# topic

topic_k = 32
EPOCH = 50
alpha = 0.6
beta = 0.2

In [134]:
# EM Step Initial (normal)
pwt = np.random.random(size = (slim_words_amount, topic_k))

for k in range(topic_k):
    pwt[:,k] /= pwt[:,k].sum()

ptd = np.full((topic_k, docs_amount), 1 / topic_k)

ptwd = np.empty((topic_k, slim_words_amount, docs_amount))

In [165]:
for i in tqdm(range(EPOCH)):
    # nb_E_step(pwt, ptd, cwd, topic_amount, word_amount, doc_amount)
    ptwd = nb_E_step(pwt, ptd, slim_cwd, topic_k, slim_words_amount, docs_amount)
    # nb_M_step(ptwd, cwd, docs_len, topic_amount, word_amount, doc_amount)
    pwt, ptd = nb_M_step(ptwd, slim_cwd, docs_len, topic_k, slim_words_amount, docs_amount)
    # nb_loss(times, cwd, pwt, ptd)
    # nb_loss(i + 1, slim_cwd, pwt, ptd)

100%|██████████| 20/20 [1:07:43<00:00, 203.20s/it]


In [166]:
plsa_EM_final = np.matmul(pwt, ptd)

In [167]:
plsa_EM_final.shape

(10208, 14955)

In [180]:
queries_result = []

for query in tqdm(queries):
    query_result = []
    for doc_index in range(docs_amount):
        plsa_result = 1
        for word in query:
            word_index = slim_words.index(word)
            unigram_pwd = slim_pwd[word_index][doc_index]
            plsa_result = plsa_result * (alpha * unigram_pwd + beta * plsa_EM_final[word_index][doc_index] + (1 - alpha - beta) * bg[word_index])
            
            # plsa_result = plsa_result * (alpha * unigram_pwd + (1 - alpha - beta) * bg[word_index])
        query_result.append(plsa_result)
    queries_result.append(query_result)

100%|██████████| 100/100 [05:45<00:00,  3.45s/it]


In [181]:
## sort and export result
sim_df = pd.DataFrame(queries_result)
sim_df = sim_df.transpose()
sim_df.index = doc_list
sim_df.columns = query_list

In [182]:
sim_df

Unnamed: 0,301,302,303,304,305,306,307,308,309,310,...,391,392,393,394,395,396,397,398,399,400
FBIS3-1001,7.844099e-09,7.849734e-16,1.274814e-15,3.106111e-09,7.052925e-08,2.063657e-11,1.185478e-11,2.879457e-12,2.755302e-09,1.977412e-18,...,9.075182e-08,0.000026,3.095598e-09,3.459057e-08,0.000036,2.231322e-14,1.440622e-09,3.589058e-14,1.373540e-07,1.066275e-14
FBIS3-10014,3.297196e-10,1.283606e-15,4.899966e-13,1.019672e-14,8.877052e-08,9.617417e-12,2.439827e-11,7.079869e-12,1.215097e-09,5.179315e-17,...,9.075182e-08,0.000026,2.493364e-09,1.006079e-06,0.000036,5.369635e-14,3.489751e-09,5.728465e-12,4.802940e-10,1.066257e-14
FBIS3-10035,5.797906e-09,4.765186e-12,2.694060e-12,1.571261e-14,5.635946e-08,1.330651e-12,4.351448e-11,4.781515e-09,2.453849e-09,5.070105e-14,...,1.075632e-07,0.000026,1.953508e-09,6.252042e-07,0.001147,9.522312e-11,3.667022e-09,1.873587e-12,9.243940e-11,3.567190e-11
FBIS3-1007,5.566566e-09,1.601474e-13,7.229363e-16,9.946346e-15,2.894528e-08,8.093462e-11,3.850425e-11,2.771099e-10,6.058272e-09,1.074080e-15,...,3.724922e-06,0.000026,3.475826e-09,4.957312e-08,0.000036,2.918694e-12,5.914333e-09,5.948465e-13,9.237538e-11,1.066257e-14
FBIS3-10082,1.344726e-08,2.077823e-15,1.442651e-15,9.946419e-15,3.632382e-08,3.084406e-12,1.356564e-11,2.895015e-12,3.495033e-08,1.388209e-17,...,5.764371e-06,0.000026,1.999834e-08,1.180914e-07,0.000886,1.993157e-13,1.585943e-08,2.865337e-13,9.238008e-11,1.066257e-14
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
LA123190-0062,1.155658e-09,2.213033e-15,4.999048e-15,2.390150e-10,5.537168e-08,8.670230e-09,1.293670e-10,3.662383e-12,7.882543e-09,1.515021e-17,...,2.626347e-07,0.000026,1.690260e-08,3.054801e-07,0.000234,2.461188e-13,2.443372e-08,4.620844e-13,2.885301e-08,1.113715e-14
LA123190-0065,2.955998e-10,8.496913e-15,5.232009e-16,9.946078e-15,4.554302e-08,2.207681e-10,1.143272e-11,2.879954e-12,3.806054e-08,1.526408e-17,...,2.505451e-06,0.000026,1.299618e-06,1.621377e-07,0.000036,7.005154e-14,1.921439e-08,2.082375e-13,9.238721e-11,1.106068e-14
LA123190-0069,1.818882e-11,1.134254e-15,5.869964e-16,1.003980e-14,7.457401e-08,1.608521e-11,9.436688e-12,9.551267e-10,4.095120e-09,2.850526e-17,...,9.076959e-08,0.000026,1.133633e-07,1.082940e-06,0.000036,1.171648e-11,2.285321e-08,1.227615e-13,2.374386e-10,1.125836e-14
LA123190-0089,2.116689e-10,9.811996e-16,7.381630e-16,9.946079e-15,2.775759e-08,6.707682e-13,9.382999e-08,7.447472e-11,2.736146e-09,1.787428e-17,...,9.084109e-08,0.000026,2.858647e-09,2.269726e-07,0.000036,1.517991e-12,1.315771e-08,3.309848e-14,9.239268e-11,1.122154e-14


In [183]:
# save results
now = datetime.datetime.now()
save_filename = 'results/result' + '_' + 'topic' + str(topic_k) + '_EPOCH' + str(EPOCH) + '_a' + str(alpha) + '_b' + str(beta) + now.strftime("_%y%m%d_%H%M") + '.txt'
print(save_filename)

with open(save_filename, 'w') as f:
    f.write('Query,RetrievedDocuments\n')
    for query in query_list:
        f.write(query + ",")
        query_sim_df = sim_df[query].sort_values(ascending=False)
        f.write(' '.join(query_sim_df[:1000].index.to_list()) + '\n')

results/result_topic32_EPOCH50_a0.6_b0.2_201124_1603.txt


In [172]:
sparse_pwt = scipy.sparse.csr_matrix(pwt)
sparse_ptd = scipy.sparse.csr_matrix(ptd)

In [173]:
scipy.sparse.save_npz('sparse_pwt' + '_' + 'topic' + str(topic_k) + '_EPOCH' + str(EPOCH), sparse_pwt)
scipy.sparse.save_npz('sparse_ptd' + '_' + 'topic' + str(topic_k) + '_EPOCH' + str(EPOCH), sparse_ptd)

In [148]:
for i in range(100):
    print(ptd[:,i].sum())

1.0
1.0
1.0000000000000004
1.0
1.0
0.9999999999999999
0.9999999999999999
1.0
0.9999999999999999
1.0
0.9999999999999998
0.9999999999999999
0.9999999999999998
0.9999999999999997
1.0
1.0
1.0000000000000002
1.0
1.0
1.0
1.0000000000000002
1.0
1.0
1.0
1.0000000000000002
1.0
1.0
1.0
1.0000000000000002
1.0
1.0
1.0
1.0
1.0
1.0
0.9999999999999999
0.9999999999999999
1.0
1.0
0.9999999999999999
1.0000000000000002
0.9999999999999999
1.0
1.0
1.0
1.0000000000000002
1.0000000000000002
1.0
1.0000000000000002
1.0
1.0000000000000004
1.0
1.0
1.0000000000000002
1.0
1.0000000000000002
1.0
0.9999999999999999
1.0
1.0000000000000002
1.0
1.0000000000000002
0.9999999999999999
0.9999999999999999
1.0000000000000002
1.0
0.9999999999999999
0.9999999999999999
1.0000000000000002
1.0000000000000004
0.9999999999999998
0.9999999999999999
1.0
0.9999999999999998
1.0000000000000002
1.0
0.9999999999999998
1.0
1.0
0.9999999999999999
1.0
1.0000000000000002
0.9999999999999999
0.9999999999999993
1.0
0.9999999999999998
1.0
1.0
1.0

In [147]:
for i in range(topic_k):
    print(pwt[:,i].sum())

1.0000000000002123
1.0000000000001819
1.0000000000001932
1.0000000000002278
1.0000000000002331
1.0000000000001326
1.0000000000001597
1.000000000000249
1.0000000000001825
1.0000000000001943
1.0000000000003209
1.0000000000001708
1.000000000000161
1.000000000000195
1.000000000000146
1.000000000000146
1.0000000000002833
1.0000000000001665
1.0000000000001719
1.0000000000001577
1.000000000000209
1.0000000000001965
1.0000000000001972
1.0000000000002218
1.0000000000002027
1.000000000000149
1.000000000000151
1.0000000000001195
1.0000000000001688
1.0000000000002178
1.0000000000001923
1.0000000000002072
