<a href="https://colab.research.google.com/github/grill-lab/CAsT-Demo/blob/main/Interactive_CAsT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction

## Overview

[The Conversational Assistance Track (CAsT)](www.treccast.ai) is run as part of the Text REtrieval Conference (TREC) workshop series. CAsT aims to advance research in conversational search systems by building a reusable benchmark for information-centric dialogues on which such systems can be evaluated and compared.

Successful CAsT systems try to overcome conversational search challenges that center on information ranking in context. Specifically, these systems:

* **Understand dialogue context** by tracking the evolution of information needs in a conversation, and identifying the salient information needed for the current turn in the conversation.

* **Retrieve Candidate Response Information** by performing retrieval over a large collection of paragraphs (or knowledge base content) to identify relevant information

## Objective

In this demo, we'll explore and build the components of a simple CAsT system. These components include a:


*   Query Rewriter
*   Document Retriever, and 
*   Passage Reranker

We'll evaluate this system on a subset of the document collection used in the third edition of CAsT.

## System Architecture

<center>
<img src="https://raw.githubusercontent.com/grill-lab/CAsT-Demo/main/assets/system_architecure.png" width="85%"/>
</center>

The diagram above shows how the components of our system interact. 

Given a query and conversation context, our system's ***Query Rewriter*** reformulates the query to resolve ambiguity. Next, the ***Document Retriever*** uses the reformulated query to retrieve the top N candidate documents from an index. Finally, passages are extracted from the candidate documents and ranked by the ***Passage Reranker***. The output of our system is a ranked list of relevant passages for the input query, based on the conversation context.


> **NOTE:** What we're building is the backend system of the Interactive [CAsT Searcher](http://3.94.55.111:5000/) interface. 

# Setup

Before putting our system together, let's download the benchmark, relevance judgements, and document collection subset.

> **NOTE:** Downloading and processing the entire CAsT year 3 document collection will take a while and likely use up the available resources in a free Colab instance. So We have preprocessed 20,000 documents from KILT and MARCO (as using WaPo requires a licence agreement) from the collection and made them available to you as compressed `.jsonlines` files for this tutorial. As a result of this, for many queries we will explore in this tutorial, we may not find very relevant documents. If you would like to process the collection yourself, follow the instructions in the [trec tools repository](https://github.com/grill-lab/trec-cast-tools).

In [1]:
!echo "Creating target directory.."
!mkdir -p files

!echo "Downloading Year 3 relevance judgements.."
!wget -c https://raw.githubusercontent.com/daltonj/treccastweb/master/2021/trec-cast-qrels-docs.2021.qrel -P files

!echo "Downloading Year 3 beanchmark.."
!wget -c https://raw.githubusercontent.com/daltonj/treccastweb/master/2021/2021_manual_evaluation_topics_v1.0.json -P files

!echo "Downloading Year 3 collection subset.."
!wget https://cast-y4-collection.s3.amazonaws.com/cast_tutorial_collection_subset.tar.gz -P files

!echo "Extracting collection subset from gzipped file.."
!tar -xvzf files/cast_tutorial_collection_subset.tar.gz

Creating target directory..
Downloading Year 3 relevance judgements..
--2022-07-10 21:24:15--  https://raw.githubusercontent.com/daltonj/treccastweb/master/2021/trec-cast-qrels-docs.2021.qrel
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 507215 (495K) [text/plain]
Saving to: ‘files/trec-cast-qrels-docs.2021.qrel’


2022-07-10 21:24:16 (21.3 MB/s) - ‘files/trec-cast-qrels-docs.2021.qrel’ saved [507215/507215]

Downloading Year 3 beanchmark..
--2022-07-10 21:24:16--  https://raw.githubusercontent.com/daltonj/treccastweb/master/2021/2021_manual_evaluation_topics_v1.0.json
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|

# Index Generation

<center>
<img src="https://raw.githubusercontent.com/grill-lab/CAsT-Demo/main/assets/index_generation.png" width="85%"/>
</center>


Now, we'll use the [Pyserini information retrieval toolkit](https://github.com/castorini/pyserini) to build sparse and dense indices from the Year 3 collection subset we just downloaded. Pyserini provides APIs for our indexing needs and supports retrieval using dense representations. A, you can also look at [PyTerrier](https://github.com/terrier-org/pyterrier) from the University of Glasgow, which has implementations of some of the newest retrieval methods.



In [2]:
# install dependencies
!pip install pyserini==0.16.0
!pip install faiss-cpu

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pyserini==0.16.0
  Downloading pyserini-0.16.0-py3-none-any.whl (84.6 MB)
[K     |████████████████████████████████| 84.6 MB 152 kB/s 
Collecting onnxruntime>=1.8.1
  Downloading onnxruntime-1.11.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.2 MB)
[K     |████████████████████████████████| 5.2 MB 51.6 MB/s 
Collecting sentencepiece>=0.1.95
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 54.0 MB/s 
Collecting transformers>=4.6.0
  Downloading transformers-4.20.1-py3-none-any.whl (4.4 MB)
[K     |████████████████████████████████| 4.4 MB 60.7 MB/s 
[?25hCollecting pyjnius>=1.4.0
  Downloading pyjnius-1.4.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[K     |████████████████████████████████| 1.3 MB 65.7 MB/s 
Collecting lightgbm>=3.3.2
 

### Preprocessing

Before indexing, we need to adapt the format of the collection into one that Pyserini can ingest for sparse and dense indexing. Currently our documents are in the following format, with a `contents` field holding the document text and it's passage splits:

```
{
  "id": "doc_1",
  "contents": [{'id': 'passage_1', 'body': 'passage text'}, {'id': 'passage_2', 'body': 'passage text'}]
}
```

However, we need to combine all the passage texts and save that to `contents` field, and keep the passage splits information in a separate field. After this, our documents will look as follows:

```
{
  "id": "doc_1",
  "contents": "document text",
  "passage_splits": [{'id': 'passage_1', 'body': 'passage text'}, {'id': 'passage_2', 'body': 'passage text'}]
}
```

Visit the [Pyserini](https://github.com/castorini/pyserini#sparse-indexes) repository to learn more about the document formats it is able to ingest.

In [3]:
import os
from pathlib import Path
import json
from tqdm import tqdm

def format_collection(collection_path, output_path):
  """
  Runs through all documents and reformats them for indexing 
  with Pyserini.
  """
  Path(output_path).mkdir(parents=True, exist_ok=True)
  # iterate through all files in directory
  for file in tqdm(os.listdir(collection_path)):
    # open file
    with open(f"{collection_path}/{file}") as document_jsonl_collection:
      file_basename = os.path.basename(file)
      with open(f"{output_path}/{file_basename}_reformated.jsonl", "a") as reformated_collection_file:
        for document in document_jsonl_collection:
          parsed_document = json.loads(document)
          parsed_document['passage_splits'] = parsed_document['contents']
          parsed_document['contents'] = " ".join(
              [passage['body'] for passage in parsed_document['contents']]
          ).replace('\n', ' ')

          # write to output file
          reformated_collection_file.write(json.dumps(parsed_document) + "\n")

In [4]:
format_collection("files/jsonlines/", "files/reformated_collection")

100%|██████████| 2/2 [00:06<00:00,  3.38s/it]


### Sparse Index Generation

Now, let's generate a sparse index from the reformated collection in `files/reformated_collection/`. The index will be stored in `files/index/sparse`.

In [5]:
# create an index from the reformated collection
# It should take about 30 seconds on 20k documents
!python -m pyserini.index.lucene \
  --collection JsonCollection \
  --input files/reformated_collection \
  --index files/index/sparse \
  --generator DefaultLuceneDocumentGenerator \
  --threads 8 \
  --storePositions --storeDocvectors --storeRaw

2022-07-10 21:25:09,371 INFO  [main] index.IndexCollection (IndexCollection.java:643) - Setting log level to INFO
2022-07-10 21:25:09,381 INFO  [main] index.IndexCollection (IndexCollection.java:646) - Starting indexer...
2022-07-10 21:25:09,381 INFO  [main] index.IndexCollection (IndexCollection.java:648) - DocumentCollection path: files/reformated_collection
2022-07-10 21:25:09,382 INFO  [main] index.IndexCollection (IndexCollection.java:649) - CollectionClass: JsonCollection
2022-07-10 21:25:09,382 INFO  [main] index.IndexCollection (IndexCollection.java:650) - Generator: DefaultLuceneDocumentGenerator
2022-07-10 21:25:09,383 INFO  [main] index.IndexCollection (IndexCollection.java:651) - Threads: 8
2022-07-10 21:25:09,383 INFO  [main] index.IndexCollection (IndexCollection.java:652) - Language: en
2022-07-10 21:25:09,383 INFO  [main] index.IndexCollection (IndexCollection.java:653) - Stemmer: porter
2022-07-10 21:25:09,384 INFO  [main] index.IndexCollection (IndexCollection.java:65

To check that our new sparse index works, let's try searching with it. The code below loads the index and searches for the term `albedo`. Feel free to change `albedo` to anything you like.

In [6]:
from pyserini.search.lucene import LuceneSearcher

sparse_searcher = LuceneSearcher('files/index/sparse')
search_term = 'albedo'
hits = sparse_searcher.search(search_term)

for i in range(len(hits)):
    print(f'{i+1:2} {hits[i].docid:4} {hits[i].score:.5f}')

 1 KILT_39 7.14260
 2 KILT_905872 5.86990
 3 MARCO_00_72080055 5.72480
 4 KILT_343225 5.50330
 5 KILT_1402629 4.54220
 6 KILT_207081 4.41020
 7 KILT_11372427 4.38890
 8 KILT_206525 4.32640
 9 MARCO_00_71963394 4.05210
10 MARCO_00_71999716 3.91280


We should see a ranked list of the ten most relevant documents and their scores for our input query. Let's see the contents of the best ranking document.

In [7]:
best_ranked_doc = sparse_searcher.doc(hits[0].docid)
parsed_doc = json.loads(best_ranked_doc.raw())
parsed_doc['contents']

' Albedo  Albedo () (, meaning \'whiteness\') is the measure of the diffuse reflection of solar radiation out of the total solar radiation received by an astronomical body (e.g. a planet like Earth).  It is dimensionless and measured on a scale from 0 (corresponding to a black body that absorbs all incident radiation) to 1 (corresponding to a body that reflects all incident radiation).   Surface albedo is defined as the ratio of radiosity to the irradiance (flux per unit area) received by a surface.  The proportion reflected is not only determined by properties of the surface itself, but also by the spectral and angular distribution of solar radiation reaching the Earth\'s surface.  These factors vary with atmospheric composition, geographic location and time (see position of the Sun).  While bi-hemispherical reflectance is calculated for a single angle of incidence (i.e., for a given position of the Sun), albedo is the directional integration of reflectance over all solar angles in a 

Recall that we also saved the passage splits for each document. This can be accessed as follows:

In [8]:
parsed_doc['passage_splits']

[{'body': " Albedo  Albedo () (, meaning 'whiteness') is the measure of the diffuse reflection of solar radiation out of the total solar radiation received by an astronomical body (e.g. a planet like Earth).  It is dimensionless and measured on a scale from 0 (corresponding to a black body that absorbs all incident radiation) to 1 (corresponding to a body that reflects all incident radiation).   Surface albedo is defined as the ratio of radiosity to the irradiance (flux per unit area) received by a surface.  The proportion reflected is not only determined by properties of the surface itself, but also by the spectral and angular distribution of solar radiation reaching the Earth's surface.  These factors vary with atmospheric composition, geographic location and time (see position of the Sun).  While bi-hemispherical reflectance is calculated for a single angle of incidence (i.e., for a given position of the Sun), albedo is the directional integration of reflectance over all solar angle

### Dense Index Generation

Now let's generate a dense index from the reformated collection in `files/reformated_collection/`. The index will be stored in `files/index/dense`.

In [9]:
# takes about 5 minutes
!python -m pyserini.encode \
  input   --corpus files/reformated_collection \
          --fields text \
  output  --embeddings files/index/dense \
          --to-faiss \
  encoder --encoder castorini/tct_colbert-v2-hnp-msmarco \
          --fields text \
          --batch 32 \
          --fp16

Downloading: 100% 559/559 [00:00<00:00, 540kB/s]
Downloading: 100% 418M/418M [00:07<00:00, 59.7MB/s]
Downloading: 100% 334/334 [00:00<00:00, 298kB/s]
Downloading: 100% 226k/226k [00:00<00:00, 926kB/s]
Downloading: 100% 112/112 [00:00<00:00, 99.0kB/s]
10000it [00:00, 18847.47it/s]
10000it [00:00, 13909.19it/s]
100% 625/625 [05:07<00:00,  2.03it/s]


And to test that our dense index works, let's try searching for the term `albedo` again. 

In [10]:
from pyserini.search import FaissSearcher

dense_searcher = FaissSearcher(
    'files/index/dense',
    'facebook/dpr-question_encoder-multiset-base'
)
hits = dense_searcher.search(search_term)

for i in range(0, 10):
    print(f'{i+1:2} {hits[i].docid:7} {hits[i].score:.5f}')

Downloading:   0%|          | 0.00/493 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/418M [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/dpr-question_encoder-multiset-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

 1 KILT_39 69.04213
 2 KILT_1805774 65.33679
 3 MARCO_00_43983796 64.97756
 4 KILT_1402274 64.88888
 5 KILT_12187433 64.87811
 6 KILT_905803 64.73970
 7 MARCO_00_19411205 64.67081
 8 KILT_1402670 64.61848
 9 MARCO_00_19393829 64.58357
10 KILT_1055837 64.57838


### Hybrid Search

Now that we have sparse and dense indices, we can try a hybrid retrieval method that leverages the sparse and dense reprentations of our document collection. We can do this with Pyserini's APIs. 

Hybrid retrieval assigns a new score to each document based on the retrieval scores from each of our retrieval method. The new `hybrid score` follows the following formula:

`hybrid_score = alpha * sparse_score + dense_score if not weight_on_dense else sparse_score + alpha * dense_score`

The retrieved documents are then reranked using these `hybrid_scores`.

In [11]:
from pyserini.search.hybrid import HybridSearcher

hybrid_searcher = HybridSearcher(dense_searcher, sparse_searcher)
hits = hybrid_searcher.search(search_term)

for i in range(0, 10):
    print(f'{i+1:2} {hits[i].docid:7} {hits[i].score:.5f}')

 1 KILT_39 69.75639
 2 KILT_1805774 65.72807
 3 MARCO_00_43983796 65.36884
 4 KILT_1402274 65.28016
 5 KILT_12187433 65.26939
 6 KILT_905872 65.16537
 7 MARCO_00_72080055 65.15086
 8 KILT_905803 65.13098
 9 KILT_343225 65.12871
10 MARCO_00_19411205 65.06209


Notice how the ordering of the retrieved documents differ across the different retrieval approaches. In general, `hybrid retrieval` works better than `dense retrieval`, which, in turn, performs better than `sparse retrieval`.

# Benchmark (not so deep) Deep Dive

The CAsT dataset (benchmark) is a collection of dialogues (**topics**). Each topic is identified by a `number` and comprises several **turns**. Let's take a closer look at a turn:

```
{
    "number": 1,
    "raw_utterance": "I just had a breast biopsy for cancer. What are the most common types?",
    "passage": "More research is needed. Types Breast cancer can be: Ductal carcinoma: This begins in the milk duct and is the most common type. Lobular carcinoma: This starts in the lobules. Invasive breast cancer is when the cancer cells break out from inside the lobules or ducts and invade nearby tissue, increasing the chance of spreading to other parts of the body. Non-invasive breast cancer is when the cancer is still inside its place of origin and has not broken out.",
    "manual_rewritten_utterance": "I just had a breast biopsy for cancer. What are the most common types of breast cancer?",
    "canonical_result_id": "MARCO_D59865",
    "passage_id": 7,
    "automatic_rewritten_utterance": "What are the most common types of cancer in regards to breast biopsy?"
}
```

* The `number` field identifies a turn within a topic. Globally, however, a turn 
is identified by a concatenation of the topic number and turn number i.e `106_1`.

* The `raw_utterance` refers to the user’s utterance or query to the system. A typical phenomenon with CAsT topics and real-life dialogues, in general, is that the user utterance becomes more and more ambiguous and faceted as the conversation progresses.

* The `passage` field contains a standard system response to the user’s query. These are carefully selected by the topic creators and are meant to serve as dialogue context for query rewriters.

* The `manual_rewritten_utterance` field is the user utterance rewritten by the topic creator without ambiguity. You can use this field to evaluate the quality of query rewrites or as input queries for manual systems

* Similarly, the `automatic_rewritten_utterance` is a baseline rewriting model's attempt at reformulating the user utterance without ambiguity. This field can be useful for comparing the output of your query rewriter.


The system we'll build in this demo will take a `turn`'s `raw_utterance`, understand it in relation to the topic history, and return passages relevant to that `turn`.

# Query Rewriting

CAsT topics mimic real-world dialogue phenomena. As a result, utterances within topics become increasingly ambiguous as the topic unfolds. On their own, these utterances likely won't return good candidates from our index, so we need to reformulate them using information from prior turns in the topic. 

<center>
<img src="https://raw.githubusercontent.com/grill-lab/CAsT-Demo/main/assets/query_rewriting.png" width="85%"/>
</center>

Let's examine the utterances in a topic to demonstrate the need for query rewriting.

In [12]:
import json

with open("files/2021_manual_evaluation_topics_v1.0.json") as cast_topics_file:
  topics = json.load(cast_topics_file)
  for topic in topics:
    print(f"Topic {topic['number']}")
    print("\n")
    for turn in topic['turn']:
      print(f"Turn {turn['number']}: {turn['raw_utterance']}")
      print(f"System Response: {turn['passage']}")
      print("\n")
    break

Topic 106


Turn 1: I just had a breast biopsy for cancer. What are the most common types?
System Response: More research is needed. Types Breast cancer can be: Ductal carcinoma: This begins in the milk duct and is the most common type. Lobular carcinoma: This starts in the lobules. Invasive breast cancer is when the cancer cells break out from inside the lobules or ducts and invade nearby tissue, increasing the chance of spreading to other parts of the body. Non-invasive breast cancer is when the cancer is still inside its place of origin and has not broken out.


Turn 2: Once it breaks out, how likely is it to spread?
System Response: Even though this condition doesn’t spread, it’s important to keep an eye on it. Between 20% to 40% of women with this condition will develop a separate invasive breast cancer -- one that will grow outside its original location -- within the next 15 years. Most of the time, these later cancers begin in the milk ducts, rather than the lobules. How is lobu

This topic begins with a query about cancer. However, if we took `Turn 3` out of the context of the conversation and tried to search with it, few, if any, results would relate to cancer. Feel free to check this for yourself.

Now, let's see how a query rewriter helps.

We'll use a [T5 query rewriter from HuggingFace](https://huggingface.co/castorini/t5-base-canard). It is finetuned on the [CANARD dataset](https://sites.google.com/view/qanta/projects/canard) but works effectively on CAsT queries. In previous years, we've observed that effective rewriters have also been trained using old CAsT (Year 1&2) queries, QRECC, and other conversational datasets

In [13]:
# Load model and tokenizer from HuggingFace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
rewriter = AutoModelForSeq2SeqLM.from_pretrained("castorini/t5-base-canard").to(device).eval()
rewriter_tokenizer = AutoTokenizer.from_pretrained("castorini/t5-base-canard")

Downloading:   0%|          | 0.00/1.31k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/850M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.81k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.74k [00:00<?, ?B/s]

 The model rewrites an utterance using that utterance and all previous utterances and system responses as input. The utterance and previous turn utterances and system responses should be separated by `|||` when building the input to the model.

For example, to rewrite the utterance at `Turn 2`, we will format the model input as follows:

`model_input = I just had a breast biopsy for cancer. What are the most common types? ||| More research is needed. Types Breast cancer can be: Ductal carcinoma: This begins in the milk duct and is the most common type. Lobular carcinoma: This starts in the lobules. Invasive breast cancer is when the cancer cells break out from inside the lobules or ducts and invade nearby tissue, increasing the chance of spreading to other parts of the body. Non-invasive breast cancer is when the cancer is still inside its place of origin and has not broken out. ||| Once it breaks out, how likely is it to spread?`

In [14]:
def get_turn_attribute(topics: list, global_turn_id: str, attribute: str) -> str:
  """
  Returns topic and turn attributes
  """
  first_topic_number = 106
  topic_number, turn_number = global_turn_id.strip().split("_")

  # Get topic
  topic_index = int(topic_number) - first_topic_number
  topic = topics[topic_index]

  turn_index = int(turn_number) - 1
  extracted_attribute = topic['turn'][turn_index][attribute]

  return topic, turn_index, extracted_attribute


def build_context(topics: list, global_turn_id: str) -> str:
  """
  Given a global_turn_id, build rewriter model input (context)
  with all previous turn utterances and system responses
  """
  topic, turn_index, current_turn_utterance = get_turn_attribute(
      topics, global_turn_id, 'raw_utterance'
  )
  
  previous_turn_context = [f"{turn['raw_utterance']} ||| {turn['passage']}" for turn in topic['turn'][:turn_index]]
  previous_turn_context = (" ||| ").join(previous_turn_context)
  
  # return generated model input
  if previous_turn_context:
    return f"{previous_turn_context} ||| {current_turn_utterance}"
  else:
    return current_turn_utterance

Let's build the context for `turn 3` in `topic 113`.

In [15]:
global_turn_id = "113_3"

# get raw utterance
_, _, raw_utterance = get_turn_attribute(topics, global_turn_id, 'raw_utterance')
print(f"Raw Utterance: {raw_utterance}")

# build context
model_input = build_context(topics, global_turn_id)
print(f"Turn Context: {model_input}")

Raw Utterance: What are the other types of diseases?
Turn Context: How do genes work? ||| Genes A gene is a short piece of DNA. Genes tell the body how to build specific proteins. There are about 20,000 genes in each cell of the human body. Together, they make up the blueprint for the human body and how it works. A person's genetic makeup is called a genotype. Information Genes are made of DNA. Strands of DNA make up part of your chromosomes. Chromosomes have matching pairs of 1 copy of a specific gene. The gene occurs in the same position on each chromosome. Genetic traits, such as eye color, are dominant or recessive: Dominant traits are controlled by 1 gene in the pair of chromosomes. Recessive traits need both genes in the gene pair to work together. Many personal characteristics, such as height, are determined by more than 1 gene. However, some diseases, such as sickle cell anemia, can be caused by a change in a single gene. ||| What other diseases are caused by a single change? |

> **NOTE:** Building the model input in this way may make the input too large for a later turn in longer topics. As an exercise, you can play around with different context truncating strategies. However, a simple approach would be removing earlier turn utterances and responses from the context if the input size is beyond the model's token limit.

Now, let's rewrite the query using our model. Feel free to play around with any other turns you're interested in rewriting.

In [16]:
def rewrite_query(context: str, model, tokenizer, device) -> str:
  tokenized_context = tokenizer.encode(context, return_tensors="pt").to(device)
  output_ids = model.generate(
      tokenized_context, max_length=200, num_beams=4, 
      repetition_penalty=2.5, length_penalty=1.0, 
      early_stopping=True).to(device)
      
  rewrite = tokenizer.decode(output_ids[0], skip_special_tokens=True)
  return rewrite

In [17]:
rewrite = rewrite_query(model_input, rewriter, rewriter_tokenizer, device)
print(f"Query Rewrite: {rewrite}")

Query Rewrite: Besides sickle cell anemia and cystic fibrosis, what are the other types of diseases?


### Rewriter Failures

Note that T5 rewriter is not a silver bullet and struggles with rewriting queries that exhibit certain dialogue phenomena. For example, let's take a look at turn `114-3`:

In [18]:
global_turn_id = "114_3"

# get turn attributes
_, _, failure_raw_utterance = get_turn_attribute(topics, global_turn_id, 'raw_utterance')
_, _, failure_manual_utterance = get_turn_attribute(topics, global_turn_id, 'manual_rewritten_utterance')
print(f"Raw Utterance: {failure_raw_utterance}")
print(f"Manual Utterance: {failure_manual_utterance}")

Raw Utterance: No, I meant to help my ferritin specifically.
Manual Utterance: No, I meant what should I consider changing in my diet to help my ferritin levels specifically.


This utterance is a feedback turn where the user gives the system feedback on the result retrieved in the prior turn (`No, I meant ...`), and reveals a little more information to guide the system towards retrieving a relevant result (`... to help my ferritin specifically`). Let's try to rewrite this utterance using our rewriter from above:

In [19]:
failure_model_input = build_context(topics, global_turn_id)
failure_rewrite = rewrite_query(failure_model_input, rewriter, rewriter_tokenizer, device)
print(f"Query Rewrite: {failure_rewrite}")

Query Rewrite: No, I meant to help my ferritin specifically.


Notice how the rewriter fails to rewrite this query. One of the objectives of CAsT is to identify these kinds of challenges and create rewriters that handle them robustly and effectively.

### Rewriter Evaluation

To quantify the quality of a query rewrite, you can compute the [ROUGE score](https://en.wikipedia.org/wiki/ROUGE_(metric)) of the rewrite against the manual rewrite for that turn.  

In [20]:
# Install ROUGE-score package
!pip install rouge-score

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting rouge-score
  Downloading rouge_score-0.0.4-py2.py3-none-any.whl (22 kB)
Installing collected packages: rouge-score
Successfully installed rouge-score-0.0.4


In [21]:
from rouge_score import rouge_scorer

# get manual rewrite for query
_, _, manual_rewritten_utterance = get_turn_attribute(topics, global_turn_id, 'manual_rewritten_utterance')
print(f"Manual Rewritten Utterance: {manual_rewritten_utterance}")

scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
scores = scorer.score(rewrite, manual_rewritten_utterance)
print("\n")
print("ROUGE Scores")
print(scores)

Manual Rewritten Utterance: No, I meant what should I consider changing in my diet to help my ferritin levels specifically.


ROUGE Scores
{'rouge1': Score(precision=0.058823529411764705, recall=0.07142857142857142, fmeasure=0.06451612903225808), 'rouge2': Score(precision=0.0, recall=0.0, fmeasure=0.0), 'rougeL': Score(precision=0.058823529411764705, recall=0.07142857142857142, fmeasure=0.06451612903225808)}


Note that there are numerous ways to adequately rewrite a query, so the ROUGE score may not truly reflect the quality of your rewrite.

# Document Retrieval

Once we have a rewritten query, we can search for documents. Recall that in the `Index Generation` section, we explored three retrieval approaches to retrieve documents from our indices.

<center>
<img src="https://raw.githubusercontent.com/grill-lab/CAsT-Demo/main/assets/document_retrieval.png" width="85%"/>
</center>

Moving forward, we'll use the `hybrid retrieval` method to retrieve documents with our rewritten query.

In [22]:
hits = hybrid_searcher.search(rewrite)
for i in range(len(hits)):
  print(f'{i+1:2} {hits[i].docid:4} {hits[i].score:.5f}')

 1 MARCO_00_5121943 68.80343
 2 MARCO_00_67514453 68.79907
 3 MARCO_00_51798579 68.75722
 4 KILT_2507552 68.70246
 5 MARCO_00_42843785 68.40486
 6 KILT_1804031 68.34913
 7 MARCO_00_49390315 68.28469
 8 MARCO_00_51920041 68.27667
 9 MARCO_00_52745124 68.26244
10 MARCO_00_51916994 68.25809


Let's examine the top scoring document.

In [23]:
# document lookup is only possible with the sparse retrieval method
best_ranked_doc = sparse_searcher.doc(hits[0].docid)
print(best_ranked_doc.raw())

{
  "id" : "MARCO_00_5121943",
  "url" : "http://16994238methodassessment3.weebly.com/macro-parasites.html",
  "title" : "Macro-parasites - PATHOGENS AND DISEASES",
  "contents" : " Macro-parasites - PATHOGENS AND DISEASES   MACRO-PARASITES What are macro-parasites?  Macro-parasites are multicellular, eukaryotic organisms that are large enough to be seen with the naked eye.  Macro-parasites, like other parasites are metabolically dependent on other living organisms, referred to as the host organism.  Most parasites grow inside the host but generally reproduce by infective stages outside of the host.  Type of macro-parasites.  There are many species of macro-parasites.  The most common of these include nematodes, ticks, mites and flatworms.  Macro-parasites can be either classed as endoparasites; parasites that live inside the host, or ectoparasites; parasites that live on the host.  Examples of endoparasites include flukes and tapeworms, while examples of ectoparasites include mosquito

# Passage Ranking


CAsT's core task is to return a ranked list of relevant passages for a given user utterance. The `Document Retrieval` step provides a list of candidate documents with potentially relevant passages; however, we need to extract those passages and find the most suited ones for our query.

<center>
<img src="https://raw.githubusercontent.com/grill-lab/CAsT-Demo/main/assets/passage_ranking.png" width="75%"/>
</center>

Our source `jsonlines` documents contain the canonical passage splits. As a first step, we'll collect these passages from our candidate documents.

In [24]:
def collect_passages(hits: list) -> dict:
  """
  Extracts and collects all passages from candidate documents
  """
  all_passages = []
  for hit in hits:
    doc_id = hit.docid
    document = sparse_searcher.doc(doc_id)
    passages = json.loads(document.raw())['passage_splits']
    for passage in passages:
        passage['id'] = f"{doc_id}-{passage['id']}" 
    all_passages.extend(passages)

  return all_passages

In [25]:
passages = collect_passages(hits)

# Examine one of the passages
print(len(passages))
passages[0]

52


{'body': ' Macro-parasites - PATHOGENS AND DISEASES   MACRO-PARASITES What are macro-parasites?  Macro-parasites are multicellular, eukaryotic organisms that are large enough to be seen with the naked eye.  Macro-parasites, like other parasites are metabolically dependent on other living organisms, referred to as the host organism.  Most parasites grow inside the host but generally reproduce by infective stages outside of the host.  Type of macro-parasites.  There are many species of macro-parasites.  The most common of these include nematodes, ticks, mites and flatworms.  Macro-parasites can be either classed as endoparasites; parasites that live inside the host, or ectoparasites; parasites that live on the host.  Examples of endoparasites include flukes and tapeworms, while examples of ectoparasites include mosquitoes, fleas, ticks, leeches and lice.  Examples of diseases caused by macro-parasites.  There are many diseases and ailments caused by macro-parasites as well as micro-paras

Now that we have the passages, we need to rank them according to their relevance to our rewritten query. Current [state-of-the-art methods](https://github.com/castorini/pygaggle/) treat this passage ranking task as a classification problem, where a language model predicts if a passage is relevant to a query or not, and provides a `relevance score` along with this classification. 

In this tutorial, we'll use a [Cross-Encoder](https://www.sbert.net/examples/applications/cross-encoder/README.html) from the SentenceTransformers library to produce an output value between 0 and 1 to indicate the similarity between our query rewrite and candidate passage. After this, we'll see how a [T5-based model](https://aclanthology.org/2020.findings-emnlp.63.pdf) finetuned on the MS MARCO dataset performs in comparison.

Like query rewriting, passage ranking is also an open problem, and one of objectives of CAsT is to identify and address challenges associated with this problem.

In [26]:
# install dependencies
!pip install sentence-transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting sentence-transformers
  Downloading sentence-transformers-2.2.2.tar.gz (85 kB)
[K     |████████████████████████████████| 85 kB 4.3 MB/s 
Building wheels for collected packages: sentence-transformers
  Building wheel for sentence-transformers (setup.py) ... [?25l[?25hdone
  Created wheel for sentence-transformers: filename=sentence_transformers-2.2.2-py3-none-any.whl size=125938 sha256=d05026b05eb234c878f803a94f1b93485c724b257f6b9d8fdd2c6637adf281d1
  Stored in directory: /root/.cache/pip/wheels/bf/06/fb/d59c1e5bd1dac7f6cf61ec0036cc3a10ab8fecaa6b2c3d3ee9
Successfully built sentence-transformers
Installing collected packages: sentence-transformers
Successfully installed sentence-transformers-2.2.2


### Cross Encoder

In [27]:
from sentence_transformers import CrossEncoder

# load cross encoder
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

Downloading:   0%|          | 0.00/794 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/86.7M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/316 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [28]:
def rank_passages_cross_encoder(rewrite: str, passages: list, cross_encoder: CrossEncoder) -> list:
  """
  Uses a cross encoder to compute the similarity between a query rewrite and a list of
  passages
  """
  query_passage_pairs = [[rewrite, passage['body']] for passage in passages]
  cross_scores = cross_encoder.predict(query_passage_pairs)

  for passage, score in zip(passages, cross_scores):
    passage['score'] = score
  
  ranked_passages = sorted(passages, key=lambda x: x['score'], reverse=True)
  return ranked_passages

In [29]:
ranked_passages = rank_passages_cross_encoder(rewrite, passages, cross_encoder)

Let's have a look at our top 10 passages. Recall our `raw_utterance`, `rewrite`, and `manual_rewrite`:

In [30]:
print(f"Raw Utterance: {raw_utterance}")
print(f"Rewrite: {rewrite}")
print(f"Manual Rewrite: {manual_rewritten_utterance}")

Raw Utterance: What are the other types of diseases?
Rewrite: Besides sickle cell anemia and cystic fibrosis, what are the other types of diseases?
Manual Rewrite: No, I meant what should I consider changing in my diet to help my ferritin levels specifically.


In [31]:
print("Top 10 relevant passages:")
print("\n")
for passage in ranked_passages[:10]:
  print(passage)

Top 10 relevant passages:


{'body': 'But at levels of between 10 and 20 percent a person can develop blue skin without any other symptoms. Most of blue Fugates never suffered any health effects and lived into their 80s and 90s.  "If you are between 1 percent and 10 percent, no one knows you have an abnormal level and this might be the case in a lot of unsuspecting patients," he said.  Many other recessive gene diseases, such as sickle cell anemia, Tay Sachs and cystic fibrosis can be lethal, he said.  "If I carry a bad recessive gene with a rare abnormality and married, the child probably wouldn\'t be sick, because it\'s very rare to meet another person with the [  same] bad gene and the most frequent cause therefore is in-breeding," Tefferi said.  Such was the case with the Fugates.  Martin Fugate came to Troublesome Creek from France in 1820 and family folklore says he was blue.  He married Elizabeth Smith, who also carried the recessive gene.  Of their seven children, four were rep

### T5

The T5 model works by generating relevance labels as `target tokens` for each passage, and uses the underlying logits of these target tokens as relevance probabilities for ranking.

> **NOTE:** If you get a `CUDA out of memory` error, reduce the number of candidate passages to rerank, until it works.

In [32]:
ranker_tokenizer = AutoTokenizer.from_pretrained('castorini/monot5-base-msmarco-10k')
ranker = AutoModelForSeq2SeqLM.from_pretrained('castorini/monot5-base-msmarco-10k', return_dict=True).to(device).eval()

Downloading:   0%|          | 0.00/1.81k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.30k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.74k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/850M [00:00<?, ?B/s]

In [33]:
prediction_tokens = ['▁false', '▁true']
token_false_id = ranker_tokenizer.get_vocab()[prediction_tokens[0]]
token_true_id  = ranker_tokenizer.get_vocab()[prediction_tokens[1]]

def rank_passages_t5(
    rewrite: str, passages: list, model: AutoModelForSeq2SeqLM, tokenizer: AutoTokenizer):
    
    input_ids = tokenizer(
        [f"Query: {rewrite} Document: {passage['body']} Relevant: " for passage in passages], 
        return_tensors="pt", padding=True, truncation=True
    ).to(device).input_ids

    outputs = model.generate(
        input_ids,
        return_dict_in_generate=True, 
        output_scores=True 
    )
    
    scores = outputs.scores[0][:, [token_false_id, token_true_id]]
    scores = torch.nn.functional.softmax(scores, dim=1)
    probabilities = scores[:, 1].tolist()

    for passage, probability in zip(passages, probabilities):
        passage['score'] = probability

    ranked_passages = sorted(passages, key=lambda x: x['score'], reverse=True)
    return ranked_passages

In [34]:
ranked_passages_t5 = rank_passages_t5(rewrite, passages[:52], ranker, ranker_tokenizer)

In [35]:
print("Top relevant passages from T5:")
print("\n")
for passage in ranked_passages_t5[:10]:
  print(passage)

Top relevant passages from T5:


{'body': 'But at levels of between 10 and 20 percent a person can develop blue skin without any other symptoms. Most of blue Fugates never suffered any health effects and lived into their 80s and 90s.  "If you are between 1 percent and 10 percent, no one knows you have an abnormal level and this might be the case in a lot of unsuspecting patients," he said.  Many other recessive gene diseases, such as sickle cell anemia, Tay Sachs and cystic fibrosis can be lethal, he said.  "If I carry a bad recessive gene with a rare abnormality and married, the child probably wouldn\'t be sick, because it\'s very rare to meet another person with the [  same] bad gene and the most frequent cause therefore is in-breeding," Tefferi said.  Such was the case with the Fugates.  Martin Fugate came to Troublesome Creek from France in 1820 and family folklore says he was blue.  He married Elizabeth Smith, who also carried the recessive gene.  Of their seven children, four wer

Notice how that the passage rankings across our Cross Encoder and T5 based methods vary. As with query rewriting, one of the objectives with CAsT is to explore the most optimal passage ranking approaches  

Note that we only use a small subset (~0.001%) of the CAsT year three collection for this tutorial. As such, there may be lots of relevant documents that we don't have in our index!

# Putting it All Together

At this point, we have all the components of our simple CAsT system. Let's encapsulate them in a Python class to help us visualise how they fit together.

In [36]:
from abc import ABC, abstractmethod

class BaseCAsTSystem(ABC):

  @abstractmethod
  def rewrite_query(self, query: str, context: str) -> str:
    pass
  
  @abstractmethod
  def retrieve_docments(self, query: str) -> list:
    pass

  @abstractmethod
  def rank_passages(self, query: str, passages: list) -> list:
    pass

Our system will be made up of the following:  

*   T5 based rewriter as our query rewriter (i.e `rewrite_query`),
*   Hybrid Searcher as our document retriever (i.e `search_system`), and
*   Cross Encoder as our passage ranker (i.e `rank_passages`) <!-- works faster than T5 -->

Feel free to swap out and experiment with any of the components



In [37]:
class SimpleCAsTSystem(BaseCAsTSystem):

  def __init__(self, query_rewriter, query_rewriter_tokenizer, 
               search_system, passage_ranker, device) -> None:
    self.query_rewriter = query_rewriter
    self.query_rewriter_tokenizer = query_rewriter_tokenizer
    self.search_system = search_system
    self.passage_ranker = passage_ranker
    self.device = device
  
  def rewrite_query(self, query_and_context: str) -> str:
    """
    Simple CAsT System Query Rewriter
    """

    tokenized_context = self.query_rewriter_tokenizer.encode(
        query_and_context, return_tensors="pt"
    ).to(self.device)

    output_ids = self.query_rewriter.generate(
        tokenized_context, max_length=200, num_beams=4, 
        repetition_penalty=2.5, length_penalty=1.0, 
        early_stopping=True).to(device)

    rewrite = self.query_rewriter_tokenizer.decode(
        output_ids[0], skip_special_tokens=True
    )
    return rewrite 
  
  def retrieve_docments(self, query: str) -> list:
    """
    Simple CAsT System Document Retriever
    """
    hits = self.search_system.search(query)
    return hits
  
  def rank_passages(self, query: str, passages: list) -> list:
    """
    Simple CAsT System Passage Ranker
    """
    query_passage_pairs = [[query, passage['body']] for passage in passages]
    query_passage_scores = self.passage_ranker.predict(query_passage_pairs)

    for passage, score in zip(passages, query_passage_scores):
      passage['score'] = score
    
    ranked_passages = sorted(passages, key=lambda x: x['score'], reverse=True)
    return ranked_passages

Now, we can run through all the topics in the Year 3 benchmark and find relevant passages for each query from our indices. 

**EXERCISE:** Let's find the top 3 passages for each query using an instance of the `SimpleCAsTSystem` class.

In [38]:
# Recall that models and systems were instantiated earlier on in the notebook

# create system instance
simple_cast_system = SimpleCAsTSystem(
    query_rewriter=rewriter,
    query_rewriter_tokenizer=rewriter_tokenizer,
    search_system=hybrid_searcher,
    passage_ranker=cross_encoder,
    device=device
)

In [39]:
#@title
for topic in topics:
  for turn in topic['turn']:
    # get turn id
    turn_id = f"{topic['number']}_{turn['number']}"
    # Build rewriter input -- not ideal, but saves having to write logic again
    query_and_context = build_context(topics, turn_id)
    # rewrite query
    rewrite = simple_cast_system.rewrite_query(query_and_context)
    # retrieve candidate documents
    candidate_documents = simple_cast_system.retrieve_docments(rewrite)
    # collect passages from candidate documents
    extracted_passages = collect_passages(candidate_documents)
    # rank passages
    ranked_passages = simple_cast_system.rank_passages(rewrite, extracted_passages)
    # print output
    print(f"Turn ID: {turn_id}")
    print(f"Query: {turn['raw_utterance']}")
    print(f"Rewrite: {rewrite}")
    print(f"Manual Rewrite: {turn['manual_rewritten_utterance']}")
    print("\n")
    print("Top Passages:")
    for passage in ranked_passages[:3]:
      print(f"Passage ID: {passage['id']}")
      print(f"Score {passage['score']}")
      print(f"Text: {passage['body']}")
      print("\n")
    print("------------------")



Turn ID: 106_1
Query: I just had a breast biopsy for cancer. What are the most common types?
Rewrite: What are the most common types of cancer in regards to me?
Manual Rewrite: I just had a breast biopsy for cancer. What are the most common types of breast cancer?


Top Passages:
Passage ID: MARCO_00_43867598-3
Score 2.925729274749756
Text: For adults, the most common brain tumor types are astrocytoma, oligodendroglioma and meningioma. Primary brain tumors are named according to the type of cells or the part of the brain in which they begin.  For example, most primary brain tumors begin in glial cells.  This type of tumor is called a glioma.  Glioma: Gliomas begin from glial cells found in the supportive tissue of the brain.  There are several types of gliomas, categorized by where they are found, and where the tumor begins.  The following are gliomas: Astrocytoma: The tumor arises from star-shaped glial cells called astrocytes.  It can be any grade.  In adults, an astrocytoma most oft

Token indices sequence length is longer than the specified maximum sequence length for this model (624 > 512). Running this sequence through the model will result in indexing errors


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Text: Global warming will probably increase the risk of food insecurity for some vulnerable groups, such as the poor.  Animal husbandry is also responsible for greenhouse gas production of and a percentage of the world's methane, and future land infertility, and the displacement of wildlife.  Agriculture contributes to climate change by anthropogenic emissions of greenhouse gases, and by the conversion of non-agricultural land such as forest for agricultural use.  Agriculture, forestry and land-use change contributed around 20 to 25% to global annual emissions in 2010.  A range of policies can reduce the risk of negative climate change impacts on agriculture, and greenhouse gas emissions from the agriculture sector.   Section::::Environmental impact.:Sustainability.   Current farming methods have resulted in over-stretched water resources, high levels of erosion and reduced soil fertility.  There is not enough water to co

In [None]:
# Write code here

Nice. Click `Show Code` in the cell above if you'd like to see our approach. Again, remember that we are retrieving documents from a very small subset of the CAsT collection, so don't fret if the documents you retrieve are not all relevant. You can extend the system we've built here to work with the larger collection.

# Evaluation

Next, we'll evaulate the effectiveness of our system using the standard [`trec_eval` toolkit](https://github.com/usnistgov/trec_eval). As we've used a subset of the collection for retrieval, the numbers won't actually mean much. The objective here is just to demonstrate how it works.

## `.qrel` and `.run` files

Recall that we downloaded the `.qrel` (Relevance judgement) file from the setup step. We will now create a .run (Results file) file from our ranked passages and use the `trec_eval` to evaluate our runs.

The `.run` file contains a ranking of passages for each query by a retrieval system. This file is evaluated by trec_eval based on the gold standard rankings in the `.qrel` file. A `.run` file is formatted as follows:

```
query-id Q0 passage-id rank score run-name
query-id Q0 passage-id rank score run-name
query-id Q0 passage-id rank score run-name
query-id Q0 passage-id rank score run-name
...
```

*   The `query-id` field identifies the query
*   The `Q0` field is a dummy value ignored by trec_eval, but required nonetheless
*   The `passage-id` field identifies the retrieved document
*   The `rank` field represents the retrieved document's position within the ranked list for that query. This field is not evaluated but necessary.
*   The `score` field is an integer or float value that indicates the similarity between the document and query, therefore, higher ranked documents have better scores. 
*   The `run-name` field identifies your run.


Note that these fields can either be space separated or tab separated.




Now, let's generate a `.run` file for our system.

**EXERCISE**: Re-purpose the code from the ***Putting it Together*** exercise to create a run file from our ranked passages following the file format above.

In [None]:
from tqdm import tqdm

with open("files/simple_cast_system.run", "a") as run_file:
  for topic in tqdm(topics):
    for turn in topic['turn']:
      # get turn id
      turn_id = f"{topic['number']}_{turn['number']}"
      # Build rewriter input -- not ideal, but saves having to write logic again
      query_and_context = build_context(topics, turn_id)
      # rewrite query
      rewrite = simple_cast_system.rewrite_query(query_and_context)
      # retrieve candidate documents
      candidate_documents = simple_cast_system.retrieve_docments(rewrite)
      # collect passages from candidate documents
      extracted_passages = collect_passages(candidate_documents)
      # rank passages
      ranked_passages = simple_cast_system.rank_passages(rewrite, extracted_passages)
      # Write output to run file
      for rank, passage in enumerate(ranked_passages):
        run_file.write(f"{turn_id}\tQ0\t{passage['id']}\t{rank+1}\t{passage['score']}\tsimple-cast-system\n")
      # print(f"Wrote passage rankings for Turn {turn_id} to file")

100%|██████████| 26/26 [04:30<00:00, 10.39s/it]


In [None]:
# Write code here

## `trec_eval`

Now, let's install the `trec_eval` toolkit and evaluate our rankings

In [None]:
!git clone https://github.com/usnistgov/trec_eval.git
%cd trec_eval
!make
!mv trec_eval /usr/local/bin/
%cd ..

Cloning into 'trec_eval'...
remote: Enumerating objects: 763, done.[K
remote: Counting objects: 100% (14/14), done.[K
remote: Compressing objects: 100% (14/14), done.[K
remote: Total 763 (delta 5), reused 3 (delta 0), pack-reused 749[K
Receiving objects: 100% (763/763), 679.52 KiB | 1.50 MiB/s, done.
Resolving deltas: 100% (491/491), done.
/content/trec_eval
gcc -g -I.  -Wall -DVERSIONID=\"9.0.7\"  -o trec_eval trec_eval.c formats.c meas_init.c meas_acc.c meas_avg.c meas_print_single.c meas_print_final.c get_qrels.c get_trec_results.c get_prefs.c get_qrels_prefs.c get_qrels_jg.c form_res_rels.c form_res_rels_jg.c form_prefs_counts.c utility_pool.c get_zscores.c convert_zscores.c measures.c  m_map.c m_P.c m_num_q.c m_num_ret.c m_num_rel.c m_num_rel_ret.c m_gm_map.c m_Rprec.c m_recip_rank.c m_bpref.c m_iprec_at_recall.c m_recall.c m_Rprec_mult.c m_utility.c m_11pt_avg.c m_ndcg.c m_ndcg_cut.c m_Rndcg.c m_ndcg_rel.c m_binG.c m_G.c m_rel_P.c m_success.c m_infap.c m_map_cut.c m_gm_bpref.

In [None]:
# Year 3 relevance judgements contain document ranking, so need to deduplicate run
def deduplicate_run(run_file: str) -> list:

    # {'106_1': [], '106_2' : [], ... }
    document_ids = {}

    with open(run_file) as f:
        run_rows: list = []

        for line in f:
            line_dict: dict = {}
            line_content = line.split()
            line_content[2] = line_content[2].rsplit("-",1)[0]

            # check if turn_id is in the dictionary
            if document_ids.get(line_content[0]):
                # check if turn has document id in its list
                if line_content[2] not in document_ids[line_content[0]]:
                    document_ids[line_content[0]].append(line_content[2])

                    line_dict = {
                        "turn_id": line_content[0],
                        "dummy_value": line_content[1],
                        "doc_id": line_content[2],
                        "rank": line_content[3],
                        "score": line_content[4],
                        "run_name": line_content[5]
                    }
                    run_rows.append(line_dict)
            else:
                document_ids[line_content[0]] = []
                document_ids[line_content[0]].append(line_content[2])

                line_dict = {
                    "turn_id": line_content[0],
                    "dummy_value": line_content[1],
                    "doc_id": line_content[2],
                    "rank": line_content[3],
                    "score": line_content[4],
                    "run_name": line_content[5]
                }
                run_rows.append(line_dict)

        return run_rows


def adjust_run_ranking(run_rows: list) -> list:

    for i in range(1, len(run_rows)):
        if run_rows[i]["turn_id"] == run_rows[i-1]["turn_id"]:
            run_rows[i]["rank"] = str(int(run_rows[i-1]["rank"]) + 1)
        else:
            run_rows[i]['rank'] = str(1)

    return run_rows

In [None]:
deduplicated_run = deduplicate_run("files/simple_cast_system.run")
adjusted_run = adjust_run_ranking(deduplicated_run)

with open(f"files/simple_cast_system_deduped.run", "w") as deduped_run_file:
    for row in adjusted_run:
        deduped_run_file.write(
            f'{row["turn_id"]} {row["dummy_value"]} {row["doc_id"]} {row["rank"]} {row["score"]} {row["run_name"]}\n')

In [None]:
!trec_eval -q -m official files/trec-cast-qrels-docs.2021.qrel files/simple_cast_system_deduped.run

num_ret               	106_1	10
num_rel               	106_1	40
num_rel_ret           	106_1	0
map                   	106_1	0.0000
Rprec                 	106_1	0.0000
bpref                 	106_1	0.0000
recip_rank            	106_1	0.0000
iprec_at_recall_0.00  	106_1	0.0000
iprec_at_recall_0.10  	106_1	0.0000
iprec_at_recall_0.20  	106_1	0.0000
iprec_at_recall_0.30  	106_1	0.0000
iprec_at_recall_0.40  	106_1	0.0000
iprec_at_recall_0.50  	106_1	0.0000
iprec_at_recall_0.60  	106_1	0.0000
iprec_at_recall_0.70  	106_1	0.0000
iprec_at_recall_0.80  	106_1	0.0000
iprec_at_recall_0.90  	106_1	0.0000
iprec_at_recall_1.00  	106_1	0.0000
P_5                   	106_1	0.0000
P_10                  	106_1	0.0000
P_15                  	106_1	0.0000
P_20                  	106_1	0.0000
P_30                  	106_1	0.0000
P_100                 	106_1	0.0000
P_200                 	106_1	0.0000
P_500                 	106_1	0.0000
P_1000                	106_1	0.0000
num_ret               	106_10	10
num_rel 

# Future Directions

As an exercise, we encourage you to explore the different ways the various components of our simple CAsT system fails and why. Explore:


*   The categories of queries the rewriter struggles with
*   Other rewriting methods such as Query Expansion (i.e QuReTeC)
*   Document retrieval performance with different sparse parameters and/or different query/document encoders
*   Passage ranking methods with other methods such as DuoT5
*   Swapping out Pyserini with PyTerrier toolkit


# Looking Ahead (CAsT'22)

### Response Generation

In the 2022 edition of CAsT, participants will be asked to return submit a response to each query along with a ranked list of passages (also called provenance). 

<center>
<img src="https://raw.githubusercontent.com/grill-lab/CAsT-Demo/main/assets/extended_system.png" width="75%"/>
</center>

Let's explore one way this can be done, by framing the challenge as a summarisation problem. Specifically, we will generate an abstractive summary of the top three passages using an off the shelf [PEGASUS](https://huggingface.co/docs/transformers/model_doc/pegasus) model.

In [40]:
from transformers import PegasusForConditionalGeneration, PegasusTokenizer

top_three_passages = '\n'.join([passage['body'] for passage in passages[:3]])

model_name = "google/pegasus-xsum"
summarisation_tokenizer = PegasusTokenizer.from_pretrained(model_name)
summariser = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)

Downloading:   0%|          | 0.00/1.82M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/65.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/87.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.36k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.12G [00:00<?, ?B/s]

In [41]:
def generate_summary(passages, tokenizer, model, device):
    top_three_passages = ' '.join([passage['body'] for passage in passages[:3]])
    batch = summarisation_tokenizer(top_three_passages, truncation=True, padding="longest", return_tensors="pt").to(device)
    translated = summariser.generate(**batch)
    summary = summarisation_tokenizer.decode(translated[0], skip_special_tokens=True)

    return summary

In [43]:
for topic in topics:
  for turn in topic['turn']:
    # get turn id
    turn_id = f"{topic['number']}_{turn['number']}"
    print(f"Turn ID: {turn_id}")
    # Log Raw utterance
    _, _, raw_utterance = get_turn_attribute(topics, turn_id, 'raw_utterance')
    print(f"Utterance: {raw_utterance}")
    # rewrite query
    query_and_context = build_context(topics, turn_id)
    rewrite = simple_cast_system.rewrite_query(query_and_context)
    print(f"Rewrite: {rewrite}")
    # retrieve candidate documents
    candidate_documents = simple_cast_system.retrieve_docments(rewrite)
    # collect passages from candidate documents
    extracted_passages = collect_passages(candidate_documents)
    # rank passages
    ranked_passages = simple_cast_system.rank_passages(rewrite, extracted_passages)
    # generate summary
    summary = generate_summary(ranked_passages, summarisation_tokenizer, summariser, device)
    print(f"System Response: {summary}")
    # log provenance
    print(f"Provenance: {[passage['id'] for passage in ranked_passages[:3]]}")
    print("\n")

    break

Turn ID: 106_1
Utterance: I just had a breast biopsy for cancer. What are the most common types?
Rewrite: What are the most common types of cancer in regards to me?
System Response: For children, the most common brain tumor types are primary brain and meningioma.
Provenance: ['MARCO_00_43867598-3', 'MARCO_00_43867598-4', 'MARCO_00_67761412-4']


Turn ID: 107_1
Utterance: How do I build a cheap driveway?
Rewrite: How do I build a cheap driveway?
System Response: Question: How many vehicle cycles does a gate have per day? Answer: This question has so many variables it is impossible to answer directly.
Provenance: ['MARCO_00_38428636-4', 'MARCO_00_38428636-5', 'MARCO_00_66382206-3']


Turn ID: 108_1
Utterance: How can fires help an ecosystem?
Rewrite: How can fires help an ecosystem?
System Response: The tundra biome is one of the most fragile habitats on the planet.
Provenance: ['MARCO_00_4615711-2', 'MARCO_00_72402184-4', 'MARCO_00_72402184-1']


Turn ID: 109_1
Utterance: Why do cats ea

From the summaries generated by our PEGASUS model, we can see that there's a lot of headroom to be made on the response generation task. Systems not only have to find relevant passages, but also have to understand and extract salient information from them in response to a query.

### Mixed Initiative Sub-Task

Furthermore, Year 4 also brings a Mixed-Initiative sub-task where systems can optionally pose a question to a raw utterance instead of generating a rewrite. This allows these systems to gain additional context for retrieving relevant passages for the query. Systems can treat the mixed-initiative task as a question generation or question ranking problem.

<center>
<img src="https://raw.githubusercontent.com/grill-lab/CAsT-Demo/main/assets/Mixed%20Initiative.png" width="75%"/>
</center>

# Appendix

### CAsT Searcher

Our tool called [CAsT Searcher](http://3.94.55.111:5000/) enables you interact with the collection for CAsT Year 4 in its entirety. It also allows you do query rewriting, sparse retrieval, and passage ranking. We encourage you to use it to explore the collection and/or contribute to the [project's Github repository](https://github.com/grill-lab/Interactive-CAsT) however you see fit.

We look forward to your submissions for CAsT Year 4 and beyond!