<a href="https://colab.research.google.com/github/mrkarezina/research-heatmap/blob/master/huggingface_t5_testing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CORD-19 doc2query-T5 Experiments

This notebook aims to the explore using the [doc2query-T5 model](https://github.com/castorini/docTTTTTquery#data-and-trained-models) to retrieve documents relevant to the research questions for each topic in the COVID-19 [Open Research Dataset](https://pages.semanticscholar.org/coronavirus-research).

In [0]:
from IPython.core.display import display, HTML

## Data Loading

First we will install dependencies and download the T5 model checkpoint. 

In [1]:
!pip install transformers

# docTTTTTquery T5-base checkpoint
!curl -o t5-base.zip "https://storage.googleapis.com/doctttttquery_git/t5-base.zip"
!unzip t5-base.zip -d "t5-model-tf"
!gsutil cp gs://t5-data/pretrained_models/base/operative_config.gin t5-model-tf/
!rm t5-base.zip

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/a3/78/92cedda05552398352ed9784908b834ee32a0bd071a9b32de287327370b7/transformers-2.8.0-py3-none-any.whl (563kB)
[K     |████████████████████████████████| 573kB 2.8MB/s 
[?25hCollecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/74/f4/2d5214cbf13d06e7cb2c20d84115ca25b53ea76fa1f0ade0e3c9749de214/sentencepiece-0.1.85-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)
[K     |████████████████████████████████| 1.0MB 38.3MB/s 
[?25hCollecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/99/50/93509f906a40bffd7d175f97fd75ea328ad9bd91f48f59c4bd084c94a25e/sacremoses-0.0.41.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 40.2MB/s 
Collecting tokenizers==0.5.2
[?25l  Downloading https://files.pythonhosted.org/packages/d1/3f/73c881ea4723e43c1e9acf317cf407fab3a278daab3a69c98dcac511c04f/tokenizers-0.5.2-cp36-cp36m-manylinux1_x86_64.whl (3.7MB)
[K    

Let's load the checkpoint of our doc2query-T5 model. We'll use the config and tokenizer from Hugging Face T5-base model.

In [2]:
import torch
from transformers import T5Tokenizer, T5Config, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained('t5-base')
config = T5Config.from_pretrained('t5-base')
model = T5ForConditionalGeneration.from_pretrained('t5-model-tf/model.ckpt-1004000.index', from_tf=True, config=config)

HBox(children=(IntProgress(value=0, description='Downloading', max=791656, style=ProgressStyle(description_wid…




HBox(children=(IntProgress(value=0, description='Downloading', max=1307, style=ProgressStyle(description_width…




We can check the memory of the runtime you were allocated. Hopefully you got assigned a Tesla P100-PCIE-16GB. If you got unlucky with an 8GB RAM you might run into "CUDA out of memory error" when running some cells.

In [4]:
# Check for GPU
if torch.cuda.is_available():     
    device = torch.device("cuda")
    print('Using: ', torch.cuda.get_device_name(0))
else:
    print( 'No GPU available.')
    device = torch.device("cpu")

model = model.to(device)
!nvidia-smi

Using:  Tesla P100-PCIE-16GB
Wed Apr 22 15:54:22 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.64.00    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   41C    P0    32W / 250W |   1739MiB / 16280MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                     

### Loading CORD-19

We will also need to download the CORD-19 dataset.

In [0]:
%%capture
%%shell
DATE=2020-04-10
DATA_DIR=./covid-"${DATE}"
mkdir "${DATA_DIR}"

wget https://ai2-semanticscholar-cord-19.s3-us-west-2.amazonaws.com/"${DATE}"/comm_use_subset.tar.gz -P "${DATA_DIR}"
wget https://ai2-semanticscholar-cord-19.s3-us-west-2.amazonaws.com/"${DATE}"/noncomm_use_subset.tar.gz -P "${DATA_DIR}"
wget https://ai2-semanticscholar-cord-19.s3-us-west-2.amazonaws.com/"${DATE}"/custom_license.tar.gz -P "${DATA_DIR}"
wget https://ai2-semanticscholar-cord-19.s3-us-west-2.amazonaws.com/"${DATE}"/biorxiv_medrxiv.tar.gz -P "${DATA_DIR}"
wget https://ai2-semanticscholar-cord-19.s3-us-west-2.amazonaws.com/"${DATE}"/metadata.csv -P "${DATA_DIR}"

ls "${DATA_DIR}"/*.tar.gz | xargs -I {} tar -zxvf {} -C "${DATA_DIR}"

We'll need to load the JSON documents into dataframe. The following preprocessing script is adapted from https://www.kaggle.com/maksimeren/covid-19-literature-clustering

In [0]:
import numpy as np
import pandas as pd
import glob
import json

Load all json documents.

In [7]:
root_path='./covid-2020-04-10'
all_json = glob.glob(f'{root_path}/**/*.json', recursive=True)
len(all_json)

59311

Load meta-data.

In [8]:
metadata_path = f'{root_path}/metadata.csv'
meta_df = pd.read_csv(metadata_path, dtype={
    'pubmed_id': str,
    'Microsoft Academic Paper ID': str, 
    'doi': str
})
meta_df.head()

Unnamed: 0,cord_uid,sha,source_x,title,doi,pmcid,pubmed_id,license,abstract,publish_time,authors,journal,Microsoft Academic Paper ID,WHO #Covidence,has_pdf_parse,has_pmc_xml_parse,full_text_file,url
0,xqhn0vbp,1e1286db212100993d03cc22374b624f7caee956,PMC,Airborne rhinovirus detection and effect of ul...,10.1186/1471-2458-3-5,PMC140314,12525263,no-cc,"BACKGROUND: Rhinovirus, the most common cause ...",2003-01-13,"Myatt, Theodore A; Johnston, Sebastian L; Rudn...",BMC Public Health,,,True,True,custom_license,https://www.ncbi.nlm.nih.gov/pmc/articles/PMC1...
1,gi6uaa83,8ae137c8da1607b3a8e4c946c07ca8bda67f88ac,PMC,Discovering human history from stomach bacteria,10.1186/gb-2003-4-5-213,PMC156578,12734001,no-cc,Recent analyses of human pathogens have reveal...,2003-04-28,"Disotell, Todd R",Genome Biol,,,True,True,custom_license,https://www.ncbi.nlm.nih.gov/pmc/articles/PMC1...
2,le0ogx1s,,PMC,A new recruit for the army of the men of death,10.1186/gb-2003-4-7-113,PMC193621,12844350,no-cc,"The army of the men of death, in John Bunyan's...",2003-06-27,"Petsko, Gregory A",Genome Biol,,,False,True,custom_license,https://www.ncbi.nlm.nih.gov/pmc/articles/PMC1...
3,fy4w7xz8,0104f6ceccf92ae8567a0102f89cbb976969a774,PMC,Association of HLA class I with severe acute r...,10.1186/1471-2350-4-9,PMC212558,12969506,no-cc,BACKGROUND: The human leukocyte antigen (HLA) ...,2003-09-12,"Lin, Marie; Tseng, Hsiang-Kuang; Trejaut, Jean...",BMC Med Genet,,,True,True,custom_license,https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2...
4,0qaoam29,5b68a553a7cbbea13472721cd1ad617d42b40c26,PMC,A double epidemic model for the SARS propagation,10.1186/1471-2334-3-19,PMC222908,12964944,no-cc,BACKGROUND: An epidemic of a Severe Acute Resp...,2003-09-10,"Ng, Tuen Wai; Turinici, Gabriel; Danchin, Antoine",BMC Infect Dis,,,True,True,custom_license,https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2...


In [9]:
class FileReader:
    def __init__(self, file_path):
        with open(file_path) as file:
            content = json.load(file)
            self.paper_id = content['paper_id']
            self.abstract = []
            self.body_text = []
            # Dataset changed, missing abstracts don't have a key
            try:
              for entry in content['abstract']:
                  self.abstract.append(entry['text'])
            except KeyError as e:
                  self.abstract.append("")
            # Body text
            for entry in content['body_text']:
                self.body_text.append(entry['text'])
            self.abstract = '\n'.join(self.abstract)
            self.body_text = '\n'.join(self.body_text)
    def __repr__(self):
        return f'{self.paper_id}: {self.abstract[:200]}... {self.body_text[:200]}...'

def get_breaks(content, length):
    data = ""
    words = content.split(' ')
    total_chars = 0

    # add break every length characters
    for i in range(len(words)):
        total_chars += len(words[i])
        if total_chars > length:
            data = data + "<br>" + words[i]
            total_chars = 0
        else:
            data = data + " " + words[i]
    return data

first_row = FileReader(all_json[0])
print(first_row)

1d466cbca75a16becf32b83b5117e70817587aae: ... Selection pressure for increased production has caused producers to remove cows based on factors that include reproductive failure, structural issues, poor health, and disease. Producers emphasize imp...


Load all of the documents including full body text into dataframe.

In [20]:
from tqdm.notebook import tqdm

dict_ = {'paper_id': [], 'doi':[], 'abstract': [], 'body_text': [], 'authors': [], 'title': [], 'journal': [], 'abstract_summary': []}
for entry in tqdm(all_json):
    
    try:
        content = FileReader(entry)
    except Exception as e:
        continue  # invalid paper format, skip
    
    # get metadata information
    meta_data = meta_df.loc[meta_df['sha'] == content.paper_id]
    # no metadata, skip this paper
    if len(meta_data) == 0:
        continue
    
    dict_['abstract'].append(content.abstract)
    dict_['paper_id'].append(content.paper_id)
    dict_['body_text'].append(content.body_text)
    
    # also create a column for the summary of abstract to be used in a plot
    if len(content.abstract) == 0: 
        # no abstract provided
        dict_['abstract_summary'].append("Not provided.")
    elif len(content.abstract.split(' ')) > 100:
        # abstract provided is too long for plot, take first 300 words append with ...
        info = content.abstract.split(' ')[:100]
        summary = get_breaks(' '.join(info), 40)
        dict_['abstract_summary'].append(summary + "...")
    else:
        # abstract is short enough
        summary = get_breaks(content.abstract, 40)
        dict_['abstract_summary'].append(summary)
        
    # get metadata information
    meta_data = meta_df.loc[meta_df['sha'] == content.paper_id]
    
    try:
        # if more than one author
        authors = meta_data['authors'].values[0].split(';')
        if len(authors) > 2:
            # more than 2 authors, may be problem when plotting, so take first 2 append with ...
            dict_['authors'].append(get_breaks('. '.join(authors), 40))
        else:
                # authors will fit in plot
                dict_['authors'].append(". ".join(authors))
    except Exception as e:
        # if only one author - or Null valie
        dict_['authors'].append(meta_data['authors'].values[0])
    
    # add the title information, add breaks when needed
    try:
        title = get_breaks(meta_data['title'].values[0], 40)
        dict_['title'].append(title)
    # if title was not provided
    except Exception as e:
        dict_['title'].append(meta_data['title'].values[0])
    
    # add the journal information
    dict_['journal'].append(meta_data['journal'].values[0])
    
    # add doi
    dict_['doi'].append(meta_data['doi'].values[0])
    
df_covid = pd.DataFrame(dict_, columns=['paper_id', 'doi', 'abstract', 'body_text', 'authors', 'title', 'journal', 'abstract_summary'])
df_covid.head()

 15%|█▍        | 8823/59311 [01:14<08:05, 103.96it/s]

KeyboardInterrupt: ignored

In [11]:
# Drop empty abstracts
df_covid['abstract'].replace('', np.nan, inplace=True)
df_covid = df_covid[df_covid['abstract'].notna()]

df_covid.shape

(26305, 8)

## Scoring CORD-19 Questions

Querying question against CORD-19

In the following section we will experiment with querying different types of questions against CORD-19 documents to see if the loss scores meaningfully reflect the relevance of the question to the document.

In [0]:
batch_size = 5
# Prevent token indices sequence length is longer than the specified maximum
max_sequence_len = 512

def encode(doc):
  return tokenizer.encode_plus(doc, max_length=max_sequence_len, return_tensors="pt")["input_ids"]

def eval(document, questions, target_ques=None):
  display(HTML(f"<b>Doc Sample:</b> {document[:500]}"))

  scores = []
  for q in questions:
    input_ids = encode(f"{document} </s>")
    question_ids = encode(f"{q} </s>")
    outputs = model(input_ids.to(device),
                    lm_labels=question_ids.to(device))
    scores.append([outputs[0], q])

  scores = sorted(scores, key=lambda x: x[0])
  for s in scores:
    if s[1] == target_ques:
      display(HTML(f"<p><b>Loss: {s[0]} Target question: {target_ques}</b></p>"))
    else:
      display(HTML(f"<p>Loss: {s[0]}  Question: {s[1]}</p>"))

To get an idea of whether the loss scores make sense we can check whether relevant questions rank higher up than random ones. The loss here is the cross-entropy loss.

Consider the input document denoted as "input" and the "target" labels which are the tokens in the question denoted as $(w_1, w_2, ...)$. We define the loss as:

$loss = - log P(w_1 | input) - log P(w_2 | w_1, input)-log P(w_3 | w_1, w_2, input) ... P(w_i|w_{i-1}, ..., input)$

where $P(w_i|w_{i-1}, ..., input)$ is the probability assigned by the model (decoder) for the word $w_i$  when fed the "input" document and the previously generated words $w_{i-1}, w_{i-2}, ... , w_1$.
Thus, the loss reflects the probability of the model producing all the words in the question given the document as input.

In [13]:
# Sample queries from MS-MARCO + random questions
questions = ["what was the goal of the manhattan project",
             "who was briefed by president on the manhattan project", 
             "what was the manhattan project", 
             "who led the development of the atomic bomb",

             # Random questions
             "Efforts to support sustained education, access, and capacity building in the area of ethics",
             "does she like apples",
             "what organs are in the pancreas",
             "what is my favorite color", 
             "how many days until christmas"]
document = 'The Manhattan Project was the name for a project conducted during World War II, to develop the first atomic bomb. It refers specifically to the period of the project from 194 â¦ 2-1946 under the control of the U.S. Army Corps of Engineers, under the administration of General Leslie R. Groves.'
eval(document, questions)

print("\n")

document = 'Manhattan Project. The Manhattan Project was a research and development undertaking during World War II that produced the first nuclear weapons. It was led by the United States with the support of the United Kingdom and Canada. From 1942 to 1946, the project was under the direction of Major General Leslie Groves of the U.S. Army Corps of Engineers. Nuclear physicist Robert Oppenheimer was the director of the Los Alamos Laboratory that designed the actual bombs. The Army component of the project was designated the'
eval(document, questions)





It looks like relavant queries are ranking higher. Now we can check if queries related to specific documents in the CORD-19 dataset are ranking higher than the other unrelated queries. The top document associated with each query is retreived from [Covidex](https://covidex.ai/).

There are also groups of queries that are created by breaking down the long form query into more specific "who, what, how" questions that the doc2query-T5 model was trained on.

In [18]:
questions = [
            # Queries specific to documents
            "Tools and studies to monitor phenotypic change and potential adaptation of the virus",

            "Research on the economic impact of this or any pandemic. This would include identifying policy and programmatic alternatives that lessen/mitigate risks to critical government services",
            "What is the economic impact of a pandemic",
            "What is the financial impact of a pandemic",
            "How can the economic impact of pandemic be reduced",

            "Real-time tracking of whole genomes and a mechanism for coordinating the rapid dissemination of that information to inform the development of diagnostics and therapeutics and to track variations of the virus over time",
            "How to track the variations of the virus over time",
            "How can monitoring of whole genomes help the development of diagnostics and therapeutics",

            "Methods evaluating potential complication of Antibody-Dependent Enhancement (ADE) in vaccine recipients.",

            "Public health mitigation measures that could be effective for control",

            "Best telemedicine practices, barriers and faciitators, and specific actions to remove/expand them within and across state boundaries",
            "What are the best pratices for telemedicine",
            "What are the barriers for telemedicine",


            # Several unrealted queries from MS-MARCO 
            "what was the goal of the manhattan project",
            "why do we use cookies at chipotle",
            "who plays the team in quidditch",
             ]


# Pairs of target query and doc
for document in [("Tools and studies to monitor phenotypic change and potential adaptation of the virus", "Moving away from genome scan methods used for human GWAS (ultimately inappropriate for the short highly polymorphic genomes of RNA viruses), our work shows the power and potential of multi-class machine learning algorithms in inferring the functional genetic changes associated with phenotypic change (e.g. crossing a species barrier). We show that even distantly related viruses within a viral family share highly conserved genetic signatures of host specificity; reinforce how fitness landscapes of host adaptation are shaped by host phylogeny; and highlight the evolutionary trajectories of RNA viruses in rapid expansion and under great evolutionary pressure. We do so by (for each dataset) unveiling a set of phenotype characteristic mutations which are shown to be functionally relevant, thus providing new insights into phenotypic relationships between RNA viruses. These methods also provide a solid statistical framework with which the degree of host adaptation can be inferred, thus serving as a valuable tool for studying host transition events with particular relevance for emerging infectious diseases. These methods can then serve as rigorous tools of emergence potential assessment, specifically in scenarios where rapid host classification of newly emerging viruses can be more important than identifying putative functional sites."),
                ("Research on the economic impact of this or any pandemic. This would include identifying policy and programmatic alternatives that lessen/mitigate risks to critical government services", "Mitigation of a severe influenza pandemic can be achieved using a range of interventions to reduce transmission. Interventions can reduce the impact of an outbreak and buy time until vaccines are developed, but they may have high social and economic costs. The non-linear effect on the epidemic dynamics means that suitable strategies crucially depend on the precise aim of the intervention. National pandemic influenza plans rarely contain clear statements of policy objectives or prioritization of potentially conflicting aims, such as minimizing mortality (depending on the severity of a pandemic) or peak prevalence or limiting the socio-economic burden of contact-reducing interventions. We use epidemiological models of influenza A to investigate how contact-reducing interventions and availability of antiviral drugs or pre-pandemic vaccines contribute to achieving particular policy objectives. Our analyses show that the ideal strategy depends on the aim of an intervention and that the achievement of one policy objective may preclude success with others, e.g., constraining peak demand for public health resources may lengthen the duration of the epidemic and hence its economic and social impact. Constraining total case numbers can be achieved by a range of strategies, whereas strategies which additionally constrain peak demand for services require a more sophisticated intervention. If, for example, there are multiple objectives which must be achieved prior to the availability of a pandemic vaccine (i.e., a time-limited intervention), our analysis shows that interventions should be implemented several weeks into the epidemic, not at the very start."),
             ("Real-time tracking of whole genomes and a mechanism for coordinating the rapid dissemination of that information to inform the development of diagnostics and therapeutics and to track variations of the virus over time", "In recent decades, many infectious diseases have significantly increased in incidence and/or geographic range, in some cases impacting heavily on human, animal or plant populations. Some of these ‘emerging infectious diseases’ are associated with pathogens that have appeared in populations for the first time as a result of cross-species transmission (e.g. human immunodeficiency virus—acquired immunodeficiency syndrome (HIV-AIDS), severe acute respiratory syndrome (SARS)), while others were previously known but are rapidly increasing in incidence or geographic range as a result of underlying epidemiological changes (e.g. multi-drug resistant Staphylococcus aureus (MRSA) infection, dengue, West Nile encephalitis, foot and mouth disease, cassava mosaic disease). The latter include prominent diseases as tuberculosis, malaria and yellow fever that were once on the decline but are now ‘re-emerging diseases’."),
             ("Methods evaluating potential complication of Antibody-Dependent Enhancement (ADE) in vaccine recipients.", "Immune enhancement (antibody-dependent enhancement, ADE) has been clearly shown to occur in experimental laboratory infections of cats previously infected by natural or experimental infection, and of cats previously vaccinated with Primucell FIP vaccine, experimental MLV vaccines, experimental inactivated vaccines, and experimental recombinant vaccines containing the S gene (McArdle et al., 1992, 1995; Ngichabe, 1992; Scott et al., 1992, 1995a,b; Weiss and Scott, 1981). Antibodies to the S protein produced by the host result in enhanced infection of macrophages via Fc receptors, and the infected macrophages then transport the virus throughout the body. In the enhanced infection there is a decrease in incubation time—as short as 1–2 days—after exposure to virulent FIPV. The relative amount of virus and antibodies is important in order for ADE to occur. Higher concentrations of antibody neutralize the virus, but as the concentration of antibody decreases a concentration occurs where enhanced infection results. Other related coronaviruses can cause enhanced FCoV infection in the cat, including CCV."),
                           ("Public health mitigation measures that could be effective for control", "The novel coronavirus disease 2019 (COVID-19) outbreak on the Diamond Princess ship has caused over 634 cases as of February 20, 2020. We model the transmission process on the ship with a stochastic model and estimate the basic reproduction number at 2.2 (95%CI: 2.1−2.4). We estimate a large dispersion parameter than other coronaviruses, which implies that the virus is difficult to go extinction. The epidemic doubling time is at 4.6 days (95%CI: 3.0−9.3), and thus timely actions were crucial. The lesson learnt on the ship is generally applicable in other settings."),
              ("Best telemedicine practices, barriers and faciitators, and specific actions to remove/expand them within and across state boundaries", "Even before the arrival of COVID-19, telemedicine was increasingly being adopted to bring specialty-palliative care into the homes of seriously ill patients and their families. Patients who receive palliative care by telemedicine are typically very satisfied with the convenience and timesaving of video care. Telemedicine also saves valuable drive-time for home-visiting palliative care clinicians and increases capacity at brick-and-mortar clinics.1 With the emergence of COVID-19, telemedicine has been catapulted into the role of a critically essential service for patients to help mitigate the spread of COVID19 and preserve valuable personal protective equipment. For example, the University of California, SanFrancisco (UCSF) has mandated telemedicine be used to care for palliative care and nonpalliative care patients in ambulatory settings, whenever possible. Similarly, many hospice agencies are currently offering most, if not all, social work and chaplaincy support by telemedicine. For hospitals, strict limitations on visitors have meant that some inpatient palliative care consult programs are performing family meetings and consults virtually. To support these changes, many telemedicine")
             ]:
  eval(document[1], questions, target_ques=document[0])
  print("")




















We can query different types of questions over may CORD-19 documents to see if there are certain queries the model favors.

We see that some queries from the MS-MARCO training set and the "childhood death" not from the training set consistently rank near the top.

In [0]:
questions = [
             # Specific to documents
             "what is the leading cuase of childhood death in the world", # Doc 3
             "how can regulation help reduce foreign pathogens", # Doc 2
             
             # CORD-19 topic questions
             "what new drugs are being devloped?", 
             "effectiveness of drugs being developed and tried to treat COVID-19 patients.", 
             "exploration of use of best animal models and their predictive value for a human vaccine.",
             "capabilities to discover a therapeutic (not vaccine) for the disease, and clinical effectiveness studies to discover therapeutics, to include antiviral agents.",
             "natural history of the virus and shedding of it from an infected person",
             "implementation of diagnostics and products to improve clinical processes",

            # What, how style questions
            "what is the incubation period of COVID-19?",
            "what is the effectiveness of chloroquine for COVID-19?",
            "what is the duration of viral shedding for COVID-19?",
            "how does COVID-19 bind to the ACE2 receptor?",
            "how do weather conditions affect the transmission of COVID-19?",
            "tell me about IgG and IgM tests for COVID-19.",
            "what is the prognostic value of IL-6 levels in COVID-19?",
             
             # Predicted queries from MS-MARCO
             "what was the goal of the manhattan project",
             "who was briefed by president on the manhattan project",
             "why do we use cookies at chipotle",
             "who plays the team in quidditch",
             
             # Random
             "what is my favorite color", 
             "how many days until christmas"]


for index, row in df_covid.iterrows():
  if index > 200:
    break
  display(HTML(f"<p><b>Title:</b> {row['title']}</p>"))
  document = row['abstract']
  eval(document, questions)
  print("")










































































































































































































































































































































































































































































Now let's compute the sum of the losses for a question over part of the dataset. We can check whether "what is the financial impact of a pandemic" or "what was the goal of the manhattan project" has a lower loss.

This will also allow us to estimate the inference speed for scoring one query over the entire CORD-19 dataset.

In [37]:
import time
from tqdm.notebook import tqdm

test_size = 10000

question = 'what is the financial impact of a pandemic'
question_ids = tokenizer.encode(question, return_tensors="pt")
total_loss = 0

for doc in tqdm(df_covid['abstract'][:test_size]):
  doc = doc[:max_sequence_len]
  input_ids = encode(doc)

  outputs = model(input_ids.to(device), 
                lm_labels=question_ids.to(device))
  
  # Outputs of forward pass
  # Prediction scores for each vocabulary token before SoftMax
  # lm_labels provided to return loss
  loss, prediction_scores = outputs[:2]
  total_loss += loss

  # if i % 100 == 0:
  #   print(f"Document #{i} Loss: {loss} Time Elapsed: {time.time() - start} sec")


print(f"Question: {question} Mean loss: {total_loss / test_size}")

HBox(children=(IntProgress(value=0, max=10000), HTML(value='')))

RuntimeError: ignored

And gauge the improvement with batching.

In [31]:
# # For CUDA out of memory error
# torch.cuda.empty_cache()
# import gc
# gc.collect()

question = 'Effectiveness of movement control strategies to prevent secondary transmission in health care and community settings'

for i in tqdm(range(0, test_size, batch_size)):
  docs = df_covid['abstract'][i:i+batch_size]
  
  # Encode docs and find max seqeunce len for padding
  all_input_ids = []
  max_len_batch = 0
  for d in docs:
    # Encode and remove batch dimension
    input_ids = encode(d)[0]
    max_len_batch = max(max_len_batch, input_ids.size()[0])
    all_input_ids.append(input_ids)

  # Pad list of variable length tensors
  all_input_ids = torch.nn.utils.rnn.pad_sequence(all_input_ids, batch_first=True)
  
  # Copy of question_ids for each input_ids in batch
  # Scoring same question for each input_id, probably better way to do this
  all_question_ids = [encode(question)[0] for _ in range(batch_size)]
  question_ids = torch.nn.utils.rnn.pad_sequence(all_question_ids, batch_first=True)

  # TODO: How to get the loss for each doc in batch? Currently only getting loss for entire batch
  # All inputs (batch_size, sequence_length)
  outputs = model(all_input_ids.to(device), 
                lm_labels=question_ids.to(device))
  
  # Outputs of forward pass
  # Prediction scores for each vocabulary token before SoftMax
  # lm_labels provided to return loss
  # loss, prediction_scores = outputs[:2]
  # if i % (100 / batch_size) == 0:
  #   print(f"Document #{i * batch_size} Loss: {loss} Time Elapsed: {time.time() - start} sec")

HBox(children=(IntProgress(value=0, max=200), HTML(value='')))




RuntimeError: ignored

# Testing Hugging Face API

As a sanity check we can see if the HF api + the doc2query-T5 checkpoint is predicting queries similar to the ones generated by the [doc2query model and T5 CLI](https://github.com/castorini/docTTTTTquery#t5-inference-predicting-queries-from-passages).

In [0]:
%%capture
# Predicted questions
!curl -o predicted_queries_topk_sampling.zip "https://storage.googleapis.com/doctttttquery_git/predicted_queries_topk_sampling.zip"
!unzip predicted_queries_topk_sampling.zip -d "predicted_queries"
!rm -f predicted_queries_topk_sampling.zip

# MS MARCO Dataset
!curl "https://storage.googleapis.com/doctttttquery_git/collection.tar.gz" --output collection.tar.gz
!tar -xvf collection.tar.gz
!rm collection.tar.gz

In [0]:
import pandas as pd

df = pd.read_csv("collection.tsv",sep='\t', header=None)

In [0]:
pd.options.display.max_colwidth = 200
df.head(20)

Compare the generated queries.

In [0]:
docs_to_test = 5
num_questions = 5

for i, doc in enumerate(df[1][:docs_to_test]):
  doc_token_ids = tokenizer.encode(doc, return_tensors="pt")
  
  greedy_outputs = model.generate(
    doc_token_ids,
    do_sample=True,
    max_length=64,
    top_k=10,
    num_return_sequences=5
  )

  for j, sample_output in enumerate(greedy_outputs):
    print("{}: {}".format(j, tokenizer.decode(sample_output, skip_special_tokens=True)))

  print("\n ---- doc2query-T5 predictions ---- \n")
  for j in range(num_questions):
    with open(f"predicted_queries/predicted_queries_topk_sample00{j}.txt000-1004000", 'r') as qs:
      for line_num, q in enumerate(qs):
        if line_num == i:
          print(q.strip())
          break

  print("-"*50)