In [1]:
import torch
torch.cuda.is_available()

False

In [2]:
import pyterrier as pt
import numpy as np
import pandas as pd

from tqdm import tqdm
import zipfile
import glob
import ir_datasets
if not pt.started():
  pt.init()

from pyterrier_t5 import MonoT5ReRanker

PyTerrier 0.9.2 has loaded Terrier 5.7 (built by craigm on 2022-11-10 18:30) and terrier-helper 0.0.7

No etc/terrier.properties, using terrier.default.properties for bootstrap configuration.


In [3]:
dataset_nyt = ir_datasets.load("nyt")
dataset_wapo = ir_datasets.load("wapo/v2")

In [4]:
index_ref_nyt = pt.IndexRef.of("/app/indices/nyt/")
index_ref_wapo = pt.IndexRef.of("/app/indices/wapo/")

In [5]:
monoT5 = MonoT5ReRanker(text_field="body", batch_size=100, verbose=True)

bm25 = pt.BatchRetrieve(index_ref_wapo , wmodel='BM25', num_results=200)
mono_pipeline = bm25 >> pt.text.get_text(index_ref_wapo, "body") >> monoT5

mono_pipeline_500 = pt.BatchRetrieve(index_ref_wapo , wmodel='BM25', num_results=500) >>  pt.text.get_text(index_ref_wapo, "body") >> monoT5
mono_pipeline_50 = pt.BatchRetrieve(index_ref_wapo , wmodel='BM25', num_results=50) >>  pt.text.get_text(index_ref_wapo, "body") >> monoT5

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. If you see this, DO NOT PANIC! This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=True`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [5]:
ds1_nyt = pt.get_dataset('irds:nyt/trec-core-2017')
ds1_wapo = pt.get_dataset('irds:wapo/v2/trec-core-2018')

In [7]:
ds1_nyt.get_qrels()

Unnamed: 0,qid,docno,label,iteration
0,307,1001536,1,0
1,307,1002887,1,0
2,307,1005682,0,0
3,307,1007340,1,0
4,307,101295,1,0
...,...,...,...,...
30025,690,991749,0,0
30026,690,993504,0,0
30027,690,994345,0,0
30028,690,995425,0,0


In [8]:
for i, row in ds1_nyt.get_topics('title').iterrows():
    query = row['query']

    with open(f"/workspace/data/nyt/topics/topic.{i+1}", "w") as f:
        f.write(query)

[INFO] [starting] https://trec.nist.gov/data/core/core_nist.txt
[INFO] [finished] https://trec.nist.gov/data/core/core_nist.txt: [00:00] [24.4kB] [229kB/s]
                                                                         

In [8]:
ds1_nyt.get_topics('title')

Unnamed: 0,qid,query
0,307,new hydroelectric projects
1,310,radio waves and brain cancer
2,321,women in parliaments
3,325,cult lifestyles
4,330,iran iraq cooperation
5,336,black bear attacks
6,341,airport security
7,344,abuses of e mail
8,345,overseas tobacco sales
9,347,wildlife extinction


In [7]:
ds1_nyt.get_topics()

There are multiple query fields available: ('title', 'description', 'narrative'). To use with pyterrier, provide variant or modify dataframe to add query column.


Unnamed: 0,qid,title,description,narrative
0,307,New Hydroelectric Projects,Identify hydroelectric projects proposed or un...,Relevant documents would contain as a minimum ...
1,310,Radio Waves and Brain Cancer,Evidence that radio waves from radio towers or...,Persons living near radio towers and more rece...
2,321,Women in Parliaments,Pertinent documents will reflect the fact that...,Pertinent documents relating to this issue wil...
3,325,Cult Lifestyles,Describe a cult by name and identify the cult ...,A relevant document would include the name of ...
4,330,Iran-Iraq Cooperation,This query is looking for examples of cooperat...,A relevant document would mention such things ...
5,336,Black Bear Attacks,A relevant document would discuss the frequenc...,It has been reported that food or cosmetics so...
6,341,Airport Security,A relevant document would discuss the effectiv...,A relevant document would contain reports on w...
7,344,Abuses of E-Mail,The availability of E-mail to many people thro...,"To be relevant, a document will concern dissat..."
8,345,Overseas Tobacco Sales,Health studies primarily in the U.S. have caus...,"To be relevant, an item will discuss either an..."
9,347,Wildlife Extinction,The spotted owl episode in America highlighted...,"A relevant item will specify the country, the ..."


In [11]:
title_queries = ""
for i, row in ds1_nyt.get_topics('title').iterrows():
    i_id = i+1
    query = row['query']

    line = f"1,1,{row['qid']},{query}\n"
    title_queries += line

with open(f"/workspace/data/nyt/title_queries", "w") as f:
    f.write(title_queries)

In [9]:
full_topics = ""
for i, row in ds1_nyt.get_topics().iterrows():
    title = row['title']
    description = row['description']
    narrative = row['narrative']

    line = f"1,1,{row['qid']},{title},{description},{narrative}\n"
    full_topics += line

with open(f"/workspace/data/nyt/full_topics", "w") as f:
    f.write(full_topics)

There are multiple query fields available: ('title', 'description', 'narrative'). To use with pyterrier, provide variant or modify dataframe to add query column.


In [14]:
#qrels
wapo_qrel_path = "/workspace/data/wapo/wapo_qrels1"
qrels = ""
for i, row in ds1_wapo.get_qrels().iterrows():
    label = 1 if int(row['label']) > 0 else 0
    line = f"{row['qid']} 0 {row['docno']} {label}\n"
    qrels+= line

with open(wapo_qrel_path, "w") as f:
    f.write(qrels)

In [7]:
pt.Experiment(
    [bm25],
    ds1_wapo.get_topics('title'),
    ds1_wapo.get_qrels(),
    eval_metrics=["map", "recip_rank", "P_10", "ndcg_cut_10"],
    names=["BM25"],
    verbose=True
)

[INFO] [starting] https://trec.nist.gov/data/core/topics2018.txt
[INFO] [finished] https://trec.nist.gov/data/core/topics2018.txt: [00:00] [24.1kB] [59.9MB/s]
[INFO] [starting] https://trec.nist.gov/data/core/qrels2018.txt            
[INFO] [finished] https://trec.nist.gov/data/core/qrels2018.txt: [00:00] [1.12MB] [1.28MB/s]
pt.Experiment: 100%|██████████| 1/1 [00:10<00:00, 10.95s/system]          


Unnamed: 0,name,map,recip_rank,P_10,ndcg_cut_10
0,BM25,0.168733,0.663436,0.404,0.37107


In [9]:
pt.Experiment(
    [mono_pipeline],
    ds1_wapo.get_topics('title'),
    ds1_wapo.get_qrels(),
    eval_metrics=["map", "recip_rank", "P_10", "ndcg_cut_10"],
    names=["MonoT5"],
    verbose=True
)

monoT5: 100%|██████████| 100/100 [03:29<00:00,  2.10s/batches]
pt.Experiment: 100%|██████████| 1/1 [03:42<00:00, 222.32s/system]


Unnamed: 0,name,map,recip_rank,P_10,ndcg_cut_10
0,MonoT5,0.20841,0.670647,0.476,0.446702


In [None]:
pt.Experiment(
    [mono_pipeline_500],
    ds1_wapo.get_topics('title'),
    ds1_wapo.get_qrels(),
    eval_metrics=["map", "recip_rank", "P_10", "ndcg_cut_10"],
    names=["MonoT5"],
    verbose=True
)

In [10]:
bm25.search("women in parliament")

Unnamed: 0,qid,docid,docno,rank,score,query
0,1,563223,f233ecdeb87a44a6aa9ac429999d2d4c,0,19.370159,women in parliament
1,1,486152,4a7c2970fd9bf65fe09c7cf46df7b06d,1,18.690675,women in parliament
2,1,486153,9171debc316e5e2782e0d2404ca7d09d,2,18.690675,women in parliament
3,1,352722,34d443eec1add515a2fbc4af2c8a3a57,3,18.372576,women in parliament
4,1,546574,f1ab493726e1e6f5dd90615d5a1b58b8,4,18.340060,women in parliament
...,...,...,...,...,...,...
195,1,351935,a64ab05765cb2b25a6f18c03b20f5a3a,195,13.074723,women in parliament
196,1,351936,8d0f44ec22604cd5c08d74fc8ffa7cf4,196,13.074723,women in parliament
197,1,351954,56b94f8fd53b63608931a373f813b7b1,197,13.074421,women in parliament
198,1,305768,583e232f566fe972d5d1c1e5652bbb75,198,13.058146,women in parliament


In [9]:
res = mono_pipeline_50.search("women in parliament")

monoT5: 100%|██████████| 1/1 [00:21<00:00, 21.10s/batches]


In [None]:
pt.Experiment(
    [mono_pipeline_50
    ],
    ds1_wapo.get_topics('title')[:2],
    ds1_wapo.get_qrels(),
    eval_metrics=["map", "recip_rank", "P_10", "ndcg_cut_10"],
    names=["MonoT5"],
    verbose=True
)