In [1]:
from tqdm import tqdm
import math
import numpy as np
query_list_path = "q_100_d_10000/query_list.txt"
doc_list_path = "q_100_d_10000/doc_list.txt"

with open(query_list_path, "r") as f:
    q_list = f.read().split('\n')[:-1]
with open(doc_list_path, "r") as f:
    d_list = f.read().split('\n')[:-1]
d_dict = {doc:doc_index for doc_index,doc in enumerate(d_list)}

def get_query_word(q):
    with open("q_100_d_10000/queries/{}.txt".format(q),'r') as f:
        words = f.read().split(' ')
    return words
def get_doc_word(d):
    with open("q_100_d_10000/docs/{}.txt".format(d),'r') as f:
        words = f.read().split(' ')
    return words
def get_random_probability_matrix(event_num,condition_num):
    matrix=np.random.random_sample((event_num,condition_num))
    for i in range(condition_num):
        temp_sum = matrix[:,i].sum()
        matrix[:,i]/=temp_sum #set sum to 1
    return matrix

class parameter_retriever:
    def __init__(self,topic_num):
        self.index_term_dict=dict()
        self.index_term_num = 0
        self.topic_num = topic_num
        self.doc_num = len(d_list)

        self.c_wd = [dict()]*self.index_term_num
        self.doc_length = [0]*len(d_list)
        self.c_w = [0]*self.index_term_num

        self.P_w_T = []
        self.P_T_d = []


    def create_index_term_set(self):
        print("creating index term set")
        print(" creating index term set from query")
        index_term_set_q = set()
        for q in tqdm(q_list):
            words = get_query_word(q)
            index_term_set_q = index_term_set_q.union(set(words))

        print(" creating index term set from doc")
        index_term_set_d = set()
        for d in tqdm(d_list):
            words = get_doc_word(d)
            index_term_set_d = index_term_set_d.union(set(words))

        index_term_set = index_term_set_q.union(index_term_set_d)
        self.index_term_dict={index_term:index_term_index for index_term_index,index_term in enumerate(index_term_set)}
        self.index_term_num = len(index_term_set)
        print("number of words in index_term_set: {}".format(self.index_term_num))
        print("...done")

    def get_word_count_in_doc(self):
        print("getting word counts in doc/doc_length/BG")
        self.c_wd = [dict() for _ in range(self.index_term_num)]
        self.doc_length = [0]*len(d_list)
        self.c_w = [0]*self.index_term_num

        for doc_index,doc in tqdm(enumerate(d_list)):
            words = get_doc_word(doc)
            for word in set(words):
                if word in self.index_term_dict:
                    self.c_w[self.index_term_dict[word]]+=1
                    self.c_wd[self.index_term_dict[word]][doc_index]=0
            for word in words:
                if word in self.index_term_dict:
                    self.c_wd[self.index_term_dict[word]][doc_index]+=1
                    self.doc_length[doc_index]+=1
        print("...done")

    def clean_index_term_set(self):
        print("cleaning index term whose c_w<=2")
        index_term_set = {index_term for index_term_index,index_term in enumerate(self.index_term_dict) if self.c_w[index_term_index]>2}
        self.index_term_dict={index_term:index_term_index for index_term_index,index_term in enumerate(index_term_set)}
        self.index_term_num = len(self.index_term_dict)
        print(" index_term_num after cleaning: {}".format(self.index_term_num))
        self.get_word_count_in_doc()
        print("...done")

    def initPossibilities(self):
        print("initializing possibilities")
        random_matrix = get_random_probability_matrix(self.index_term_num,self.topic_num)
        for index_term_idx in range(self.index_term_num):
            self.P_w_T.append([probability for probability in random_matrix[index_term_idx,:].tolist()])
        
        random_matrix = get_random_probability_matrix(self.topic_num,self.doc_num)
        for topic_idx in range(self.topic_num):
            self.P_T_d.append([probability for probability in random_matrix[topic_idx,:].tolist()])
        print("...done")


    def calculate_P_T_wd(self,topic_idx,index_term_index,doc_index):
        buffer = self.P_w_T[index_term_index][topic_idx]*self.P_T_d[topic_idx][doc_index]
        sum=0
        for k in range(self.topic_num):
            sum+=self.P_w_T[index_term_index][k]*self.P_T_d[k][doc_index]
        return buffer/sum
        
    def iter(self):
        # print("start E_step")
        sum_of_cP_in_D=[[0]*self.topic_num]*self.index_term_num
        sum_of_cP_in_V=[[0]*len(d_list)]*self.topic_num
        for index_term_index in tqdm(range(self.index_term_num),leave=False):
            for doc_index in self.c_wd[index_term_index].keys():
                for topic_idx in range(self.topic_num):
                    cP=self.c_wd[index_term_index][doc_index]*self.calculate_P_T_wd(topic_idx,index_term_index,doc_index)
                    sum_of_cP_in_D[index_term_index][topic_idx]+=cP
                    sum_of_cP_in_V[topic_idx][doc_index]+=cP
        # print("start M_step")
        # print(" process P_w_T")
        for topic_index in range(self.topic_num):
            sum_of_topic_k = 0
            for index_term_index in range(self.index_term_num):
                sum_of_topic_k+=sum_of_cP_in_D[index_term_index][topic_index]
            for index_term_index in range(self.index_term_num):
                self.P_w_T[index_term_index][topic_index]=sum_of_cP_in_D[index_term_index][topic_index]/sum_of_topic_k
        # print(" process P_T_d")
        for doc_index in range(len(d_list)):
            if self.doc_length[doc_index]==0:
                for topic_index in range(self.topic_num):
                    self.P_T_d[topic_index][doc_index] = 1/self.topic_num
            else:
                for topic_index in range(self.topic_num):
                    self.P_T_d[topic_index][doc_index]=sum_of_cP_in_V[topic_index][doc_index]/self.doc_length[doc_index]
        # print("...done")



class PLSA:
    def __init__(self,topic_num,alpha,beta):
        self.topic_num=topic_num
        self.param = parameter_retriever(topic_num)
        
        self.alpha=alpha
        self.beta=beta
    
    def init_param(self):
        self.param.create_index_term_set()
        self.param.get_word_count_in_doc()
        self.param.clean_index_term_set()
        self.param.initPossibilities()

    def train(self,n):
        for i in tqdm(range(n)):
            # print("iter: {}".format(i))
            self.param.iter()

    def get_sim(self,idx_doc,q):
        if self.param.doc_length[idx_doc]==0:
            return -999999999
        logsum=0
        for word in get_query_word(q):
            if word in self.param.index_term_dict:
                word_index=self.param.index_term_dict[word]
                if idx_doc in self.param.c_wd[word_index]:
                    first = self.alpha*self.param.c_wd[word_index][idx_doc]/self.param.doc_length[idx_doc]
                    first = math.log(first)
                else:
                    first = -15

                second = 0
                for topic_idx in range(self.topic_num):
                    second+=self.param.P_w_T[word_index][topic_idx]*self.param.P_T_d[topic_idx][idx_doc]
                second*=self.beta
                second = math.log(second)
                
                third = (1-self.alpha-self.beta)*self.param.c_w[word_index]/self.param.doc_num
                third = math.log(third)

                temp = np.logaddexp(first, second)
                temp = np.logaddexp(temp, third)
                logsum+=temp
        return logsum
    def query(self,q):
        sim={}
        for idx_doc,doc in enumerate(d_list):

            sim[doc] = self.get_sim(idx_doc,q)
        sim = sorted(sim.items(), key=lambda x:x[1],reverse=True)
        ans = ""
        for i in sim:
            ans+=i[0]+' '
        return ans


In [2]:
#time: < 1min15sec
#creating index term set ~47sec
#cleaning index term whose c_w<=2 <10sec
plsa=PLSA(32,0.3,0.1)
plsa.init_param()

creating index term set
 creating index term set from query


100%|██████████| 100/100 [00:00<00:00, 14283.34it/s]


 creating index term set from doc


100%|██████████| 10000/10000 [00:47<00:00, 210.57it/s]


number of words in index_term_set: 93065
...done
getting word counts in doc/doc_length/BG


10000it [00:09, 1043.22it/s]


...done
cleaning index term whose c_w<=2
 index_term_num after cleaning: 31371
getting word counts in doc/doc_length/BG


10000it [00:09, 1063.82it/s]


...done
...done
initializing possibilities
...done


In [3]:
# 1Epoch
# total<2min45sec
plsa.train(100)

 11%|█         | 11/100 [1:06:43<8:55:56, 361.31s/it]
  6%|▋         | 1993/31371 [00:21<06:54, 70.85it/s] 

In [4]:
f=open("ans.txt","w")
f.write("Query,RetrievedDocuments\n")
for q in tqdm(q_list):
    ranking=plsa.query(q)
    f.writelines(q+","+ranking+'\n')
f.close()

 19%|█▉        | 19/100 [00:16<01:11,  1.13it/s]


KeyboardInterrupt: 

In [5]:
import pickle
with open('230_10.pickle', 'wb') as f:
    pickle.dump(plsa, f)