In [1]:
import xml.etree.ElementTree as Et
from tqdm import tqdm
from collections import defaultdict
from pprint import pprint
import subprocess

import numpy as np
import pandas as pd
from sklearn.model_selection import KFold
from sklearn.metrics import ndcg_score


In [2]:
from pyserini.search import pysearch
from pyserini.search import pyquerybuilder
from pyserini.index import pyutils
from pyserini.analysis import pyanalysis
from pyserini.pyclass import autoclass
from pyserini.analysis.pyanalysis import get_lucene_analyzer

#Mirrors of old indices are archived here
#https://github.com/castorini/anserini/blob/master/docs/experiments-cord19.md
index_loc = '/home/tmschoegje/Desktop/caos-19/lucene-index-cord19-paragraph-2020-05-19/'
searcher = pysearch.SimpleSearcher(index_loc)
index_utils = pyutils.IndexReaderUtils(index_loc)

#Additionally, you need the metadata.csv of the corresponding index, which is included in the CORD-19 releases
#https://ai2-semanticscholar-cord-19.s3-us-west-2.amazonaws.com/historical_releases.html
metadatafile = "/home/tmschoegje/Desktop/caos-19/metadata.csv"


docidfile = '/home/tmschoegje/Desktop/caos-19/trecdata/docids-rnd2.txt'
topicsfile = "/home/tmschoegje/Desktop/caos-19/trecdata/topics-rnd2.xml"
qrelname = "/home/tmschoegje/Desktop/caos-19/trecdata/qrels-rnd2.txt"

In [3]:
#manual topic classification into tasks
rnd3classes = [2, 0, 3, 0, 3, 7, 7, 7, 6, 5, 4, 5, 0, 0, 0, 0, 4, 5, 5, 1, 0, 1, 1, 1, 1, 7, 7, 3, 3, 3, 2, 2, 3, 3, 9, 2, 2, 3, 3, 2]
rnd3confidence = [1, 1, 1, 0.5, 0.5, 1, 1, 0.5, 0.5, 1, 0, 0.5, 1, 0, 1, 1, 0.75, 0.25, 0.25, 0.5, 0.5, 0.75, 0.75, 0.75, 0.75, 1, 0.5, 1, 1, 1, 0, 1, 1, 0, 0, 0.5, 1, 0, 0, 1]

In [9]:
#Used to read in metadata for the docids this round
def prepTREC(fname):
    #get valid TREC ids for this round
    TRECids = []
    f = open(fname)
    for line in f.readlines():
        if line[-1] == '\n':
            line = line[:-1]
        TRECids.append(line)
    f.close()
    
    metadata = pd.read_csv(metadatafile)
    #now we filter all TREC ids we don't need
    metadata = metadata[metadata.cord_uid.isin(TRECids)]
    metadata.drop_duplicates(subset='cord_uid', keep='first', inplace=True)
    
    return metadata

#Used to read journal priors from the doc
def prepJournals(fname):
    f = open(fname)
    journals = dict()
    for line in f.readlines():
        if line[-1] == '\n':
            line = line[:-1]
        ls = line.split(" ")
        journals[ls[1]] = ls[0]
    f.close()

    return journals

#Used to get a specific journal's prior value
#cord_uid is id of document, metadata contains metadata.csv, 
#journals is a list of journals from journalpriors.txt (see prepJournals)
def getJPrior(cord_uid, metadata, journals):
    #Get journal for this item
    journal = metadata[metadata['cord_uid'] == cord_uid]['journal']
    if journal.to_string(index=False).strip() in journals:
        return journals[journal.to_string(index=False).strip()]
    else:
        #if we have no knowledge, we assume the relevance is 0 (neutral)
        return 0

#Used to read TREC topics
def readTopics(fname):
    root = Et.parse(fname).getroot()
    topics = []
    for num, topic in enumerate(root):
        #print(topic[0].text) #query
        topics.append([topic[0].text, rnd3classes[num]])
        #print(topic[1].text) #question
        #print(topic[2].text) #narrative
    return topics

#Used to read in a run's ranking
def readAnserini(fname):
    res=[]
    f = open(fname)
    #f.readline()
    for line in f.readlines():
        vals = line.strip().split(" ")
        #topic, rank, cord_id, score
        res.append([vals[0], vals[3], vals[2], vals[4]])
    return res

#Used to prepare results in submission format
def writeBM25results(results, runtitle):
    f = open(runtitle, "w")
    #topic, rank, cord_id, score
    for result in results:
        #print(result)
        f.write(result[0] + " Q0 " + result[2] + " 1 " + str(result[3]) + " " + runtitle + "\n")
    f.close()
    

#read the qrels file
def getqrels(fname):
    qrels = []
    for line in open(fname).readlines():
        vals = line.strip().split(" ")
        #topic, cord_uid, qrel, assessround
        qrels.append([int(vals[0]), vals[3], float(vals[4]), float(vals[1])])

    qrels = np.array(qrels, dtype="O")
    return qrels

#find qrel for a cord uid
def get_qrel(cord_uid, topic_id, qrels):

    topicrels = qrels[qrels[:,0] == topic_id]

    qrel_uids = [qrel[1] for qrel in topicrels]
    #print(qrel_uids.index(cord_uid))
    index = qrel_uids.index(cord_uid)
    #print(qrels[index,2])
    if(qrels[index,2] > 0):
        return 1
    else:
        return 0

In [17]:
#ndcg after filtering unknown docs - sakai 2007 says this is more stable than bpref

#Note: it's nicer to do NDCG over all known qrels. 
#Implementation of this was limited - so we only considered the qrels in the top 30k documents

def ndcg(runname, qrelname):
    qrels = getqrels(qrelname)
    
    preds = []
    #first parse predictions
    for line in open(runname).readlines():
        #topic, unused, cord_uid, rank, score, runname
        vals = line.strip().split(" ")
        #topic, cord_uid, score, rank
        preds.append([int(vals[0]), vals[2], float(vals[4]), int(vals[3])])
        
    #print('hi')
    #print(len(preds))
    #print(len(qrels))
    knownpreds = []
    
    for num, pred in enumerate(preds):
        #get qrels for the given topic
        qrels_topic = qrels[qrels[:,0] == pred[0]]
        qrel_uids = [val[1] for val in qrels_topic]

        #filter all preds not in qrels
        if(pred[1] in qrel_uids):
            #add known prediction with predicted score and real score
            knownpreds.append([pred[0], pred[1], pred[2], get_qrel(pred[1], pred[0], qrels_topic), pred[3]])
            
    knownpreds = np.array(knownpreds, dtype="O")
    
    #TODO update for round2+ topics
    ndcgs = []
    for t in range(1, 31):
        knownpreds_t = knownpreds[knownpreds[:,0] == t]
        
        #cross validation on knownpreds_t
        #If this topic has at least 5 documens with known qrels, we compute it using 5x cross validation
        #Otherwise, we ignore the ndcg for this topic. Afterwards, average for all topics
        
        n_splits = 5
        if(len(knownpreds_t) > 5):
            kf = KFold(n_splits)
            for train_index, test_index in kf.split(knownpreds_t):
        
                sortedqrel = []
                sortedqpred = []
                for pred_ind, pred in enumerate(knownpreds_t):
                    if pred_ind in train_index:
                        #get corresponding pred's qrel
                        sortedqrel.append(pred[3])#get_qrel(pred[1], pred[0], qrels))
                        #ground truth
                        sortedqpred.append(pred[2])        
            
                if(len(sortedqrel) > 1):
                    ndcgs.append(ndcg_score(np.asarray([sortedqrel]), np.asarray([sortedqpred]), k=10))
                else:
                    print('how did i get here')
           
        else:
            pass           
        
    return np.mean(ndcgs)

In [19]:
# Rerank by journal
def rerank(results, topics, mixer, journals):
    metadata = prepTREC(docidfile)
    
    jscores = []
    scores = []
    for result in results:
        jscores.append(getJPrior(result[2], metadata, journals))
        scores.append(float(result[3]))
    
    for i, val in enumerate(jscores):
        results[i][3] = mixer * float(jscores[i]) + scores[i]
    
    #Some ugly/quick sorting
    def sort_key0(item):
        return item[3]
    def sort_key1(item):
        return item[0]

    results = sorted(results, key=sort_key0, reverse=True)
    results = sorted(results, key=sort_key1, reverse=False)
    
    return results

# Let's see what linear combination between the run score these two values makes sense
for m3 in np.linspace(0, 0.5, 10):
    
    #This is currently the best run. Differs from the submitted runfile because it is a longer list of ranked
    #qrels that we can use to tune with ndcg
    #results = readAnserini('/home/tmschoegje/Desktop/caos-19/runs/testrun-best-rnd3.run')
    
    #Currently testing on the baseline (using query terms)
    results = readAnserini('/home/tmschoegje/Desktop/caos-19/runs/testrun-baseline-rnd3.run')
    
    #print(len(results))
    journals = prepJournals('/home/tmschoegje/Desktop/caos-19/round3/journalpriors.txt')
    results_reranked = rerank(results, readTopics(topicsfile), m3, journals)
    writeBM25results(results_reranked, "/home/tmschoegje/Desktop/caos-19/runs/testrun-" + str(m3) + '.run')
    print(ndcg("/home/tmschoegje/Desktop/caos-19/runs/testrun-" + str(m3) + '.run', qrelname))

  exec(code_obj, self.user_global_ns, self.user_ns)


0.5225502184047945


  exec(code_obj, self.user_global_ns, self.user_ns)


0.5226202172483354


  exec(code_obj, self.user_global_ns, self.user_ns)


0.5357320329022079


  exec(code_obj, self.user_global_ns, self.user_ns)


0.5268356431706291


  exec(code_obj, self.user_global_ns, self.user_ns)


0.5158641439101874


  exec(code_obj, self.user_global_ns, self.user_ns)


0.508383664126815


  exec(code_obj, self.user_global_ns, self.user_ns)


0.49849152805133884


  exec(code_obj, self.user_global_ns, self.user_ns)


0.4995929562186969


  exec(code_obj, self.user_global_ns, self.user_ns)


0.4882150231210602


  exec(code_obj, self.user_global_ns, self.user_ns)


0.491663131643002
