In [None]:
from src.RawDataLoaders import MS_Marco_RawDataLoader, CAsT_RawDataLoader
from src.pipe_datasets import Manual_Query_BM25_Reranking_Dataset, Reranking_Validation_Dataset, Manual_Query_RUN_File_Reranking_Dataset
from src.models_and_transforms.run_file_models import Run_File_Searcher
from src.models_and_transforms.BERT_models import BERT_Reranker
from src.models_and_transforms.BM25_models import BM25_Ranker
from src.Experiments import CAsT_experiment, Ranking_Experiment, RUN_File_Transform_Exporter
from src.trainers import Model_Trainer
from src.models_and_transforms.complex_transforms import BERT_Score_Transform, BERT_ReRanker_Transform, BM25_Search_Transform, \
                                                        Oracle_ReRanker_Transform, RUN_File_Search_Transform
from src.models_and_transforms.text_transforms import Query_Resolver_Transform, Document_Resolver_Transform

from transformers import LongformerConfig, LongformerModel, LongformerTokenizer, BertTokenizer, BertModel
from pytorch_lightning import Trainer, Callback
import pickle
import random
import numpy as np
import os
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import torch

from tqdm.auto import tqdm 
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [2]:
def seed_everything(seed=1234):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything()

In [10]:
raw_data_loader = MS_Marco_RawDataLoader(from_pickle=True)
get_query_fn = raw_data_loader.get_query
get_doc_fn = raw_data_loader.get_doc

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='loading train queries', max=1.0, style=…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='loading dev queries', max=1.0, style=Pr…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='loading eval queries', max=1.0, style=P…

In [4]:
train_raw_samples = raw_data_loader.get_topics("train")

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='loading train queries', max=1.0, style=…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='loading q_rels', max=1.0, style=Progres…

In [5]:
searcher = Run_File_Searcher("saved_models/MARCO_train_BM25_full.run")

In [8]:
bm25_searcher = BM25_Ranker("datasets/MS_MARCO/MARCO_anserini")

In [20]:
bm25_searcher.predict('wonderful.in.spanish')

[]

#### Removing probllematic samples

In [6]:
bad_samples = [sample for sample in train_raw_samples if sample['q_id'] not in searcher.query_doc_mapping]

In [7]:
for sample in bad_samples:
    print(sample)
    train_raw_samples.remove(sample)

{'q_id': '140329', 'q_rel': ['MARCO_6542451']}
{'q_id': '1078982', 'q_rel': ['MARCO_4115897']}
{'q_id': '502557', 'q_rel': ['MARCO_1271975']}
{'q_id': '48509', 'q_rel': ['MARCO_2063851']}
{'q_id': '56573', 'q_rel': ['MARCO_3198289']}
{'q_id': '129844', 'q_rel': ['MARCO_7817031']}
{'q_id': '197820', 'q_rel': ['MARCO_5510763']}
{'q_id': '522517', 'q_rel': ['MARCO_3075528']}
{'q_id': '205266', 'q_rel': ['MARCO_5143713']}


In [10]:
len(bad_samples)

0

In [8]:
train_dataset = Manual_Query_RUN_File_Reranking_Dataset(train_raw_samples, get_query_fn, get_doc_fn, "saved_models/MARCO_train_BM25_full.run", hits=100, num_neg_samples=50)

HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='Searching queries', max=502930.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, description='Sampling ± query-doc pairs', max=502930.0, style=Progress…





In [37]:
train_dataset[500001]

{'q_id': '700091',
 'q_rel': ['MARCO_1725292'],
 'query': 'what is a simulcast',
 'd_id': 'MARCO_1725292',
 'label': 1,
 'doc': 'Simulcast Simulcast refers to the process of transmitting the same signal from different tower locations over the same frequency at the same time. For public safety communications, this typically means multiple towers: configured to transmit the exact same communications, on the exact same frequencies, at precisely the same time.',
 'input_text': 'what is a simulcast [SEP] Simulcast Simulcast refers to the process of transmitting the same signal from different tower locations over the same frequency at the same time. For public safety communications, this typically means multiple towers: configured to transmit the exact same communications, on the exact same frequencies, at precisely the same time.',
 'input_ids': [101,
  2054,
  2003,
  1037,
  20525,
  102,
  20525,
  20525,
  5218,
  2000,
  1996,
  2832,
  1997,
  23820,
  1996,
  2168,
  4742,
  2013,
  

In [None]:
train_dataset = Manual_Query_BM25_Reranking_Dataset(train_raw_samples[:2000000], get_query_fn, get_doc_fn, hits=150, num_neg_samples=100, index_dir="datasets/MS_MARCO/MARCO_anserini")

HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='Searching queries', max=502939.0, style=ProgressStyle(des…

In [28]:
train_dataloader = train_dataset.to_dataloader(16, num_workers=32, shuffle=True)

In [27]:
model = BERT_Reranker()

In [None]:
torch.save(model.state_dict(), "saved_models/BERT_ReRanker_MARCO.ckpt")

In [30]:
model.load_state_dict(torch.load("saved_models/BERT_reranker_q100k_h100_checkpoints/BERT_ReRanker_MARCO_from_valid_0.38551378521302254.ckpt"))

<All keys matched successfully>

In [53]:
samples = [{"input_ids":[55,66,33]},{"input_ids":[45,76,33]}]
score_transform = BERT_Score_Transform("saved_models/BERT_ReRanker_MARCO.ckpt")
score_transform(samples)

<All keys matched successfully>
BERT ReRanker initialised on cuda. Batch size 64


[{'input_ids': [55, 66, 33], 'score': 0.044405270367860794},
 {'input_ids': [45, 76, 33], 'score': 0.044405270367860794}]

In [132]:
samples = [{'q_id':"121352","query":"define extreme", 'search_results':[('MARCO_6237152', 0.6), ('MARCO_2912794', 0.6)]}]
rerank_transform = BERT_ReRanker_Transform("saved_models/BERT_reranker_q10k_h100_checkpoints/BERT_ReRanker_MARCO_from_valid_0.3465380570542395.ckpt", get_doc_fn)
rerank_transform(samples)

<All keys matched successfully>
BERT ReRanker initialised on cuda. Batch size 64


HBox(children=(FloatProgress(value=0.0, description='Reranking queries', max=1.0, style=ProgressStyle(descript…




[{'q_id': '121352',
  'query': 'define extreme',
  'search_results': [('MARCO_6237152', 0.6), ('MARCO_2912794', 0.6)],
  'reranked_results': [('MARCO_2912794', 0.992748498916626),
   ('MARCO_6237152', -0.01869625225663185)]}]

In [11]:
valid_q_rels = raw_data_loader.q_rels("dev")
valid_raw_samples = raw_data_loader.get_topics("dev")[:400]
valid_samples = Query_Resolver_Transform(get_query_fn)(valid_raw_samples)
valid_BM25_results = RUN_File_Search_Transform('saved_models/MARCO_dev_BM25.run', hits=100)(valid_raw_samples)
val_dataset = Reranking_Validation_Dataset(valid_BM25_results, get_query_fn, get_doc_fn)
val_dataloader = val_dataset.to_dataloader(128, num_workers=32, shuffle=False)

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='loading q_rels', max=1.0, style=Progres…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='loading dev queries', max=1.0, style=Pr…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='loading q_rels', max=1.0, style=Progres…

HBox(children=(FloatProgress(value=0.0, description='Searching queries', max=400.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='Flattening search results', max=400.0, style=ProgressStyl…





In [12]:
# experiment
expr = Ranking_Experiment(valid_q_rels)
print("BM25 results")
print(expr(valid_BM25_results))

print("ORACLE results")
valid_oracle_rerank_results = Oracle_ReRanker_Transform(valid_q_rels)(valid_BM25_results)
for sample in valid_oracle_rerank_results:
    sample["search_results"] = sample["reranked_results"]
print(expr(valid_oracle_rerank_results))

print("BERT+BM25 results")
valid_BERT_rerank_results = BERT_ReRanker_Transform("saved_models/BERT_reranker_q100k_h100_checkpoints/BERT_ReRanker_MARCO_from_valid_0.38551378521302254.ckpt", get_doc_fn, device="cuda:2", batch_size=256)(valid_BM25_results)
for sample in valid_BERT_rerank_results:
    sample["search_results"] = sample["reranked_results"]
print(expr(valid_BERT_rerank_results))

BM25 results
{'map': 0.18098515570583398, 'recip_rank': 0.18098515570583398, 'ndcg': 0.27305565934364895, 'set_recall': 0.645}
ORACLE results
{'map': 0.645, 'recip_rank': 0.645, 'ndcg': 0.645, 'set_recall': 0.645}
BERT+BM25 results
cpu
<All keys matched successfully>
BERT ReRanker initialised on cuda:2. Batch size 256


HBox(children=(FloatProgress(value=0.0, description='Reranking queries', max=400.0, style=ProgressStyle(descri…


{'map': 0.3138712999797035, 'recip_rank': 0.3138712999797035, 'ndcg': 0.38551378521302254, 'set_recall': 0.645}


In [17]:
expr = Ranking_Experiment(valid_q_rels)
print(expr(valid_BERT_rerank_results))

TypeError: Unable to resolve all measures.

In [1]:
%debug

ERROR:root:No traceback has been produced, nothing to debug.


In [32]:
my_trainer = Model_Trainer(gpus=[0])
my_trainer.train(model, train_dataloader)

Detected 8 GPUS available, using [0].
Main device is: cuda:0


HBox(children=(FloatProgress(value=0.0, max=12417.0), HTML(value='')))


Keyboard Interrupt!


{'train_loss': [tensor(0.7658, device='cuda:0', grad_fn=<MseLossBackward>),
  tensor(0.5080, device='cuda:0', grad_fn=<MseLossBackward>),
  tensor(0.3244, device='cuda:0', grad_fn=<MseLossBackward>),
  tensor(0.3060, device='cuda:0', grad_fn=<MseLossBackward>),
  tensor(0.2623, device='cuda:0', grad_fn=<MseLossBackward>),
  tensor(0.4784, device='cuda:0', grad_fn=<MseLossBackward>),
  tensor(0.5412, device='cuda:0', grad_fn=<MseLossBackward>),
  tensor(0.4063, device='cuda:0', grad_fn=<MseLossBackward>),
  tensor(0.4163, device='cuda:0', grad_fn=<MseLossBackward>),
  tensor(0.5151, device='cuda:0', grad_fn=<MseLossBackward>),
  tensor(0.3588, device='cuda:0', grad_fn=<MseLossBackward>),
  tensor(0.4145, device='cuda:0', grad_fn=<MseLossBackward>),
  tensor(0.3521, device='cuda:0', grad_fn=<MseLossBackward>),
  tensor(0.2835, device='cuda:0', grad_fn=<MseLossBackward>),
  tensor(0.3656, device='cuda:0', grad_fn=<MseLossBackward>),
  tensor(0.5360, device='cuda:0', grad_fn=<MseLossBackwa

In [31]:
checkpoint_callback = ModelCheckpoint(
    filepath='saved_models/BERT_reranker_q500k_h150_checkpoints/test_saves/',
    save_top_k=3,
    verbose=True,
    monitor='ndcg',
    mode='min',
    prefix='BERT_reranker_500k_queries'
)

In [32]:
model.set_validation_q_rels(valid_q_rels)

In [38]:
wandb_logger = WandbLogger(name='mega500k',project='pytorchlightning')

ImportError: You want to use `wandb` logger which is not installed yet, install it with `pip install wandb`.

In [33]:
trainer = Trainer(gpus=1, profiler=True, 
                  print_nan_grads=True, 
                  num_sanity_val_steps=0,#len(val_dataloader), 
                  val_check_interval=0.005,
#                   logger= wandb_logger,
                  checkpoint_callback=checkpoint_callback)
trainer.fit(model, train_dataloader, val_dataloader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type      | Params
---------------------------------------------
0 | BERT_for_class | BertModel | 109 M 
1 | dropout        | Dropout   | 0     
2 | proj_layer     | Linear    | 769   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…



Profiler Report

Action              	|  Mean duration (s)	|  Total time (s) 
-----------------------------------------------------------------
on_train_start      	|  0.01525        	|  0.01525        
on_epoch_start      	|  0.0017096      	|  0.0017096      
get_train_batch     	|  0.78545        	|  66.764         
on_batch_start      	|  1.1894e-05     	|  0.001011       
model_forward       	|  0.059203       	|  5.0323         
model_backward      	|  0.11257        	|  9.5688         
on_after_backward   	|  2.2661e-06     	|  0.00019036     
optimizer_step      	|  0.017945       	|  1.5074         
on_batch_end        	|  0.0019552      	|  0.16424        
on_train_end        	|  0.0027203      	|  0.0027203      






1

In [20]:
trainer.save_checkpoint("example.ckpt")

In [None]:
%debug

In [36]:
outputs = model.validation_step(batch, 0)

In [37]:
outputs["valid_outputs"]

[0.64140784740448,
 0.650149941444397,
 0.6211758255958557,
 0.650149941444397,
 0.6418268084526062,
 0.650149941444397,
 0.6800584197044373,
 0.650149941444397,
 0.6111412048339844,
 0.650149941444397,
 0.6355796456336975,
 0.650149941444397,
 0.6755321025848389,
 0.650149941444397,
 0.6149528622627258,
 0.650149941444397]

In [148]:
train_raw_sampleget_topicsw_data_loader.get_topics("train")
train_dataset = Manual_Query_BM25_Reranking_Dataset(train_raw_samples, get_query_fn, get_doc_fn, hits=100)
train_dataloader = train_dataset.to_dataloader(2, num_workers=32)

NameError: name 'train_raw_sampleget_topicsw_data_loader' is not defined

In [41]:
BERT_BM25_reranker = BERT_BM25_Reranker(raw_data_loader.get_doc, raw_data_loader.get_query)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…




In [43]:
reranking_DataProcessor = Resolved_Query_Reranking_DataProcessor(raw_data_loader.get_doc, 
                                                  raw_data_loader.get_query, 
                                                  raw_data_loader.get_topics("train"), 
                                                  BERT_BM25_reranker.first_pass_model.predict, 
                                                  numericalizer,
                                                  max_length=512)

In [7]:
dataloader = reranking_DataProcessor.to_dataloader(2, num_workers=0)

In [96]:
len(train_dataset)

25352

In [147]:
for batch in tqdm.tqdm(train_dataloader):
    print(batch["input_ids"])
    break

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))




KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/nfs/phd_by_carlos/notebooks/src/DataProcessors.py", line 493, in __getitem__
    samples = transform(samples)
  File "/nfs/phd_by_carlos/notebooks/src/DataProcessors.py", line 465, in __call__
    samples = transform(samples)
  File "/nfs/phd_by_carlos/notebooks/src/DataProcessors.py", line 421, in __call__
    sample_obj["doc"] = self.get_doc_fn(sample_obj["d_id"])
  File "/nfs/phd_by_carlos/notebooks/src/dataset_loaders.py", line 256, in get_doc
    return self.collection[d_id]
KeyError: 'CAR_cce9dc23154a5887bbe92bfff13a4437b8ab2256'


In [41]:
BM25_Ranker(get_query_fn).predict("32_4")

'BM25_Ranker'

In [9]:
def eval_call(model):
    print("called evaluation")

In [10]:
model = Longformer_Reranker()

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=725.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=597257159.0, style=ProgressStyle(descri…




In [11]:
model.eval_callback = eval_call

In [45]:
trainer = Trainer(gpus=[0], profiler=True, gradient_clip_val=0.5, distributed_backend='dp', check_val_every_n_epoch=1)
trainer.fit(model, dataloader)

NameError: name 'dataloader' is not defined

In [10]:
dataloader = torch.utils.data.DataLoader(torch.tensor([]), batch_size=1)

# Running an experiment on Y2 data

In [30]:
CAsT_raw_data_loader = CAsT_RawDataLoader()
get_query_fn = CAsT_raw_data_loader.get_query
get_doc_fn = CAsT_raw_data_loader.get_doc
eval_raw_samples = CAsT_raw_data_loader.get_topics("train")
CAsT_q_rels = CAsT_raw_data_loader.q_rels

In [31]:
eval_samples = Query_Resolver_Transform(get_query_fn, utterance_type="manual_rewritten_utterance")(eval_raw_samples)
eval_BM25_results = BM25_Search_Transform(index_dir='datasets/TREC_CAsT/CAsT_collection_with_meta.index', hits=500)(eval_samples)

HBox(children=(FloatProgress(value=0.0, description='Searching queries', max=133.0, style=ProgressStyle(descri…




In [32]:
# experiment
expr = Ranking_Experiment(CAsT_q_rels)
print("BM25 results")
print(expr(eval_BM25_results))

print("RUN File Model")
eval_run_rerank_results = RUN_File_Search_Transform('saved_models/CAsT_y1_pgbert.run', hits=500)(eval_samples)
print(expr(eval_run_rerank_results))

print("ORACLE results")
eval_oracle_rerank_results = Oracle_ReRanker_Transform(CAsT_q_rels)(eval_BM25_results)
for sample in eval_oracle_rerank_results:
    sample["search_results"] = sample["reranked_results"]
print(expr(eval_oracle_rerank_results))

print("BERT+BM25 results")
eval_BERT_rerank_results = BERT_ReRanker_Transform("saved_models/BERT_reranker_q100k_h100_checkpoints/BERT_ReRanker_MARCO_from_valid_0.38551378521302254.ckpt", get_doc_fn, device="cuda:2", batch_size=256)(eval_BM25_results)
for sample in eval_BERT_rerank_results:
    sample["search_results"] = sample["reranked_results"]
print(expr(eval_BERT_rerank_results))

BM25 results
{'map': 0.19679730164325712, 'recip_rank': 0.4596050256759099, 'ndcg_cut_3': 0.306364841910461, 'set_recall': 0.7744165761018662}
RUN File Model


HBox(children=(FloatProgress(value=0.0, description='Searching queries', max=133.0, style=ProgressStyle(descri…


{'map': 0.22648101607514953, 'recip_rank': 0.48689326704789326, 'ndcg_cut_3': 0.3621256024784965, 'set_recall': 0.5633517089242985}
ORACLE results
{'map': 0.5633517089242985, 'recip_rank': 0.8345864661654135, 'ndcg_cut_3': 0.7816506828787912, 'set_recall': 0.5633517089242985}
BERT+BM25 results
cpu
<All keys matched successfully>
BERT ReRanker initialised on cuda:2. Batch size 256


HBox(children=(FloatProgress(value=0.0, description='Reranking queries', max=133.0, style=ProgressStyle(descri…


{'map': 0.10678501989112972, 'recip_rank': 0.25388222588499865, 'ndcg_cut_3': 0.14453453816167808, 'set_recall': 0.5633517089242985}


In [9]:
get_doc_fn('MARCO_5593358')

'â\x80\x8bâ\x80\x8bResidential garage doors from Overhead Door are among the most dependable in the industry, so you can feel good knowing that weâ\x80\x99ll be there â\x80\x94 day or night, winter or summer. For added peace of mind, our home garage doors have also been proven to be durable and long lasting. The reliability of your garage door will help you stay on schedule in the morning. Its beauty will greet you at the end of a busy workday, opening convenient, comfortable passage to your home. And through the night, the security of your garage door will help you rest assured that your family is safe.'

In [6]:
eval_BERT_rerank_results[0]

{'prev_turns': [],
 'q_id': '81_1',
 'q_rel': ['MARCO_5498474'],
 'query': 'How do you know when your garage door opener is going bad?',
 'search_results': [('MARCO_5593358', 0.8783813118934631),
  ('MARCO_5498474', 0.8653429746627808),
  ('MARCO_7308614', 0.8559824824333191),
  ('MARCO_7517892', 0.8260317444801331),
  ('MARCO_5844152', 0.792199432849884),
  ('MARCO_6245022', 0.7700963616371155),
  ('MARCO_6801809', 0.7367041707038879),
  ('MARCO_7699205', 0.6752471923828125),
  ('MARCO_5498468', 0.6368646025657654),
  ('MARCO_7308619', 0.6284499168395996),
  ('MARCO_7713531', 0.6059923768043518),
  ('MARCO_7987331', 0.5933001637458801),
  ('MARCO_6154877', 0.5182024240493774),
  ('MARCO_6154876', 0.4936941862106323),
  ('MARCO_4516819', 0.47940194606781006),
  ('MARCO_1900270', 0.4233276844024658),
  ('MARCO_6015699', 0.42029985785484314),
  ('MARCO_700026', 0.38879796862602234),
  ('MARCO_1945742', 0.3215405344963074),
  ('MARCO_7517889', 0.30050134658813477),
  ('MARCO_1381083', 0.2