In [1]:
import logging
import math
import os
import random
import sys
import time
from typing import Tuple

import hydra
import torch
from omegaconf import DictConfig, OmegaConf
from torch import Tensor as T
from torch import nn

from dpr.models import init_biencoder_components
from dpr.models.biencoder import BiEncoderNllLoss, BiEncoderBatch
from dpr.options import (
    setup_cfg_gpu,
    set_seed,
    get_encoder_params_state_from_cfg,
    set_cfg_params_from_state,
    setup_logger,
)
from dpr.utils.conf_utils import BiencoderDatasetsCfg
from dpr.utils.data_utils import (
    ShardedDataIterator,
    Tensorizer,
    MultiSetDataIterator,
    LocalShardedDataIterator,
)
from dpr.utils.dist_utils import all_gather_list
from dpr.utils.model_utils import (
    setup_for_distributed_mode,
    move_to_device,
    get_schedule_linear,
    CheckpointState,
    get_model_file,
    get_model_obj,
    load_states_from_checkpoint,
)

logger = logging.getLogger()
setup_logger(logger)



In [2]:
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
import hydra
hydra.initialize()
cfg = compose(config_name="conf/dense_retriever.yaml")
cfg = cfg.conf



The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  hydra.initialize()
See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/changes_to_hydra_main_config_path for more information.
  hydra.initialize()
See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/changes_to_package_header for more information


In [3]:
print(OmegaConf.to_yaml(cfg))

encoder:
  encoder_model_type: hf_bert
  pretrained_model_cfg: bert-base-uncased
  pretrained_file: null
  projection_dim: 0
  sequence_length: 256
  dropout: 0.1
  fix_ctx_encoder: true
  pretrained: true
datasets:
  nq_test:
    _target_: dpr.data.retriever_data.CsvQASrc
    file: data.retriever.qas.nq-test
  nq_train:
    _target_: dpr.data.retriever_data.CsvQASrc
    file: data.retriever.qas.nq-train
  nq_dev:
    _target_: dpr.data.retriever_data.CsvQASrc
    file: data.retriever.qas.nq-dev
  trivia_test:
    _target_: dpr.data.retriever_data.CsvQASrc
    file: data.retriever.qas.trivia-test
  trivia_train:
    _target_: dpr.data.retriever_data.CsvQASrc
    file: data.retriever.qas.trivia-train
  trivia_dev:
    _target_: dpr.data.retriever_data.CsvQASrc
    file: data.retriever.qas.trivia-dev
  webq_test:
    _target_: dpr.data.retriever_data.CsvQASrc
    file: data.retriever.qas.webq-test
  curatedtrec_test:
    _target_: dpr.data.retriever_data.CsvQASrc
    file: data.retriever

In [4]:
cfg.model_file = "outputs/2024-04-09/21-01-56/poisoned_one_positive_one_negative/dpr_biencoder.31"

In [5]:
cfg.qa_dataset = "nq_test"

In [8]:

saved_state = load_states_from_checkpoint(cfg.model_file)
set_cfg_params_from_state(saved_state.encoder_params, cfg)

[140547894019136] 2024-05-16 01:17:21,365 [INFO] root: Reading saved model from outputs/2024-04-09/21-01-56/poisoned_one_positive_one_negative/dpr_biencoder.31
[140547894019136] 2024-05-16 01:17:26,307 [INFO] root: model_state_dict keys dict_keys(['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset', 'epoch', 'encoder_params'])


In [9]:
tensorizer, encoder, _ = init_biencoder_components(cfg.encoder.encoder_model_type, cfg, inference_only=True)


[140547894019136] 2024-05-16 01:17:29,567 [INFO] dpr.models.hf_models: Initializing HF BERT Encoder. cfg_name=bert-base-uncased
[140547894019136] 2024-05-16 01:17:29,789 [INFO] dpr.models.hf_models: Initializing HF BERT Encoder. cfg_name=bert-base-uncased


In [10]:
logger.info("Loading saved model state ...")
encoder.load_state(saved_state, strict=False)

[140547894019136] 2024-05-16 01:17:33,677 [INFO] root: Loading saved model state ...


In [11]:
ctx_files_patterns = cfg.encoded_ctx_files

In [15]:
ctx_files_patterns = "/scratch/gbagwe/Projects/DPR/downloads/data/retriever_results/nq/single/wikipedia_passages_*"

In [16]:
ctx_files_patterns

'/scratch/gbagwe/Projects/DPR/downloads/data/retriever_results/nq/single/wikipedia_passages_*'

In [23]:
ctx_datatsets=["dpr_wiki"]

In [26]:
ctx_datatsets[0]

'dpr_wiki'

In [28]:
id_prefixes = []
ctx_sources = []
print("\n\n\n ********", ctx_datatsets[0], " \n\n\n*********")
for ctx_src in ctx_datatsets:
    ctx_src = hydra.utils.instantiate(cfg.ctx_sources[ctx_src])
    id_prefixes.append(ctx_src.id_prefix)
    ctx_sources.append(ctx_src)




 ******** dpr_wiki  


*********


In [29]:
if ctx_files_patterns:
    assert len(ctx_files_patterns) == len(id_prefixes), "ctx len={} pref leb={}".format(
        len(ctx_files_patterns), len(id_prefixes)
    )
else:
    assert (
        index_path or cfg.rpc_index_id
    ), "Either encoded_ctx_files or index_path pr rpc_index_id parameter should be set."

input_paths = []
path_id_prefixes = []

AssertionError: ctx len=92 pref leb=1

In [40]:
# model_to_load = get_model_obj(encoder)
# vector_size = model_to_load.get_out_size()
from dense_retriever import LocalFaissRetriever
index = hydra.utils.instantiate(cfg.indexers[cfg.indexer])
logger.info("Local Index class %s ", type(index))
index_buffer_sz = index.buffer_size
index.init_index(768)
retriever = LocalFaissRetriever(encoder, cfg.batch_size, tensorizer, index )

[140547894019136] 2024-05-16 01:53:23,556 [INFO] root: Local Index class <class 'dpr.indexer.faiss_indexers.DenseFlatIndexer'> 


In [42]:
retriever.index_encoded_data("/scratch/gbagwe/Projects/DPR/downloads/data/retriever_results/nq/single/wikipedia_passages_1", index_buffer_sz, path_id_prefixes=path_id_prefixes)

NameError: name 'path_id_prefixes' is not defined