### imports

In [1]:
import math
import lucene
from java.io import File
from org.apache.lucene.store import FSDirectory
from org.apache.lucene.util import BytesRefIterator
from org.apache.lucene.index import DirectoryReader, Term
from org.apache.lucene.analysis.en import EnglishAnalyzer
from org.apache.lucene.analysis.core import WhitespaceAnalyzer
from org.apache.lucene.queryparser.classic import QueryParser
from org.apache.lucene.search import IndexSearcher, BooleanQuery, BooleanClause, TermQuery, BoostQuery
from org.apache.lucene.search.similarities import BM25Similarity, LMJelinekMercerSimilarity, LMDirichletSimilarity
lucene.initVM()

<jcc.JCCEnv at 0x7f95e6efe150>

In [15]:
import xml.etree.ElementTree as ET

# indexPath = '/mnt/c/Users/priya/Desktop/MS/documents_index/'
indexPath = '../../TREC678/documents_index/'
q_name = 'trec6'
topicFilePath = f'../../{q_name}.xml'  # 50 queries
tree = ET.parse(topicFilePath)
topics = tree.getroot()
index_path = indexPath
directory = FSDirectory.open(File(index_path).toPath())
indexReader = DirectoryReader.open(directory)

In [16]:
FIELDNAME = 'CONTENTS'       # Lucene index field name
# calculating avgdl for queries. Used in BM25_query().
analyzer = EnglishAnalyzer()
query_lens = []
for topic in topics:
    queryKeywordsField = 'title'     # other fields are 'desc'and 'narr'
    q = topic.find(queryKeywordsField).text.strip()
    escaped_q = QueryParser(FIELDNAME, analyzer).escape(q)      # a few titles had '/' in them which
    # EnglishAnalyzer was not able to parse
    # without escaping those special characters
    query = QueryParser(FIELDNAME, analyzer).parse(escaped_q)
    query_terms = [term.strip()[len(FIELDNAME)+1:]
                   for term in query.toString().split()]
    query_lens.append(len(query_terms))
avgdl_query = sum(query_lens)/len(query_lens)
# calculating avgdl for the corpus. Used in BM25_docVec().
N = indexReader.numDocs()
avgdl_collection = indexReader.getSumTotalTermFreq(FIELDNAME)/N

In [17]:
print(N)

528155


In [18]:
def makeRelJudgeDict(qrelFilePath):
    # returns a nested dictionary representation of trec678 qrel file
    # for faster rel judgement checks during Rocchio query expansion.
    # Nested Dict strcture is like,
    # {qid1:{docid1:0/1,docid2:0/1,...}, qid2:{docid2:0/1,docid4:0/1,...},...}
    relJudgeDict = {}
    with open(qrelFilePath, 'r') as f:
        for line in f:
            l = line.split()
            qid, docid, judgement = l[0], l[2], int(l[3])
            if qid not in relJudgeDict:
                relJudgeDict[qid] = {docid: judgement}
            else:
                relJudgeDict[qid][docid] = judgement
    return relJudgeDict


def isTrueRelevant(qid, docid, relJudgeDict):
    # returns if the doc is True relevant, for the given query, according to the judgment file
    if qid not in relJudgeDict:
        return False
    if docid not in relJudgeDict[qid]:
        return False
    if relJudgeDict[qid][docid] == 1:   # 1 -> Relevant TRF
        return True
    if relJudgeDict[qid][docid] == 0:
        return False


def isTrueNonRelevant(qid, docid, relJudgeDict):
    # returns if the doc is NOT true relevant, for the given query, according to the judgment file
    if qid not in relJudgeDict:
        return False
    if docid not in relJudgeDict[qid]:
        return False
    if relJudgeDict[qid][docid] == 0:   # 0 -> Non-relevant TRF
        return True
    if relJudgeDict[qid][docid] == 1:
        return False


# SET this to the relevance judgment file path
# qrelPath = '/mnt/c/Users/priya/Desktop/MS/trec678_robust.qrel'
qrelPath = '../../trec678_robust.qrel'
# making a nested dictionary representation of judgement file for faster access
relJudgeDict = makeRelJudgeDict(qrelPath)


In [19]:
d = {}
for k in relJudgeDict:
    i = 0
    for key, values in relJudgeDict[k].items():
        if values == 2:
            i += 1
    # print(i)
    d[k] = i


In [20]:
sum(d.values())

1031

In [21]:
d

{'301': 0,
 '302': 0,
 '303': 0,
 '304': 0,
 '305': 0,
 '306': 0,
 '307': 0,
 '308': 0,
 '309': 0,
 '310': 0,
 '311': 0,
 '312': 0,
 '313': 0,
 '314': 0,
 '315': 0,
 '316': 0,
 '317': 0,
 '318': 0,
 '319': 0,
 '320': 0,
 '321': 0,
 '322': 0,
 '323': 0,
 '324': 0,
 '325': 0,
 '326': 0,
 '327': 0,
 '328': 0,
 '329': 0,
 '330': 0,
 '331': 0,
 '332': 0,
 '333': 0,
 '334': 0,
 '335': 0,
 '336': 0,
 '337': 0,
 '338': 0,
 '339': 0,
 '340': 0,
 '341': 0,
 '342': 0,
 '343': 0,
 '344': 0,
 '345': 0,
 '346': 0,
 '347': 0,
 '348': 0,
 '349': 0,
 '350': 0,
 '351': 0,
 '352': 0,
 '353': 0,
 '354': 0,
 '355': 0,
 '356': 0,
 '357': 0,
 '358': 0,
 '359': 0,
 '360': 0,
 '361': 0,
 '362': 0,
 '363': 0,
 '364': 0,
 '365': 0,
 '366': 0,
 '367': 0,
 '368': 0,
 '369': 0,
 '370': 0,
 '371': 0,
 '372': 0,
 '373': 0,
 '374': 0,
 '375': 0,
 '376': 0,
 '377': 0,
 '378': 0,
 '379': 0,
 '380': 0,
 '381': 0,
 '382': 0,
 '383': 0,
 '384': 0,
 '385': 0,
 '386': 0,
 '387': 0,
 '388': 0,
 '389': 0,
 '390': 0,
 '391': 0,

In [22]:

def tf_idf_query(term, query_terms):
    # returns TF-IDF weight for the given term in query
    D = len(query_terms)
    N = indexReader.numDocs()
    tf = query_terms.count(term)
    df = indexReader.docFreq(Term(FIELDNAME, term))
    weight = (tf/D)*(math.log(N/(df+1)))
    return weight


def tf_idf_docVec(docVec, D):
    # tf-idf weight calculation for all the terms in the document vector
    N = indexReader.numDocs()       # no. of total docs in the corpus
    for t in docVec:
        tf = docVec[t][0]
        df = docVec[t][1]
        idf = math.log(N/(df+1))
        docVec[t] = (tf/D)*idf
    return docVec


def BM25_query(term, query_terms, k1=0.8, b=0.4):
    # returns Okapi BM25 weight for the given term in query
    D = len(query_terms)
    N = indexReader.numDocs()
    tf = query_terms.count(term)
    df = indexReader.docFreq(Term(FIELDNAME, term))
    idf = math.log(1+((N-df+0.5)/(df+0.5)))
    weight = ((tf*(1+k1))/(tf+k1*((1-b)+(b*D/avgdl_query))))*idf
    return weight


def BM25_docVec(docVec, D, k1=0.8, b=0.4):
    # Okapi BM25 weight calculation for all the terms in the document vector
    N = indexReader.numDocs()       # no. of total docs in the corpus
    for t in docVec:
        tf = docVec[t][0]
        df = docVec[t][1]
        idf = math.log(1+((N-df+0.5)/(df+0.5)))
        docVec[t] = ((tf*(1+k1))/(tf+k1*((1-b)+(b*D/avgdl_collection))))*idf

    return docVec


def getDocumentVector(luceneDocid, weightScheme):
    # returns document vector in dictionary form with tf-idf weights

    docVec = {}

    D = 0                           # doc length, i.e., total no. of tokens in the doc
    terms = indexReader.getTermVector(luceneDocid, FIELDNAME)
    iterator = terms.iterator()
    for term in BytesRefIterator.cast_(iterator):
        t = term.utf8ToString()
        tf = iterator.totalTermFreq()                           # termFreq of term,t
        # docFreq of term,t
        df = indexReader.docFreq(Term(FIELDNAME, t))
        D += tf
        docVec[t] = [tf, df]

    if weightScheme == 'TFIDF':
        docVec = tf_idf_docVec(docVec, D)
    elif weightScheme == 'BM25':
        docVec = BM25_docVec(docVec, D)


    docVec = {key: value/sum(docVec.values()) for key, value in docVec.items()}

    return docVec


In [23]:
def rocchio_TRF(query, qid, top_k_docs, searcher, N, alpha, beta, gamma, tpd, tnd, weightScheme):
    """Implements Rocchio's relevance feedback and returns a modified query

    Args:
        query (org.apache.lucene.search.Query): lucene parsed version of the initial/original query
        top_k_docs (lucene._lucene.JArray_object): scoreDocs returned after performing search with top k results
        N (int): number of terms to be in the returned modified query
        alpha (float): weight for original query
        beta (float): weight for positive feedback
        weightScheme (string): TFIDF or BM25 for term weighting

    Returns:
        list: expanded/modified query list of string query terms
    """

    # processing JQuery object to extract query terms in form of a list
    query_terms = [term.strip()[len(FIELDNAME)+1:]
                   for term in query.toString().split()]

    # creating query vector Q0
    Q0_vector = {}
    for term in query_terms:
        if weightScheme == 'TFIDF':
            Q0_vector[term] = tf_idf_query(term, query_terms)
        elif weightScheme == 'BM25':
            Q0_vector[term] = BM25_query(term, query_terms)

    Q0_vector = {key: value/sum(Q0_vector.values())
                 for key, value in Q0_vector.items()}

    # Rel for Relevant, NRel for Non-relevant
    sumRelDocsVector, sumNRelDocsVector = {}, {}
    numRel, numNRel = 0, 0
    Rellist, NRellist = [], []

    for scoreDoc in top_k_docs:
        doc = searcher.doc(scoreDoc.doc)
        if isTrueRelevant(qid, doc.get('ID'), relJudgeDict):
            Rellist.append(scoreDoc.doc)
        if gamma == 0:
            pass
        else:
            if isTrueNonRelevant(qid, doc.get('ID'), relJudgeDict):
                NRellist.append(scoreDoc.doc)
    print(Rellist)
    for relDoc in Rellist:
        if numRel < tpd:
            docVec = getDocumentVector(relDoc, weightScheme)
            numRel += 1
            # vector addition of sumRelDocsVector and docVec
            sumRelDocsVector = {term: sumRelDocsVector.get(
                term, 0) + docVec.get(term, 0) for term in set(sumRelDocsVector) | set(docVec)}

    for nreldoc in NRellist:
        if gamma != 0 and numNRel < tnd:
            docVec = getDocumentVector(nreldoc, weightScheme)
            numNRel += 1
            # vector addition of sumNRelDocsVector and docVec
            sumNRelDocsVector = {term: sumNRelDocsVector.get(
                term, 0) + docVec.get(term, 0) for term in set(sumNRelDocsVector) | set(docVec)}
    

    if numRel == 0 or len(Rellist) == 0:
        newlist = [doc for doc in relJudgeDict[qid]
                   if relJudgeDict[qid][doc] == 1][:tpd]
        for rd in newlist:
            q = QueryParser('ID', analyzer).parse(rd)
            scoreDocs = searcher.search(q, 1).scoreDocs
            for scoreDoc in scoreDocs:
                docVec = getDocumentVector(scoreDoc.doc, weightScheme)
                sumRelDocsVector = {term: sumRelDocsVector.get(
                    term, 0) + docVec.get(term, 0) for term in set(sumRelDocsVector) | set(docVec)}

        numRel = len(newlist)

        # print(f'Nrel_vec_{sumNRelDocsVector}')

    if numNRel == 0 or len(NRellist) == 0:
        newNlist = [doc for doc in relJudgeDict[qid]
                   if relJudgeDict[qid][doc] == 0][:tnd]
        for nd in newNlist:
            q = QueryParser('ID', analyzer).parse(nd)
            scoreDocs = searcher.search(q, 1).scoreDocs
            for scoreDoc in scoreDocs:
                docVec = getDocumentVector(scoreDoc.doc, weightScheme)
                sumNRelDocsVector = {term: sumNRelDocsVector.get(
                    term, 0) + docVec.get(term, 0) for term in set(sumNRelDocsVector) | set(docVec)}

        numNRel = len(newNlist)
        

    # normlaized Relevant Docs Vector
    r = {term: sumRelDocsVector[term]/numRel for term in sumRelDocsVector}
    # normlaized Non-Relevant Docs Vector
    nr = {term: sumNRelDocsVector[term]/numNRel for term in sumNRelDocsVector}

    # final Rocchio formula for Qm
    # expanded_query = [[term, alpha*Q0_vector.get(term, 0) + beta*r.get(term, 0) - gamma*nr.get(term, 0)] for term in set(Q0_vector) | set(r) | set(nr)]
    expanded_query = [[term, alpha*Q0_vector.get(term, 0) + beta*r.get(term, 0) - gamma*nr.get(term, 0)] for term in set(Q0_vector) | set(r)]

    # sorted (descending) the expanded query list as per term scores
    expanded_query.sort(key=lambda x: x[1], reverse=True)
    # selecting top N expanded query terms
    Qm_with_scores = expanded_query[:int(N)]

    # weighting expanded query terms
    booleanQuery = BooleanQuery.Builder()
    for item in Qm_with_scores:
        if item[1] >= 0:
            t = Term(FIELDNAME, item[0])
            tq = TermQuery(t)
            boostedTermQuery = BoostQuery(tq, item[1])
            BooleanQuery.setMaxClauseCount(4096)
            booleanQuery.add(boostedTermQuery, BooleanClause.Occur.SHOULD)
    modifiedQuery = booleanQuery.build()

    return modifiedQuery   # modified query


In [24]:
def bm25_rocchio(numPRD, N, alpha, beta, gamma, tpd, tnd, weightScheme='TFIDF'):
    """ Performs bm25 search with Rocchio pseudo relevance feedback 
        on a set of queries and output the result in a file

    Args:
        numPRD: no. of pseudo relevant docs
        N: no. of expansion terms
        alpha, beta: Rocchio model parameters
        weightScheme (string): TFIDF or BM25 for term weighting
        
    Returns:
        None
    """

    model = 'bm25'
#     LAMBDA = 0.4   # LM-JM baseline lambda parameter
#     similarityModel = LMJelinekMercerSimilarity(LAMBDA)

    k1 = 0.8
    b = 0.4
    similarityModel = BM25Similarity(k1, b)

    # change result file path below
    if weightScheme == 'BM25' or weightScheme == 'TFIDF':
        rocchioOutputPath = f"./Rocchio_TRF/{q_name}/{tpd}_{tnd}_{q_name}_BM25_Rocchio_numPRD={numPRD}_N={N}_alpha={alpha}_beta={beta}_gamma={gamma}_{weightScheme}.res"
    else:
        print('Warning: weightScheme entered not a valid parameter value. Taking default weightScheme: TFIDF')
        weightScheme = 'TFIDF'
        rocchioOutputPath = f"./Rocchio_TRF/{weightScheme}/TREC6_BM25_Rocchio_numPRD={numPRD}_N={N}_alpha={alpha}_beta={beta}_{weightScheme}.res"

    f = open(rocchioOutputPath, 'w')

    # setting up the searcher
    analyzer = EnglishAnalyzer()    # used same analyzer as indexer
#     index_path = './index/'
    index = index_path
    directory = FSDirectory.open(File(index_path).toPath())
    searcher = IndexSearcher(DirectoryReader.open(directory))
    # setting the similarity model
    searcher.setSimilarity(similarityModel)

    # print('\nRetrieving ...')

    # search on 50 queries from the topic file 'trec6.xml'
    for topic in topics:
        qidField = 'num'
        queryKeywordsField = 'title'     # other fields are 'desc'and 'narr'

        qid = topic.find(qidField).text.strip()
        q = topic.find(queryKeywordsField).text.strip()

        escaped_q = QueryParser(FIELDNAME, analyzer).escape(
            q)      # a few titles had '/' in them which
        # EnglishAnalyzer was not able to parse
        # without escaping those special characters
        query = QueryParser(FIELDNAME, analyzer).parse(escaped_q)

        # print(f'Rocchio {weightScheme}, numPRD = {numPRD}, N = {N}, alpha = {alpha}, beta = {beta}; qid = {qid}, retrieving & writing ...', end=' ')

        # getting the top pseudo relevant docs using the searcher
        scoreDocs = searcher.search(query, numPRD).scoreDocs

        # Rocchio expanded query retrieval
        modified_query = rocchio_TRF(query, qid, scoreDocs, searcher=searcher,
                                     N=N, alpha=alpha, beta=beta, gamma=gamma, weightScheme=weightScheme, tpd=tpd, tnd=tnd)

        # getting the top k search results using the searcher
        k = 1000
        scoreDocs = searcher.search(modified_query, k).scoreDocs

        # writing all k doc results in a .res file in TREC format
        rank = 0
        results = ''
        for scoreDoc in scoreDocs:
            rank += 1
            doc = searcher.doc(scoreDoc.doc)
            # f.write(f"{qid}\tQ0\t{doc.get('DOCID')}\t{rank}\t{scoreDoc.score}\taman_lmjm_{LAMBDA}-rocchio_{alpha}_{beta}\n")
            results += f"{qid}\tQ0\t{doc.get('ID')}\t{rank}\t{scoreDoc.score}\tBM25_{k1}-{b}-rocchio_{alpha}_{beta}\n"

        f.write(results)

        # print('complete!')

    f.close()
    # print('Search completed! Search results exported to a .res file in the current directory.\n')


In [25]:
numPRD = 1000
N = 80
alpha = 1
beta = 20
gamma = 6
tpd = 20
tnd = 6
bm25_rocchio(numPRD=numPRD, N=N, alpha=alpha, beta=beta, gamma=gamma, tpd=tpd, tnd=tnd, weightScheme='BM25')


[341670]
[337668, 156081, 418471, 311120, 241404, 347609, 171977, 171874, 239464, 391861, 264449]
[263791, 186136, 310265, 323741, 422056, 424047]
[397314, 314206, 191875, 261277]
[186144]
[431179, 433263, 426639, 368089, 432606, 176836, 353560, 338617, 319488, 369583, 370315, 369589, 333057, 359672, 348681, 196678, 351863, 353816, 327829, 327495, 342930, 348775, 256557, 325410, 370574, 333059, 359243, 406597, 358476, 426848, 434193, 176393, 376192, 335450, 369325, 255535, 410058, 154032, 319291, 353182, 177593, 331766, 257304, 381319, 346732, 410223, 425246, 161990, 173504]
[315413, 409355, 240866, 361763, 398225, 305117, 166340, 432674, 156305, 170809, 427561, 157692, 366170, 421449, 383714, 391134, 202112, 361762, 186886, 339768, 358891, 418910, 389834, 167151, 421303, 310738, 401781, 260078, 161935, 375257, 419257, 191270, 403278, 372517, 182480, 359548, 155678, 156949, 159212, 189273, 158638, 303762, 162289, 156958, 402434, 165696, 315018, 156727, 242474, 398144, 361225]
[169393, 