#### imports

In [19]:
import math
import lucene
import time
import nltk
import itertools
import numpy as np
from tqdm import tqdm
from java.io import File
import xml.etree.ElementTree as ET
from collections import defaultdict
from org.apache.lucene.store import FSDirectory
from org.apache.lucene.util import BytesRefIterator
from org.apache.lucene.index import DirectoryReader, Term
from org.apache.lucene.analysis.en import EnglishAnalyzer
from org.apache.lucene.analysis.core import WhitespaceAnalyzer
from org.apache.lucene.queryparser.classic import QueryParser
from org.apache.lucene.search import IndexSearcher, BooleanQuery, BooleanClause, TermQuery, BoostQuery
from org.apache.lucene.search.similarities import BM25Similarity, LMJelinekMercerSimilarity, LMDirichletSimilarity
lucene.initVM()

<jcc.JCCEnv at 0x7f5f52a54790>

In [20]:
q_name = 'trec6'

In [21]:
index_path = '../../../index/'
topicFilePath = f'../../../{q_name}.xml'
qrel_file = '../../../trec678_robust.qrel'

directory = FSDirectory.open(File(index_path).toPath())
indexReader = DirectoryReader.open(directory)

In [22]:
def query_topics(xml_file):
    tree = ET.parse(xml_file)
    root = tree.getroot()

    parsed_topics = {}

    for top in root.findall('top'):
        num = top.find('num').text.strip()
        title = top.find('title').text.strip()
        parsed_topics[num] = title

    return parsed_topics

In [23]:
def makeRelJudgeDict(qrelFilePath):
    # {qid1:{docid1:0/1,docid2:0/1,...}, qid2:{docid2:0/1,docid4:0/1,...},...}
    relJudgeDict = {}
    with open(qrelFilePath, 'r') as f:
        for line in f:
            l = line.split()
            qid, docid, judgement = l[0], l[2], int(l[3])
            if qid not in relJudgeDict:
                relJudgeDict[qid] = {docid: judgement}
            else:
                relJudgeDict[qid][docid] = judgement
    return relJudgeDict

def isTrueRelevant(qid, docid, relJudgeDict):
    # returns if the doc is True relevant, for the given query, according to the judgment file
    if qid not in relJudgeDict:
        return False
    if docid not in relJudgeDict[qid]:
        return False
    if relJudgeDict[qid][docid] == 1:   # 1 -> Relevant TRF
        return True
    if relJudgeDict[qid][docid] == 0:
        return False

def isTrueNonRelevant(qid, docid, relJudgeDict):
    # returns if the doc is NOT true relevant, for the given query, according to the judgment file
    if qid not in relJudgeDict:
        return False
    if docid not in relJudgeDict[qid]:
        return False
    if relJudgeDict[qid][docid] == 0:   # 0 -> Non-relevant TRF
        return True
    if relJudgeDict[qid][docid] == 1:
        return False

relJudgeDict = makeRelJudgeDict(qrel_file)

In [24]:
query_all = query_topics(topicFilePath)

In [25]:
def getDocumentVector(luceneDocid, indexReader):               
    
    docVec = {}
    D = 0                                 
    
    terms = indexReader.getTermVector(luceneDocid, 'CONTENTS')
    iterator = terms.iterator()
    for term in BytesRefIterator.cast_(iterator):
        t = term.utf8ToString()
        tf = iterator.totalTermFreq()  
        D += tf
        docVec[t] = tf
    
    docVec = {key: (value / D)  for key, value in docVec.items()}

    return docVec


### RM3


In [26]:
def search(indexReader, query, similarity, top_rel_doc, qid, tpd, tnd):
    analyzer = EnglishAnalyzer()
    searcher = IndexSearcher(indexReader)
    searcher.setSimilarity(similarity)
    # query = QueryParser('CONTENTS', analyzer).escape(query)      # a few titles had '/' in them which 
    
    # query = QueryParser("CONTENTS", analyzer).parse(query)

    scoreDocs = searcher.search(query, top_rel_doc).scoreDocs
    
    docids = [scoreDoc.doc for scoreDoc in scoreDocs]

    relevent_docs = []
    nonrel_docs = []
    for id in docids:
        doc = searcher.doc(id)
        if isTrueRelevant(qid, doc.get('ID'), relJudgeDict):
            relevent_docs.append(id)
        if isTrueNonRelevant(qid, doc.get('ID'), relJudgeDict):
            nonrel_docs.append(id)
   
    # print(qid,relevent_docs)
    docids = relevent_docs[:tpd]
    ndocid = nonrel_docs[:tnd]
    
    set_cont = {term for doc in docids for term in getDocumentVector(doc, indexReader).keys()}

    
    
    # set_n = {term for doc in ndocid for term in getDocumentVector(doc, indexReader).keys()}
    
    # query_terms_set = set([term.strip()[9:] for term in query.toString().split()])
    # set_cont = set_cont.difference(set_n) | query_terms_set

    # set_cont = {ele for ele in set_cont if ele.isalpha()} 

    return set_cont, docids, ndocid

In [27]:
def RM3_term_selection(Query, set_ET, docs, ndocid, indexReader, alpha, lamb, expanded_query_terms):
    
    totalTF = indexReader.getSumTotalTermFreq("CONTENTS")

    Q = Query.split()
    weight = {}

    cf = {}
    for t in set_ET | set(Q):
        T = Term("CONTENTS", t)
        cf[t] = indexReader.totalTermFreq(T)/totalTF

    docVectors = {}
    ndocVectors = {}
   
    
    for d in docs:                    
        docVectors[d] = getDocumentVector(d, indexReader)
    
    for d in ndocid:                    
        ndocVectors[d] = getDocumentVector(d, indexReader)
        
        
    # tagged = nltk.tag.pos_tag(list(set_ET), tagset='universal')
    # set_ET = set([ele[0] for ele in tagged if ele[1] in ['NOUN','ADJ','ADP','X']])
    ml = lamb
    for w in set_ET:
        p_wr = 0
        tdc = 0
        for d in docs:                  
            tdc = tdc + 1
            # p_wd = (ml*(docVectors[d].get(w,0)) + (1 - ml)*cf[w]) 
            p_wd = docVectors[d].get(w,0)     
        
            p_q = 1
            for q in Q:
                # p_q = p_q*docVectors[d].get(q,0)   
                p_q = p_q*(ml*(docVectors[d].get(q,0)) + (1 - ml)*cf[q])   

            p_wr = p_wr + p_wd
        
        p_wr = p_wr/tdc

        p_wnr = 0
        ndc = 0 
        for d in ndocid:
            ndc = ndc + 1
            
            # p_wnd = (ml*(ndocVectors[d].get(w,0)) + (1 - ml)*cf[w]) 
            p_wnd = ndocVectors[d].get(w,0)

            p_q = 1
            for q in Q:
                # p_q = p_q*docVectors[d].get(q,0)   
                p_q = p_q*(ml*(ndocVectors[d].get(q,0)) + (1 - ml)*cf[q])

            p_wnr = p_wnr + p_wnd
        if ndc != 0:
            p_wnr = p_wnr/ndc
        
        p_wr = p_wr - p_wnr

        if p_wr > 0:
            weight[w] = p_wr


    weight = dict(sorted(weight.items(), key=lambda x:x[1], reverse=True)[:expanded_query_terms])
    
    norm = sum(weight.values())
    weight = {w:weight[w]/norm for w in weight}
 
    for w in weight.keys() | set(Q):
        weight[w] = (alpha*weight.get(w,0)) + (1-alpha)*(Q.count(w)/len(Q))
  

    temp_list = sorted(weight.items(), key=lambda x:x[1], reverse=True)
    sorted_weights = dict(temp_list)

    return sorted_weights

In [28]:
def expanded_query_BM25(search, RM3_term_selection, k1, b, alpha, top_rel_doc, expanded_query_terms, mu, tpd, tnd):

    analyzer = EnglishAnalyzer()
    similarity = BM25Similarity(k1,b)
    expanded_q = []

    i = 0
    for qid, q in tqdm(query_all.items(), colour='red', desc='Expanding Queries', leave=False):
    # for qid, q in query_all.items():
     
        i += 1 
        escaped_q = QueryParser('CONTENTS', analyzer).escape(q)      # a few titles had '/' in them which 
        query = QueryParser('CONTENTS', analyzer).parse(escaped_q)
        
        query_terms = [term.strip()[9:] for term in query.toString().split()]
        parsed_q = ' '.join(query_terms)
#         print(parsed_q)
        expension_term_set, docids, ndocid = search(indexReader, query, similarity, top_rel_doc, qid, tpd, tnd)
        # expension_term_set, docids, ndocid = search(indexReader, q, similarity, top_rel_doc, qid, tpd, tnd)
        weights = RM3_term_selection(parsed_q, expension_term_set, docids, ndocid, indexReader, alpha, mu, expanded_query_terms)
    
        # print(weights.keys())    
        booleanQuery = BooleanQuery.Builder()
        for m, n in weights.items():
            t = Term('CONTENTS', m)
            tq = TermQuery(t)
            boostedTermQuery = BoostQuery(tq, float(n))
            BooleanQuery.setMaxClauseCount(4096)
            booleanQuery.add(boostedTermQuery, BooleanClause.Occur.SHOULD)
        booleanQuery = booleanQuery.build()
       
        expanded_q.append(booleanQuery)   

    return expanded_q

In [29]:
def search_retrived(indexReader, Query, Qid, similarity, out_name):

    searcher = IndexSearcher(indexReader)
    searcher.setSimilarity(similarity)
   
    scoreDocs = searcher.search(Query, 1000).scoreDocs             #retrieving top 1000 relDoc
    i = 1
    res = ''

    for scoreDoc in scoreDocs:
        doc = searcher.doc(scoreDoc.doc)
        r = str(Qid) + '\t' + 'Q0' + '\t' + str(doc.get('ID')) + '\t' + str(i) + '\t' + str(scoreDoc.score) + '\t' + str(out_name) + '\n'
        res += r
        i = i+1   

    return res

In [30]:
def run_RM3(top_PRD, expanded_query_terms, alpha, lamb, tpd, tnd):
    expand_q = expanded_query_BM25(search, RM3_term_selection, k1, b, alpha, top_PRD, expanded_query_terms, lamb, tpd, tnd)
                                       
    name = 'prm_'
    sim = BM25Similarity(k1,b)
    name = name + 'BM25_' + str(k1) + '_'+ str(b)

    file_name = f'./res_TRF/{q_name}/{tpd}_{tnd}_{q_name}_lamb_' + str(lamb) +'_docs_' + str(top_PRD) + '_terms_' + str(expanded_query_terms) + '_alpha_' + str(alpha) +'_tf' +'.txt'
    out_file = open(file_name, "w")

    res = ''
    for i in tqdm(range(len(query_all)),colour='cyan', desc = 'Re-retrival', leave=False):
    # for i in range(len(query_all)):
    
        result =  search_retrived(indexReader, expand_q[i], list(query_all.keys())[i], sim, name)
        res = res + result

    out_file.write(res)
    out_file.close()
    # print("Retrieval Completed - result dumped in", file_name)

In [31]:
k1 = 0.8
b = 0.4

tpd = [20]
tnd = [0, 5 ,20]
top_PRD = [1000]
expanded_query_terms = [50]
alpha = [0.8]
lamb = [0.7]

parameters = list(itertools.product(top_PRD, expanded_query_terms, alpha, lamb, tpd, tnd))

for num_doc, num_q, alpha, lamb, tpd, tnd in tqdm(parameters, colour='red'):
    run_RM3(num_doc, num_q, alpha, lamb, tpd, tnd)

100%|[31m██████████[0m| 3/3 [01:23<00:00, 27.79s/it]
