In [1]:
from src.RawDataLoaders import *
from src.pipe_datasets import *
from src.models_and_transforms.run_file_models import Run_File_Searcher
from src.models_and_transforms.BERT_models import BERT_Reranker
from src.models_and_transforms.Longformer_models import Longformer_Reranker
from src.models_and_transforms.BM25_models import BM25_Ranker
from src.Experiments import CAsT_experiment, Ranking_Experiment, RUN_File_Transform_Exporter
from src.trainers import Model_Trainer
from src.models_and_transforms.complex_transforms import *
from src.models_and_transforms.text_transforms import *
from src.models_and_transforms.complex_transforms import BART_Query_Rewriter_Transform, BART_Full_Conversational_Rewriter_Transform

from transformers import LongformerConfig, LongformerModel, LongformerTokenizer, BertTokenizer, BertModel, BertForSequenceClassification
from pytorch_lightning import Trainer, Callback, seed_everything
import pickle
import random
import numpy as np
import json
import os
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import torch
import jsonlines
seed_everything(42)

from tqdm.auto import tqdm 
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

# Creating the collection

In [2]:
CAsT_raw_data_loader = CAsT_RawDataLoader()
get_query_fn = CAsT_raw_data_loader.get_query
get_doc_fn = CAsT_raw_data_loader.get_doc
CAsT_q_rels = CAsT_raw_data_loader.q_rels

# Manual BM25 @ 1k

### Y1 data

In [5]:
samples = CAsT_raw_data_loader.get_topics("all")
samples = Query_Resolver_Transform(get_query_fn, utterance_type="manual_rewritten_utterance")(samples)
samples = BM25_Search_Transform(index_dir='datasets/TREC_CAsT/CAsT_collection_with_meta.index', 
                                          key_fields={'query_field':'query', 'target_field':'search_results'},
                                          hits=1000)(samples)
expr = Ranking_Experiment(CAsT_q_rels)
print("Manual Y1 BM25 @ 1k results")
print([f"{metric}:{score:.3f}"for metric, score in expr(samples).items()])
run_file_exporter = RUN_File_Transform_Exporter('saved_models/BM25/y1_manual_BM25_1k.run', model_name='manual_BM25')
run_file_exporter(samples)

HBox(children=(FloatProgress(value=0.0, description='Searching queries', max=163.0, style=ProgressStyle(descri…


Manual Y1 BM25 @ 1k results
['map:0.198', 'recip_rank:0.448', 'ndcg_cut_3:0.289', 'set_recall:0.844']


HBox(children=(FloatProgress(value=0.0, description='Writing to RUN file', max=163.0, style=ProgressStyle(desc…

Successfully written 163000 samples from 163 queries run to: saved_models/BM25/y1_manual_BM25_1k.run


### Y2 data

In [8]:
samples = CAsT_raw_data_loader.get_topics("eval")
samples = Query_Resolver_Transform(get_query_fn, utterance_type="manual_rewritten_utterance")(samples)
samples = BM25_Search_Transform(index_dir='datasets/TREC_CAsT/CAsT_collection_with_meta.index', 
                                          key_fields={'query_field':'query', 'target_field':'search_results'},
                                          hits=1000)(samples)

run_file_exporter = RUN_File_Transform_Exporter('saved_models/BM25/y2_manual_BM25_1k.run', model_name='manual_BM25')
run_file_exporter(samples)

HBox(children=(FloatProgress(value=0.0, description='Searching queries', max=216.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Writing to RUN file', max=216.0, style=ProgressStyle(desc…

Successfully written 216000 samples from 216 queries run to: saved_models/BM25/y2_manual_BM25_1k.run


# Manual BM25@1k + monoBERT

### Y1 data

In [11]:
samples = CAsT_raw_data_loader.get_topics("all")
samples = Query_Resolver_Transform(get_query_fn, utterance_type="manual_rewritten_utterance")(samples)
samples = BM25_Search_Transform(index_dir='datasets/TREC_CAsT/CAsT_collection_with_meta.index', 
                                          key_fields={'query_field':'query', 'target_field':'search_results'},
                                          hits=1000)(samples)
samples = MonoBERT_ReRanker_Transform('saved_models/monoBERT/', get_doc_fn, device="cuda:0", 
                                                          key_fields={'query_field':'query', 'target_field':'search_results'},
                                                          batch_size=100)(samples)
expr = Ranking_Experiment(CAsT_q_rels)
print("Manual BM25+monoBERT @ 1k results")
print([f"{metric}:{score:.3f}"for metric, score in expr(samples).items()])
run_file_exporter = RUN_File_Transform_Exporter('saved_models/monoBERT/y1_manual_BM25_monoBERT_1k.run', model_name='manual_BM25_monoBERT')
run_file_exporter(samples)

HBox(children=(FloatProgress(value=0.0, description='Searching queries', max=3.0, style=ProgressStyle(descript…


Loading chekcpoint from saved_models/monoBERT/
MonoBERT ReRanker initialised on device cuda:0. Batch size 100


HBox(children=(FloatProgress(value=0.0, description='Reranking queries', max=3.0, style=ProgressStyle(descript…


Manual BM25+monoBERT @ 1k results
['map:0.480', 'recip_rank:0.833', 'ndcg_cut_3:0.844', 'set_recall:0.884']


HBox(children=(FloatProgress(value=0.0, description='Writing to RUN file', max=3.0, style=ProgressStyle(descri…

Successfully written 3000 samples from 3 queries run to: saved_models/monoBERT/y1_manual_BM25_monoBERT_1k.run


Manual BM25+monoBERT @ 1k results
['map:0.352', 'recip_rank:0.658', 'ndcg_cut_3:0.511', 'set_recall:0.844']

### Y2 data

In [None]:
samples = CAsT_raw_data_loader.get_topics("eval")
samples = Query_Resolver_Transform(get_query_fn, utterance_type="manual_rewritten_utterance")(samples)
samples = BM25_Search_Transform(index_dir='datasets/TREC_CAsT/CAsT_collection_with_meta.index', 
                                          key_fields={'query_field':'query', 'target_field':'search_results'},
                                          hits=1000)(samples)
samples = MonoBERT_ReRanker_Transform('saved_models/monoBERT/', get_doc_fn, device="cuda:0", 
                                                          key_fields={'query_field':'query', 'target_field':'search_results'},
                                                          batch_size=100)(samples)

run_file_exporter = RUN_File_Transform_Exporter('saved_models/monoBERT/y2_manual_BM25_monoBERT_1k.run', model_name='manual_BM25_monoBERT')
run_file_exporter(samples)

# Manual BM25@1k + monoBERT + DuoBERT

### Y1

In [3]:
samples = CAsT_raw_data_loader.get_topics("all")
samples = Query_Resolver_Transform(get_query_fn, utterance_type="manual_rewritten_utterance")(samples)
samples = RUN_File_Search_Transform('saved_models/monoBERT/y1_manual_BM25_monoBERT_1k.run', hits=1000)(samples)

samples = DuoBERT_ReRanker_Transform("saved_models/duoBERT/", get_doc_fn, rerank_top=10, device="cuda:1")(samples)
for sample in samples:
    sample["search_results"] = sample["reranked_results"]

expr = Ranking_Experiment(CAsT_q_rels)
print("Manual BM25+monoBERT+duoBERT @ 1k results")
print([f"{metric}:{score:.3f}"for metric, score in expr(samples).items()])
run_file_exporter = RUN_File_Transform_Exporter('saved_models/duoBERT/y1_manual_BM25_mono_duoBERT_1k.run', model_name='manual_BM25_mono_duoBERT')
run_file_exporter(samples)

NameError: name 'CAsT_raw_data_loader' is not defined

Manual BM25+monoBERT+duoBERT @ 1k results
['map:0.322', 'recip_rank:0.681', 'ndcg_cut_3:0.556', 'set_recall:0.844']

### Y2

In [9]:
samples = CAsT_raw_data_loader.get_topics("eval")
samples = Query_Resolver_Transform(get_query_fn, utterance_type="manual_rewritten_utterance")(samples)
samples = RUN_File_Search_Transform('saved_models/monoBERT/y2_manual_BM25_monoBERT_1k.run', hits=1000)(samples)

samples = DuoBERT_ReRanker_Transform("saved_models/duoBERT/", get_doc_fn, rerank_top=10, device="cuda:2")(samples)
for sample in samples:
    sample["search_results"] = sample["reranked_results"]

run_file_exporter = RUN_File_Transform_Exporter('saved_models/duoBERT/y2_manual_BM25_mono_duoBERT_1k.run', model_name='manual_BM25_mono_duoBERT')
run_file_exporter(samples)

HBox(children=(FloatProgress(value=0.0, description='Searching queries', max=2.0, style=ProgressStyle(descript…


Loading chekcpoint from saved_models/duoBERT/
DuoBERT ReRanker initialised on cuda. Batch size 32


HBox(children=(FloatProgress(value=0.0, description='Reranking queries', max=2.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Writing to RUN file', max=2.0, style=ProgressStyle(descri…

Successfully written 2000 samples from 2 queries run to: saved_models/duoBERT/y2_manual_BM25_mono_duoBERT_1k.run


# Relevance Models query rewrites

### Y1

In [14]:
samples = CAsT_raw_data_loader.get_topics("all")
samples = Query_Resolver_Transform(get_query_fn, utterance_type="manual_rewritten_utterance")(samples)
samples = RUN_File_Search_Transform('saved_models/duoBERT/y1_manual_BM25_mono_duoBERT_1k.run', hits=1000)(samples)

samples = Relevance_Model_Transform(get_doc_fn, top_k=20)(samples)
samples = Simple_Query_Expansion_Transform(top_k=20)(samples)

json.dump(samples, open('saved_models/Relevance Models CAsT/y1_manual_duoBERT_Relevance_Terms.json', 'w'))

HBox(children=(FloatProgress(value=0.0, description='Searching queries', max=163.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Relevance Model', max=163.0, style=ProgressStyle(descript…




### Y2

In [16]:
samples = CAsT_raw_data_loader.get_topics("eval")
samples = Query_Resolver_Transform(get_query_fn, utterance_type="manual_rewritten_utterance")(samples)
samples = RUN_File_Search_Transform('saved_models/duoBERT/y2_manual_BM25_mono_duoBERT_1k.run', hits=1000)(samples)

samples = Relevance_Model_Transform(get_doc_fn, top_k=20)(samples)
samples = Simple_Query_Expansion_Transform(top_k=20)(samples)

json.dump(samples, open('saved_models/Relevance Models CAsT/y2_manual_duoBERT_Relevance_Terms.json', 'w'))

HBox(children=(FloatProgress(value=0.0, description='Searching queries', max=216.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Relevance Model', max=216.0, style=ProgressStyle(descript…




# Jeff SDM runs on monoBERT and duoBERT

### Y1

In [20]:
samples = CAsT_raw_data_loader.get_topics("all")
samples = Query_Resolver_Transform(get_query_fn, utterance_type="manual_rewritten_utterance")(samples)
samples = RUN_File_Search_Transform('saved_models/Jeff_SDM/y1_test_manual_cur_turn_sdm.run', hits=1000)(samples)
samples = MonoBERT_ReRanker_Transform('saved_models/monoBERT/', get_doc_fn, device="cuda:0", 
                                                          key_fields={'query_field':'query', 'target_field':'search_results'},
                                                          batch_size=100)(samples)
samples = DuoBERT_ReRanker_Transform("saved_models/duoBERT/", get_doc_fn, rerank_top=10, device="cuda:0")(samples)
for sample in samples:
    sample["search_results"] = sample["reranked_results"]

expr = Ranking_Experiment(CAsT_q_rels)
print("Manual Jeff SDM+monoBERT+duoBERT @ 1k results")
print([f"{metric}:{score:.3f}"for metric, score in expr(samples).items()])
run_file_exporter = RUN_File_Transform_Exporter('saved_models/Jeff_SDM/y1_manual_SDM_mono_duoBERT_1k.run', model_name='manual_SDM_mono_duoBERT')
run_file_exporter(samples)

HBox(children=(FloatProgress(value=0.0, description='Searching queries', max=2.0, style=ProgressStyle(descript…


Loading chekcpoint from saved_models/monoBERT/
MonoBERT ReRanker initialised on device cuda:0. Batch size 100


HBox(children=(FloatProgress(value=0.0, description='Reranking queries', max=2.0, style=ProgressStyle(descript…


Loading chekcpoint from saved_models/duoBERT/
DuoBERT ReRanker initialised on cuda:0. Batch size 32


HBox(children=(FloatProgress(value=0.0, description='Reranking queries', max=2.0, style=ProgressStyle(descript…


Manual Jeff SDM+monoBERT+duoBERT @ 1k results
['map:0.476', 'recip_rank:1.000', 'ndcg_cut_3:1.000', 'set_recall:0.804']


HBox(children=(FloatProgress(value=0.0, description='Writing to RUN file', max=2.0, style=ProgressStyle(descri…

Successfully written 2000 samples from 2 queries run to: saved_models/Jeff_SDM/y1_manual_SDM_mono_duoBERT_1k.run


Manual Jeff SDM + mono + duoBERT @ 1k results
['map:0.311', 'recip_rank:0.673', 'ndcg_cut_3:0.550', 'set_recall:0.867']

### Y2

In [None]:
samples = CAsT_raw_data_loader.get_topics("eval")
samples = Query_Resolver_Transform(get_query_fn, utterance_type="manual_rewritten_utterance")(samples)
samples = RUN_File_Search_Transform('saved_models/Jeff_SDM/y2_test_manual_cur_turn_sdm.run', hits=1000)(samples)
samples = MonoBERT_ReRanker_Transform('saved_models/monoBERT/', get_doc_fn, device="cuda:2", 
                                                          key_fields={'query_field':'query', 'target_field':'search_results'},
                                                          batch_size=100)(samples)
samples = DuoBERT_ReRanker_Transform("saved_models/duoBERT/", get_doc_fn, rerank_top=10, device="cuda:2")(samples)
for sample in samples:
    sample["search_results"] = sample["reranked_results"]

run_file_exporter = RUN_File_Transform_Exporter('saved_models/Jeff_SDM/y2_manual_SDM_mono_duoBERT_1k.run', model_name='manual_SDM_mono_duoBERT')
run_file_exporter(samples)

HBox(children=(FloatProgress(value=0.0, description='Searching queries', max=216.0, style=ProgressStyle(descri…


Loading chekcpoint from saved_models/monoBERT/
MonoBERT ReRanker initialised on device cuda:2. Batch size 100


HBox(children=(FloatProgress(value=0.0, description='Reranking queries', max=216.0, style=ProgressStyle(descri…

# mono and duoBERT on Jeff y2_test_automatic_cur_turn_first_prevrm3.run

In [5]:
with jsonlines.open('saved_models/BART_Rewriter/clean_BART_y2_self_rewrites.jsonl') as reader:
    samples = []
    for sample_obj in reader:
        sample_obj["query"] = sample_obj["clean_rewritten_query"]
        samples.append(sample_obj)
samples = RUN_File_Search_Transform('saved_models/Jeff_SDM/y2_test_manual_cur_turn_sdm.run', hits=1000)(samples)

samples = MonoBERT_ReRanker_Transform('saved_models/monoBERT/', get_doc_fn, device="cuda:0", 
                                                          key_fields={'query_field':'query', 'target_field':'search_results'},
                                                          batch_size=100)(samples)
run_file_exporter = RUN_File_Transform_Exporter('saved_models/Jeff_SDM/y2_auto_cleanBART_SDM_context_fuse_monoBERT_1k.run', model_name='grill_fuseMono')
run_file_exporter(samples)

samples = DuoBERT_ReRanker_Transform("saved_models/duoBERT/", get_doc_fn, rerank_top=10, device="cuda:0")(samples)
for sample in samples:
    sample["search_results"] = sample["reranked_results"]

run_file_exporter = RUN_File_Transform_Exporter('saved_models/Jeff_SDM/y2_auto_cleanBART_SDM_context_fuse_duoBERT_1k.run', model_name='grill_fuseDuo')
run_file_exporter(samples)

HBox(children=(FloatProgress(value=0.0, description='Searching queries', max=216.0, style=ProgressStyle(descri…


Loading chekcpoint from saved_models/monoBERT/
MonoBERT ReRanker initialised on device cuda:0. Batch size 100


HBox(children=(FloatProgress(value=0.0, description='Reranking queries', max=216.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Writing to RUN file', max=216.0, style=ProgressStyle(desc…

Successfully written 216000 samples from 216 queries run to: saved_models/Jeff_SDM/y2_auto_cleanBART_SDM_context_fuse_monoBERT_1k.run
Loading chekcpoint from saved_models/duoBERT/
DuoBERT ReRanker initialised on cuda:0. Batch size 32


HBox(children=(FloatProgress(value=0.0, description='Reranking queries', max=216.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Writing to RUN file', max=216.0, style=ProgressStyle(desc…

Successfully written 216000 samples from 216 queries run to: saved_models/Jeff_SDM/y2_auto_cleanBART_SDM_context_fuse_duoBERT_1k.run


# y2 auto -> cleanBART -> BM25 -> monoBERT -> duoBERT

In [3]:
with jsonlines.open('saved_models/BART_Rewriter/clean_BART_y2_self_rewrites.jsonl') as reader:
    samples = []
    for sample_obj in reader:
        sample_obj["query"] = sample_obj["clean_rewritten_query"]
        samples.append(sample_obj)
        

samples = BM25_Search_Transform(index_dir='datasets/TREC_CAsT/CAsT_collection_with_meta.index', 
                                          key_fields={'query_field':'query', 'target_field':'search_results'},
                                          hits=1000)(samples)
samples = MonoBERT_ReRanker_Transform('saved_models/monoBERT/', get_doc_fn, device="cuda:2", 
                                                          key_fields={'query_field':'query', 'target_field':'search_results'},
                                                          batch_size=100)(samples)
run_file_exporter = RUN_File_Transform_Exporter('saved_models/monoBERT/y2_auto_cleanBART_BM25_monoBERT_1k.run', model_name='grill_auto_monobart')
run_file_exporter(samples)

samples = DuoBERT_ReRanker_Transform("saved_models/duoBERT/", get_doc_fn, rerank_top=10, device="cuda:2")(samples)
for sample in samples:
    sample["search_results"] = sample["reranked_results"]
run_file_exporter = RUN_File_Transform_Exporter('saved_models/duoBERT/y2_auto_cleanBART_BM25_mono_duoBERT_1k.run', model_name='grill_auto_duobart')
run_file_exporter(samples)

HBox(children=(FloatProgress(value=0.0, description='Searching queries', max=216.0, style=ProgressStyle(descri…


Loading chekcpoint from saved_models/monoBERT/
MonoBERT ReRanker initialised on device cuda:2. Batch size 100


HBox(children=(FloatProgress(value=0.0, description='Reranking queries', max=216.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Writing to RUN file', max=216.0, style=ProgressStyle(desc…

Successfully written 216000 samples from 216 queries run to: saved_models/monoBERT/y2_auto_cleanBART_BM25_monoBERT_1k.run
Loading chekcpoint from saved_models/duoBERT/
DuoBERT ReRanker initialised on cuda:2. Batch size 32


HBox(children=(FloatProgress(value=0.0, description='Reranking queries', max=216.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Writing to RUN file', max=216.0, style=ProgressStyle(desc…

Successfully written 216000 samples from 216 queries run to: saved_models/monoBERT/y2_auto_cleanBART_BM25_mono_duoBERT_1k.run


# Fully automatic single query run

In [3]:
BM25_transform = BM25_Search_Transform(index_dir='datasets/TREC_CAsT/CAsT_collection_with_meta.index', 
                                          key_fields={'query_field':'query', 'target_field':'search_results'},
                                          hits=500)
MonoBERT_transform = MonoBERT_ReRanker_Transform('saved_models/monoBERT/', get_doc_fn, device="cuda:0", 
                                                          key_fields={'query_field':'query', 'target_field':'search_results'},
                                                          batch_size=100)
DuoBERT_transform = DuoBERT_ReRanker_Transform("saved_models/duoBERT/", get_doc_fn, rerank_top=10, device="cuda:0")
BART_conv_transform = BART_Full_Conversational_Rewriter_Transform("saved_models/BART_Rewriter/BART_save_dict.ckpt", device="cuda:0")

Loading chekcpoint from saved_models/monoBERT/
MonoBERT ReRanker initialised on device cuda:0. Batch size 100
Loading chekcpoint from saved_models/duoBERT/
DuoBERT ReRanker initialised on cuda:0. Batch size 32
BERT ReRanker initialised on cuda:0. Batch size 1


In [15]:
# Is there a chunky recipe?
test_samples = [{'previous_queries':['how do you make strawberry  jam?'],
                'unresolved_query':'Can it be made without pectin?'}]
eval_raw_samples = BART_conv_transform(test_samples)
resolved_query = eval_raw_samples[0]['full_rewritten_queries'][-1]
print("Output:", resolved_query)

HBox(children=(FloatProgress(value=0.0, description='BART self feeding rewrites', max=1.0, style=ProgressStyle…


Output:  Can strawberry jam be made without pectin?


In [14]:
samples = [{'query':resolved_query}]
samples = BM25_transform(samples)
samples = MonoBERT_transform(samples)
samples = DuoBERT_transform(samples)
get_doc_fn(samples[0]['reranked_results'][0][0])

HBox(children=(FloatProgress(value=0.0, description='Searching queries', max=1.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Reranking queries', max=1.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Reranking queries', max=1.0, style=ProgressStyle(descript…




'For example, a peanut butter and jelly sandwich made with chunky peanut butter and strawberry jam on wheat bread is quite a bit different than one made with smooth peanut butter, grape jelly and white bread with the crusts cut off.'