Skip to content

Commit

Permalink
Add Vector PRF on top of SimpleDenseSearcher (#773)
Browse files Browse the repository at this point in the history
* implement vector PRF for DenseSearcher
  • Loading branch information
hanglics committed Sep 22, 2021
1 parent 436c9c9 commit 82f8422
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 36 deletions.
63 changes: 54 additions & 9 deletions pyserini/dsearch/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from pyserini.query_iterator import get_query_iterator, TopicsFormat
from pyserini.output_writer import get_output_writer, OutputFormat

from ._prf import AveragePRF, RocchioPRF

# Fixes this error: "OMP: Error #15: Initializing libomp.a, but found libomp.dylib already initialized."
# https://stackoverflow.com/questions/53014306/error-15-initializing-libiomp5-dylib-but-found-libiomp5-dylib-already-initial
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
Expand All @@ -47,6 +49,15 @@ def define_dsearch_args(parser):
help="Query prefix if exists.")
parser.add_argument('--searcher', type=str, metavar='str', required=False, default='simple',
help="dense searcher type")
parser.add_argument('--prf-depth', type=int, metavar='num of passages used for PRF', required=False, default=0,
help="Specify how many passages are used for PRF, 0: Simple retrieval with no PRF, > 0: perform PRF")
parser.add_argument('--prf-method', type=str, metavar='avg or rocchio', required=False, default='avg',
help="Choose PRF methods, avg or rocchio")
parser.add_argument('--rocchio-alpha', type=float, metavar='alpha parameter for rocchio', required=False,
default=0.9,
help="The alpha parameter to control the contribution from the query vector")
parser.add_argument('--rocchio-beta', type=float, metavar='beta parameter for rocchio', required=False, default=0.1,
help="The beta parameter to control the contribution from the average vector of the PRF passages")


def init_query_encoder(encoder, tokenizer_name, topics_name, encoded_queries, device, prefix):
Expand Down Expand Up @@ -88,25 +99,39 @@ def init_query_encoder(encoder, tokenizer_name, topics_name, encoded_queries, de
return BprQueryEncoder.load_encoded_queries(encoded_queries)
else:
return QueryEncoder.load_encoded_queries(encoded_queries)

if topics_name in encoded_queries_map:
return QueryEncoder.load_encoded_queries(encoded_queries_map[topics_name])
raise ValueError(f'No encoded queries for topic {topics_name}')


def run_prf(topic_ids, query_embs, candidates, arg):
if arg.prf_method.lower() == 'avg':
average_prf = AveragePRF(topic_ids, query_embs, candidates)
prf_query_embs = average_prf.get_prf_q_emb()
elif arg.prf_method.lower() == 'rocchio':
rocchio_prf = RocchioPRF(topic_ids, query_embs, candidates,
rocchio_alpha=arg.rocchio_alpha, rocchio_beta=arg.rocchio_beta)
prf_query_embs = rocchio_prf.get_prf_q_emb()
else:
raise ValueError(f'PRF Method {arg.prf_method} Not Implemented')
return prf_query_embs


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Search a Faiss index.')
parser.add_argument('--topics', type=str, metavar='topic_name', required=True,
help="Name of topics. Available: msmarco-passage-dev-subset.")
parser.add_argument('--hits', type=int, metavar='num', required=False, default=1000, help="Number of hits.")
parser.add_argument('--binary-hits', type=int, metavar='num', required=False, default=1000, help="Number of binary hits.")
parser.add_argument('--binary-hits', type=int, metavar='num', required=False, default=1000,
help="Number of binary hits.")
parser.add_argument("--rerank", action="store_true", help='whethere rerank bpr sparse results.')
parser.add_argument('--topics-format', type=str, metavar='format', default=TopicsFormat.DEFAULT.value,
help=f"Format of topics. Available: {[x.value for x in list(TopicsFormat)]}")
parser.add_argument('--output-format', type=str, metavar='format', default=OutputFormat.TREC.value,
help=f"Format of output. Available: {[x.value for x in list(OutputFormat)]}")
parser.add_argument('--output', type=str, metavar='path', required=True, help="Path to output file.")
parser.add_argument('--max-passage', action='store_true',
parser.add_argument('--max-passage', action='store_true',
default=False, help="Select only max passage from document.")
parser.add_argument('--max-passage-hits', type=int, metavar='num', required=False, default=100,
help="Final number of hits when selecting only max passage.")
Expand All @@ -122,7 +147,8 @@ def init_query_encoder(encoder, tokenizer_name, topics_name, encoded_queries, de
query_iterator = get_query_iterator(args.topics, TopicsFormat(args.topics_format))
topics = query_iterator.topics

query_encoder = init_query_encoder(args.encoder, args.tokenizer, args.topics, args.encoded_queries, args.device, args.query_prefix)
query_encoder = init_query_encoder(args.encoder, args.tokenizer, args.topics, args.encoded_queries, args.device,
args.query_prefix)
kwargs = {}
if os.path.exists(args.index):
# create searcher from index directory
Expand All @@ -138,10 +164,16 @@ def init_query_encoder(encoder, tokenizer_name, topics_name, encoded_queries, de
searcher = BinaryDenseSearcher.from_prebuilt_index(args.index, query_encoder)
else:
searcher = SimpleDenseSearcher.from_prebuilt_index(args.index, query_encoder)

if not searcher:
exit()

# Check PRF Flag
if args.prf_depth > 0 and type(searcher) == SimpleDenseSearcher:
PRF_FLAG = True
else:
PRF_FLAG = False

# build output path
output_path = args.output

Expand All @@ -159,16 +191,29 @@ def init_query_encoder(encoder, tokenizer_name, topics_name, encoded_queries, de
batch_topic_ids = list()
for index, (topic_id, text) in enumerate(tqdm(query_iterator, total=len(topics.keys()))):
if args.batch_size <= 1 and args.threads <= 1:
hits = searcher.search(text, args.hits, **kwargs)
if PRF_FLAG:
emb_q, prf_candidates = searcher.get_prf_candidates(text, args.prf_depth, **kwargs)
prf_emb_q = run_prf(topic_id, emb_q, prf_candidates, args)
hits = searcher.search(prf_emb_q, args.hits, **kwargs)
else:
hits = searcher.search(text, args.hits, **kwargs)
results = [(topic_id, hits)]
else:
batch_topic_ids.append(str(topic_id))
batch_topics.append(text)
if (index + 1) % args.batch_size == 0 or \
index == len(topics.keys()) - 1:
results = searcher.batch_search(
batch_topics, batch_topic_ids, args.hits, threads=args.threads, **kwargs)
results = [(id_, results[id_]) for id_ in batch_topic_ids]
if PRF_FLAG:
q_embs, prf_candidates = searcher.get_batch_prf_candidates(batch_topics, batch_topic_ids,
args.prf_depth, **kwargs)
prf_embs_q = run_prf(batch_topic_ids, q_embs, prf_candidates, args)
results = searcher.batch_search(prf_embs_q, batch_topic_ids, args.hits, threads=args.threads,
**kwargs)
results = [(id_, results[id_]) for id_ in batch_topic_ids]
else:
results = searcher.batch_search(batch_topics, batch_topic_ids, args.hits, threads=args.threads,
**kwargs)
results = [(id_, results[id_]) for id_ in batch_topic_ids]
batch_topic_ids.clear()
batch_topics.clear()
else:
Expand Down
132 changes: 105 additions & 27 deletions pyserini/dsearch/_dsearcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import numpy as np
import pandas as pd

from transformers import (AutoModel, AutoTokenizer, BertModel, BertTokenizer, BertTokenizerFast,
from transformers import (AutoModel, AutoTokenizer, BertModel, BertTokenizer, BertTokenizerFast,
DPRQuestionEncoder, DPRQuestionEncoderTokenizer, RobertaTokenizer)
from transformers.file_utils import is_faiss_available, requires_backends

Expand Down Expand Up @@ -147,7 +147,7 @@ def __init__(self, encoder_dir: str = None, tokenizer_name: str = None,
if encoded_query_dir:
self.embedding = self._load_embeddings(encoded_query_dir)
self.has_encoded_query = True

if encoder_dir:
self.device = device
self.model = DPRQuestionEncoder.from_pretrained(encoder_dir)
Expand All @@ -164,25 +164,27 @@ def encode(self, query: str):
embeddings = self.model(input_ids["input_ids"]).pooler_output.detach().cpu()
dense_embeddings = embeddings.numpy()
sparse_embeddings = self.convert_to_binary_code(embeddings).numpy()
return {'dense':dense_embeddings.flatten(), 'sparse':sparse_embeddings.flatten()}
return {'dense': dense_embeddings.flatten(), 'sparse': sparse_embeddings.flatten()}
else:
return super().encode(query)

def convert_to_binary_code(self, input_repr: torch.Tensor):
return input_repr.new_ones(input_repr.size()).masked_fill_(input_repr < 0, -1.0)

@staticmethod
def _load_embeddings(encoded_query_dir):
df = pd.read_pickle(os.path.join(encoded_query_dir, 'embedding.pkl'))
ret = {}
for text, dense, sparse in zip(df['text'].tolist(), df['dense_embedding'].tolist(), df['sparse_embedding'].tolist()):
for text, dense, sparse in zip(df['text'].tolist(), df['dense_embedding'].tolist(),
df['sparse_embedding'].tolist()):
ret[text] = {'dense': dense, 'sparse': sparse}
return ret


class DkrrDprQueryEncoder(QueryEncoder):

def __init__(self, encoder_dir: str = None, encoded_query_dir: str = None, device: str = 'cpu', prefix: str = "question:"):
def __init__(self, encoder_dir: str = None, encoded_query_dir: str = None, device: str = 'cpu',
prefix: str = "question:"):
super().__init__(encoded_query_dir)
self.device = device
self.model = BertModel.from_pretrained(encoder_dir)
Expand All @@ -204,7 +206,7 @@ def encode(self, query: str):
inputs = self.tokenizer(query, return_tensors='pt', max_length=40, padding="max_length")
inputs.to(self.device)
outputs = self.model(input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"])
attention_mask=inputs["attention_mask"])
embeddings = self._mean_pooling(outputs, inputs['attention_mask']).detach().cpu().numpy()
return embeddings.flatten()
else:
Expand Down Expand Up @@ -295,6 +297,13 @@ class DenseSearchResult:
score: float


@dataclass
class PRFDenseSearchResult:
docid: str
score: float
vectors: [float]


class SimpleDenseSearcher:
"""Simple Searcher for dense representation
Expand All @@ -304,7 +313,8 @@ class SimpleDenseSearcher:
Path to faiss index directory.
"""

def __init__(self, index_dir: str, query_encoder: Union[QueryEncoder, str], prebuilt_index_name: Optional[str] = None):
def __init__(self, index_dir: str, query_encoder: Union[QueryEncoder, str],
prebuilt_index_name: Optional[str] = None):
requires_backends(self, "faiss")
if isinstance(query_encoder, QueryEncoder):
self.query_encoder = query_encoder
Expand All @@ -313,7 +323,7 @@ def __init__(self, index_dir: str, query_encoder: Union[QueryEncoder, str], preb
self.index, self.docids = self.load_index(index_dir)
self.dimension = self.index.d
self.num_docs = self.index.ntotal

assert self.docids is None or self.num_docs == len(self.docids)
if prebuilt_index_name:
sparse_index = get_sparse_index(prebuilt_index_name)
Expand Down Expand Up @@ -350,13 +360,13 @@ def list_prebuilt_indexes():
"""Display information about available prebuilt indexes."""
get_dense_indexes_info()

def search(self, query: str, k: int = 10, threads: int = 1) -> List[DenseSearchResult]:
def search(self, query: Union[str, np.ndarray], k: int = 10, threads: int = 1) -> List[DenseSearchResult]:
"""Search the collection.
Parameters
----------
query : str
query text
query : Union[str, np.ndarray]
query text or query embeddings
k : int
Number of hits to return.
threads : int
Expand All @@ -366,19 +376,50 @@ def search(self, query: str, k: int = 10, threads: int = 1) -> List[DenseSearchR
List[DenseSearchResult]
List of search results.
"""
emb_q = self.query_encoder.encode(query)
assert len(emb_q) == self.dimension
emb_q = emb_q.reshape((1, len(emb_q)))
if isinstance(query, str):
emb_q = self.query_encoder.encode(query)
assert len(emb_q) == self.dimension
emb_q = emb_q.reshape((1, len(emb_q)))
else:
emb_q = query
faiss.omp_set_num_threads(threads)
distances, indexes = self.index.search(emb_q, k)
distances = distances.flat
indexes = indexes.flat
return [DenseSearchResult(self.docids[idx], score)
for score, idx in zip(distances, indexes) if idx != -1]

def batch_search(self, queries: List[str], q_ids: List[str], k: int = 10, threads: int = 1) \
-> Dict[str, List[DenseSearchResult]]:
def get_prf_candidates(self, query: str, k: int = 10, threads: int = 1):
"""Search the collection to get PRF candidates
Parameters
----------
query : str
query text
k : int
Number of hits to return.
threads : int
Maximum number of threads to use for intra-query search.
Returns
-------
np.ndarray
Holds the query embeddings
List[PRFDenseSearchResult]
List of search results, with doc vectors returned.
"""
emb_q = self.query_encoder.encode(query)
assert len(emb_q) == self.dimension
emb_q = emb_q.reshape((1, len(emb_q)))
faiss.omp_set_num_threads(threads)
distances, indexes, vectors = self.index.search_and_reconstruct(emb_q, k)
vectors = vectors[0]
distances = distances.flat
indexes = indexes.flat
return emb_q, [PRFDenseSearchResult(self.docids[idx], score, vector)
for score, idx, vector in zip(distances, indexes, vectors) if idx != -1]

def get_batch_prf_candidates(self, queries: List[str], q_ids: List[str], k: int = 10, threads: int = 1):
"""Batch search to get the PRF candidates
Parameters
----------
Expand All @@ -393,14 +434,49 @@ def batch_search(self, queries: List[str], q_ids: List[str], k: int = 10, thread
Returns
-------
Dict[str, List[DenseSearchResult]]
Dictionary holding the search results, with the query ids as keys and the corresponding lists of search
results as the values.
np.ndarray
Holds the query embeddings
Dict[str, List[PRFDenseSearchResult]]
Dictionary holding the PRF candidate results, with the query ids as keys and the corresponding lists of
candidates as the values.
"""
q_embs = np.array([self.query_encoder.encode(q) for q in queries])
n, m = q_embs.shape
assert m == self.dimension
faiss.omp_set_num_threads(threads)
D, I, V = self.index.search_and_reconstruct(q_embs, k)
return q_embs, {key: [PRFDenseSearchResult(self.docids[idx], score, vector)
for score, idx, vector in zip(distances, indexes, vectors) if idx != -1]
for key, distances, indexes, vectors in zip(q_ids, D, I, V)}

def batch_search(self, queries: Union[List[str], np.ndarray], q_ids: List[str], k: int = 10, threads: int = 1) \
-> Dict[str, List[DenseSearchResult]]:
"""
Parameters
----------
queries : Union[List[str], np.ndarray]
List of query texts or list of query embeddings
q_ids : List[str]
List of corresponding query ids.
k : int
Number of hits to return.
threads : int
Maximum number of threads to use.
Returns
-------
Dict[str, List[DenseSearchResult]]
Dictionary holding the search results, with the query ids as keys and the corresponding lists of search
results as the values.
"""
if isinstance(queries, np.ndarray):
q_embs = queries
else:
q_embs = np.array([self.query_encoder.encode(q) for q in queries])
n, m = q_embs.shape
assert m == self.dimension
faiss.omp_set_num_threads(threads)
D, I = self.index.search(q_embs, k)
return {key: [DenseSearchResult(self.docids[idx], score)
for score, idx in zip(distances, indexes) if idx != -1]
Expand Down Expand Up @@ -462,10 +538,12 @@ class BinaryDenseSearcher(SimpleDenseSearcher):
Path to faiss index directory.
"""

def __init__(self, index_dir: str, query_encoder: Union[QueryEncoder, str], prebuilt_index_name: Optional[str] = None):
def __init__(self, index_dir: str, query_encoder: Union[QueryEncoder, str],
prebuilt_index_name: Optional[str] = None):
super().__init__(index_dir, query_encoder, prebuilt_index_name)

def search(self, query: str, k: int = 10, binary_k: int = 100, rerank: bool = True, threads: int = 1) -> List[DenseSearchResult]:
def search(self, query: str, k: int = 10, binary_k: int = 100, rerank: bool = True, threads: int = 1) \
-> List[DenseSearchResult]:
"""Search the collection.
Parameters
Expand All @@ -490,7 +568,7 @@ def search(self, query: str, k: int = 10, binary_k: int = 100, rerank: bool = Tr
sparse_emb_q = ret['sparse']
assert len(dense_emb_q) == self.dimension
assert len(sparse_emb_q) == self.dimension

dense_emb_q = dense_emb_q.reshape((1, len(dense_emb_q)))
sparse_emb_q = sparse_emb_q.reshape((1, len(sparse_emb_q)))
faiss.omp_set_num_threads(threads)
Expand All @@ -500,8 +578,8 @@ def search(self, query: str, k: int = 10, binary_k: int = 100, rerank: bool = Tr
return [DenseSearchResult(str(idx), score)
for score, idx in zip(distances, indexes) if idx != -1]

def batch_search(self, queries: List[str], q_ids: List[str], k: int = 10, binary_k: int = 100, \
rerank: bool = True, threads: int = 1) -> Dict[str, List[DenseSearchResult]]:
def batch_search(self, queries: List[str], q_ids: List[str], k: int = 10, binary_k: int = 100,
rerank: bool = True, threads: int = 1) -> Dict[str, List[DenseSearchResult]]:
"""
Parameters
Expand Down Expand Up @@ -566,7 +644,7 @@ def binary_dense_search(self, k, binary_k, rerank, dense_emb_q, sparse_emb_q):
indexes = indexes.reshape(num_queries, -1)[:, :k]
distances = distances[np.arange(num_queries)[:, None], sorted_indices][:, :k]
return distances, indexes

def load_index(self, index_dir: str):
index_path = os.path.join(index_dir, 'index')
index = faiss.read_index_binary(index_path)
Expand Down
Loading

0 comments on commit 82f8422

Please sign in to comment.