In [1]:
%%capture
!pip install pyserini==0.9.4.0

import os
import nltk
import sys
import time
import pandas as pd
from pyserini.search import get_topics
from pyserini.search import SimpleSearcher
from pyserini.search import querybuilder
from nltk.corpus import stopwords
from nltk.corpus import wordnet
from IPython.display import clear_output

nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('universal_tagset')
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-amd64"

In [2]:
# Get Robust04 Dataset ~2 min
%%capture
!wget https://git.uwaterloo.ca/jimmylin/anserini-indexes/raw/master/index-robust04-20191213.tar.gz
# Backup URL: https://www.dropbox.com/s/s91388puqbxh176/index-robust04-20191213.tar.gz
!tar xvfz index-robust04-20191213.tar.gz

In [3]:
# Get MsMarcoPassage Dataset ~2 min
%%capture
!wget https://git.uwaterloo.ca/jimmylin/anserini-indexes/raw/master/index-msmarco-passage-20191117-0ed488.tar.gz
!tar xvfz index-msmarco-passage-20191117-0ed488.tar.gz

In [4]:
# Get MsMarcoDocument Dataset ~20 min
%%capture
!wget https://git.uwaterloo.ca/jimmylin/anserini-indexes/raw/master/index-msmarco-doc-20201117-f87c94.tar.gz
!tar xvfz index-msmarco-doc-20201117-f87c94.tar.gz

In [5]:
# Sanity check of Robust04: 2.1G
!du -h index-robust04-20191213

# Sanity check of MsMarcoPassage: 2.5G
!du -h index-msmarco-passage-20191117-0ed488

#Sanity check of MsDocPassage: 16G
!du -h index-msmarco-doc-20201117-f87c94

2.1G	index-robust04-20191213
2.5G	index-msmarco-passage-20191117-0ed488
16G	index-msmarco-doc-20201117-f87c94


In [6]:
# Write to .txt file to store analysis
def setStdOutToFile():
  old_stdout = sys.stdout
  writer = open('stdout.txt', 'a')
  sys.stdout = writer
  return writer, old_stdout

# Close writer and reset stdout
def resetStdOut(writer, old_stdout):
  writer.close()
  sys.stdout = old_stdout

# Clear the output of a code block for nicer notebook
def clearOutput():
  clear_output(wait=True)
  print("", flush=True)

In [7]:
# Change the variable here to change the query expansion method
# 1 = Control Condition
# 2 = Wordnet-based Expansion
# 3 = RM3
QEMethod = 1

In [8]:
def build_query(query, limit=0, pos=True):
    """Expand the query.
        Parameters
        ----------
        query : str
            Query string.
        limit : int
            Determines the maximum amount of word expansions per query term, 
            not restricted if limit=0.
        pos: bool
            Determines whether or not to apply part of speech tagging 
            to the query expansion algorithm.
        Returns
        -------
        str
            Expanded query
    """
    words = query.split()
    tagged_query = nltk.pos_tag(words, tagset='universal')
    stop_words = set(stopwords.words('english'))
    filtered_words = [w for w in words if not w in stop_words]
    filtered_tagged_words = [w for w in tagged_query if not w[0] in stop_words]
    expanded_words = set()

    for word in filtered_tagged_words:
      expanded_words.add(word[0])
      starting_length = len(expanded_words)
      for syn in wordnet.synsets(word[0]):
        for l in syn.lemmas():
          if l.name().lower() not in stop_words:
            synonym = l.name()
            if pos:
              if limit == 0 or len(expanded_words) < starting_length + limit:
                tagged_synonym = nltk.pos_tag(nltk.word_tokenize(synonym), tagset='universal')
                if word[1] == tagged_synonym[0][1]:
                  expanded_words.add(l.name())
            else:
              expanded_words.add(l.name()) 

    new_query = ""
    for word in expanded_words:
      new_query = new_query + " " + word
    return new_query

In [9]:
#TODO: Implement RM3 query expansion

#TODO: Implement WordNet query expansion

#TODO: Control

#This means we need to build QueryGenerators
#IE: searcher.search(query, 1000, query_generator=wordnet_generator) searcher.search(query, 1000, query_generator=rm3)
#For the wordnet_generator we need to make a new WordNetGenerator.java file which expands upon the BagOfWordsQueryGenerator.java file in 
#https://github.com/castorini/anserini/blob/master/src/main/java/io/anserini/search/query/BagOfWordsQueryGenerator.java

def run_all_queries(file, topics, searcher):
    with open(file, 'w') as runfile:
        cnt = 0
        print('Running {} queries in total'.format(len(topics)))
        for id in topics:        
            query = topics[id]['title']

            if (QEMethod == 1 or QEMethod == 3):
              # FOR CONTROL CONDITION:
              hits = searcher.search(query, 1000)
            
            if (QEMethod == 2):
              # FOR WORDNET EXPANSION:
              new_query = build_query(query, limit=1, pos=True) 
              hits = searcher.search(new_query, 1000)

            for i in range(0, len(hits)):
                _ = runfile.write('{} Q0 {} {} {:.6f} Anserini\n'.format(id, hits[i].docid, i+1, hits[i].score))
            cnt += 1
            if cnt % 100 == 0:
                print(f'{cnt} queries completed')

In [10]:
##### Robust04 ##### ~21 sec
start = time.perf_counter()
searcher = SimpleSearcher('index-robust04-20191213')
if (QEMethod == 3):
  searcher.set_rm3(10, 10, 0.5)
topics = get_topics('robust04')
run_all_queries('run-robust04-bm25.txt', topics, searcher)
!wget -O jtreceval-0.0.5-jar-with-dependencies.jar https://search.maven.org/remotecontent?filepath=uk/ac/gla/dcs/terrierteam/jtreceval/0.0.5/jtreceval-0.0.5-jar-with-dependencies.jar
!wget https://raw.githubusercontent.com/castorini/anserini/master/src/main/resources/topics-and-qrels/qrels.robust04.txt
writer, old_stdout = setStdOutToFile()
print("Robust04")
print("time                  \tall\t", round(time.perf_counter()-start)) #Timer in seconds
!java -jar jtreceval-0.0.5-jar-with-dependencies.jar qrels.robust04.txt run-robust04-bm25.txt
resetStdOut(writer, old_stdout)
clearOutput()




In [11]:
##### MsMarcoPassage ##### ~11 min 
start = time.perf_counter()
searcher = SimpleSearcher('index-msmarco-passage-20191117-0ed488')
if (QEMethod == 3):
  searcher.set_rm3(10, 10, 0.5)
topics = get_topics('msmarco_passage_dev_subset')
run_all_queries('run-msmarco-passage-bm25.txt', topics, searcher)
!wget -O jtreceval-0.0.5-jar-with-dependencies.jar https://search.maven.org/remotecontent?filepath=uk/ac/gla/dcs/terrierteam/jtreceval/0.0.5/jtreceval-0.0.5-jar-with-dependencies.jar
!wget https://raw.githubusercontent.com/castorini/anserini/master/src/main/resources/topics-and-qrels/qrels.msmarco-passage.dev-subset.txt
writer, old_stdout = setStdOutToFile()
print("MsMarcoPassage")
print("time                  \tall\t", round(time.perf_counter()-start)) #Timer in seconds
!java -jar jtreceval-0.0.5-jar-with-dependencies.jar qrels.msmarco-passage.dev-subset.txt run-msmarco-passage-bm25.txt
resetStdOut(writer, old_stdout)
clearOutput()




In [None]:
##### MsMarcoDoc ##### ~120 min 
start = time.perf_counter()
searcher = SimpleSearcher('index-msmarco-doc-20201117-f87c94')
if (QEMethod == 3):
  searcher.set_rm3(10, 10, 0.5)
topics = get_topics('msmarco_doc_dev')
run_all_queries('run-msmarco-doc-bm25.txt', topics, searcher)
!wget -O jtreceval-0.0.5-jar-with-dependencies.jar https://search.maven.org/remotecontent?filepath=uk/ac/gla/dcs/terrierteam/jtreceval/0.0.5/jtreceval-0.0.5-jar-with-dependencies.jar
!wget https://raw.githubusercontent.com/castorini/anserini/master/src/main/resources/topics-and-qrels/qrels.msmarco-doc.dev.txt
writer, old_stdout = setStdOutToFile()
print("MsMarcoDoc")
print("time                  \tall\t", round(time.perf_counter()-start)) #Timer in seconds
!java -jar jtreceval-0.0.5-jar-with-dependencies.jar qrels.msmarco-doc.dev.txt run-msmarco-doc-bm25.txt
resetStdOut(writer, old_stdout)
clearOutput()

Running 5193 queries in total
100 queries completed


In [None]:
# Convert .txt to table and print it 
resetStdOut(writer, old_stdout)
stdout_file = open('stdout.txt', 'r') 
lines = stdout_file.readlines() 
  
count = 0
datasets = ["Robust04", "MsMarcoPassage"]
column_names = ["Dataset", "MAP", "Recip Rank", "P@5", "Num Rel", "Num Rel Ret" , "Time"]
keep_lines = [2, 6, 7, 8, 12, 24]
max_line = 32
df = pd.DataFrame(columns = column_names)

metric_list = []
dataset_name = ""

# Strips the newline character 
for line in lines: 
  if line.strip() in datasets:
    count = 0
    metric_list = []
    dataset_name = line.strip()
  count +=1 
  if count in keep_lines:
    line = line.strip().split()[2]
    metric_list.append(float(line))
  if count == max_line:
    temp_dict = {'Dataset': dataset_name,
                 'MAP': metric_list[3],
                 'P@5': metric_list[5],
                 'Recip Rank': metric_list[4],
                 'Num Rel': metric_list[1],
                 'Num Rel Ret': metric_list[2],
                 'Time': metric_list[0]} # seconds
    df = df.append(temp_dict, ignore_index=True)

df

# New Section