# Query Reformulation with A2A
This example shows how A2A facilitates query reformulation research. It implements the approach outlined by Lin et al in Trec2021, using docTTTTTquery from https://cs.uwaterloo.ca/~jimmylin/publications/Nogueira_Lin_2019_docTTTTTquery-v2.pdf.

In [None]:
%pip install --quiet sentencepiece transformers

In [None]:
import os
import sys
import re
import requests
import argparse
import warnings
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import torch
import sentencepiece
from transformers import T5Tokenizer, T5ForConditionalGeneration
device = "cuda" if torch.cuda.is_available() else "cpu"

import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
sw_nltk = set(stopwords.words('english') + ['presents', 'presented', 'patient', 'show', 'shows', 'year', 'yo', 'old'])
punct=set(",..?()/\-+'\"")

# set seed for reproducibility
torch.manual_seed(0)
np.random.seed(0)

# hyperparameter
n_queries = 10
k = 60

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


## Data Wrangling Functions and Model Class Definition

In [None]:
def clean_topic(topic):
    """remove punctuation from the topic"""
    return re.sub(r"[^\w\s.]", " ", topic)

def store_topic_as_xml(topics, q_index=0):
    """Convert list of topics to XML for A2A format, then write to file. q_index: index of generated query to use"""
    xml = '<topics task=\"2021 TREC Clinical Trials\">\n'
    for i in range(len(topics)):
        xml += '\t<topic number="{}">\n'.format(i+1)
        xml += '\t\t<user_query>{}</user_query>\n'.format(clean_topic(topics[i][q_index]))
        xml += '\t</topic>\n'
    xml += '</topics>\n'
    with open("reformulated_topics_[{}].xml".format(q_index), "w", encoding="utf-8") as file:
        file.write(xml)

def generate_queries(topics, n_queries=1, remove_stopwords=False):
    """Generate the query replacements for each topic"""
    model = docT5query()
    new_topics = []
    for topic in topics:
        clean_queries, queries = model.get_queries(topic["description"], top_k=10, num_queries=n_queries)
        
        # replace with the expansion terms to remove stopwords and puncutation
        if remove_stopwords:
            clean_query_strs = []
            for terms in clean_queries:
                query = " ".join(terms)
                if len(query) == 0: # empty queries cause an error, they need to be replaced
                    query = topic["description"]
                clean_query_strs.append(query)
        else:
            clean_query_strs = queries
        
        new_topics.append([topic["description"]] + clean_query_strs)
    return new_topics

def fuse_rankings(rankings, n_topics, n_voters=1, k=60):
    """Calculate an overall rank for each topic using reciprocal rank fusion https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf"""
    query_dfs = []
    for query in range(n_topics):
        # Create a table of rows: rankings for each document by the columns: voters
        df = pd.DataFrame()
        for voter in range(n_voters):
            # Create a series with index doc id, value: rank
            ranked_ids = list(map(lambda row: row[0], rankings[voter][str(query + 1)])) # the document ids in a list sorted by rank (rankings are already sorted)
            ranks = pd.DataFrame(ranked_ids, columns=["doc_id"]).set_index("doc_id") # create dataframe with ids as the index
            ranks["rank_{}".format(voter)] = np.arange(ranks.shape[0]) + 1
            
            # Append rows to dataframe using outer join
            df = pd.concat([df, ranks], axis=1, join="outer")
            df = df.fillna(10000) # any document that wasn't in the top 1000 ranks for a voter will have a very low score

        # Calculate rank fusion score for this topic
        recip_rank = (1/(k + df))
        rff_score = recip_rank.sum(axis=1)
        df["fusion_score"] = rff_score

        # add a numerical rank column
        df = df.sort_values(by="fusion_score", ascending=False)
        df["fusion_rank"] = np.arange(1, df.shape[0] + 1)
        query_dfs.append(df)
    return query_dfs

In [None]:
# Model Definition
class docT5query(torch.nn.Module):
    def __init__(self, pretrained_model_path="castorini/doc2query-t5-base-msmarco", max_input_size=512, max_output_size=64):
        super().__init__()
        self.tokenizer = T5Tokenizer.from_pretrained("castorini/doc2query-t5-base-msmarco")
        self.model = T5ForConditionalGeneration.from_pretrained(pretrained_model_path).to(device)
        self.max_input_size = max_input_size
        self.max_output_size = max_output_size

    def get_queries(self, doc_text, top_k=10, num_queries=10):
        # Generate the queries using top_k sampling
        with torch.no_grad():
            input_ids = self.tokenizer.encode(doc_text, return_tensors='pt', truncation=True, max_length=self.max_input_size).to(device)
            outputs = self.model.generate( # Generates using topk sampling until it reaches EOS token
                input_ids=input_ids,
                max_length=self.max_output_size,
                do_sample=True,
                top_k=top_k,
                num_return_sequences=num_queries)

        # Clean them, removing punctuation and stopwords
        clean_queries = []
        queries = []
        for i in range(num_queries):
            sent = self.tokenizer.decode(outputs[i], skip_special_tokens=True)
            queries.append(sent)
            sentence = [m.group().lower() for m in re.finditer(r"[^\W\d_]+|\d+|\S", sent)] # split sentence into tokens
            clean_queries.append([s.lower() for s in sentence if s.lower() not in sw_nltk and s not in punct]) # keep the clean tokens
        
        return clean_queries, queries

In [None]:
# Evaluation Functions
def df_to_trec(df, topic_id=0, file_name=None, rank_col="rank", run_name="ct2021_test"):
    """
    convert a single dataframe to the trec results format for scoring. 
    if file_name is passed, write the results to file, otherwise return the dataframe
    """
    # Convert the ranking to a score, so other ranking columns can be used
    trec_df = df.copy()
    assert not np.any(trec_df[rank_col] == 0), "ranks starting at 0 cause a division by zero error!"
    trec_df["score"] = 1/trec_df[rank_col]
    trec_df = trec_df.sort_values("score", ascending=False)
    trec_df = trec_df[:1000]

    # Add columns in the TREC Format
    trec_df["Q0"] = "Q0"
    trec_df["run_name"] = run_name
    trec_df["doc_id"] = trec_df.index
    trec_df["topic_id"] = topic_id + 1
    trec_df = trec_df.rename(columns={"doc_id": "document_id"})

    trec_df = trec_df.reset_index()
    trec_df = trec_df.reindex(columns=["topic_id", "Q0", "document_id", rank_col, "score", "run_name"]) # reorder the columns
    
    if file_name is not None:
        trec_df.to_csv(file_name, sep="\t", header=False, index=False)
    return trec_df

def dfs_to_trec(dfs, file_name, rank_col="fusion_rank", run_name="ct2021_test"):
    """Takes in a list of query dfs and converts them to the TREC results fromat for scoring, note the dfs must be in the same order as the topic id"""
    results = []
    for topic_id, df in enumerate(dfs):
        results.append(df_to_trec(df, topic_id=topic_id, rank_col=rank_col, run_name=run_name))
    
    results = pd.concat(results, axis=0)
    results.to_csv(file_name, sep="\t", header=False, index=False)

## Reformulate Queries, then Perform and Evaluate Retrieval

In [None]:
# Reformulate Queries
topics = requests.get('https://a2a.csiro.au/api/topics/ct2021').json() # Download topics
new_topics = generate_queries(topics, n_queries=n_queries)

# Perform retrieval for all reformulated queries
rankings = []
for i in range(n_queries + 1): # the original query is also used
    store_topic_as_xml(new_topics, q_index=i)
    topic_file = {'file': open('reformulated_topics_[{}].xml'.format(i),'rb')} # open the topic xml as binary
    response = requests.post('https://a2a.csiro.au/api/bm25/ct2021?t=$user_query', files=topic_file) # Retrieve using BM25 on the reformulated queries
    rankings.append(response.json()['rankings'])
query_dfs = fuse_rankings(rankings, len(topics), n_voters=n_queries + 1, k=k) # aggregate rankings of each reformulation with rank fusion

# Evaluate rankings
dfs_to_trec(query_dfs, 'reformulated_fusion.results')
files = {'file': open('reformulated_fusion.results', 'rb')}
reranked_results_graded = requests.post('https://a2a.csiro.au/api/eval/ct2021', files=files).json()
print(f"Results: {reranked_results_graded}")

Results: {'all': {'bpref': 0.4487, 'ndcg_cut_10': 0.3035}, 'per_topic': {'1': {'bpref': 0.7612, 'ndcg_cut_10': 0.3471}, '2': {'bpref': 0.7671, 'ndcg_cut_10': 0.4525}, '3': {'bpref': 0.4113, 'ndcg_cut_10': 0.0331}, '4': {'bpref': 0.4427, 'ndcg_cut_10': 0.5965}, '5': {'bpref': 0.6014, 'ndcg_cut_10': 0.7631}, '6': {'bpref': 0.3073, 'ndcg_cut_10': 0.1625}, '7': {'bpref': 0.3757, 'ndcg_cut_10': 0.1635}, '8': {'bpref': 0.6033, 'ndcg_cut_10': 1.0}, '9': {'bpref': 0.5629, 'ndcg_cut_10': 0.1973}, '10': {'bpref': 0.7542, 'ndcg_cut_10': 0.4619}, '11': {'bpref': 0.7616, 'ndcg_cut_10': 0.5}, '12': {'bpref': 0.6668, 'ndcg_cut_10': 0.4186}, '13': {'bpref': 0.3848, 'ndcg_cut_10': 0.2646}, '14': {'bpref': 0.1871, 'ndcg_cut_10': 0.11}, '15': {'bpref': 0.2556, 'ndcg_cut_10': 0.4905}, '16': {'bpref': 0.2488, 'ndcg_cut_10': 0.2201}, '17': {'bpref': 0.6384, 'ndcg_cut_10': 0.055}, '18': {'bpref': 0.5211, 'ndcg_cut_10': 0.5347}, '19': {'bpref': 0.5575, 'ndcg_cut_10': 0.3293}, '20': {'bpref': 0.2782, 'ndcg_cut