Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
1 contributor

Users who have contributed to this file

110 lines (90 sloc) 4.24 KB
#!/usr/bin/env python3
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Rank documents with an ElasticSearch index"""
import logging
import scipy.sparse as sp
from multiprocessing.pool import ThreadPool
from functools import partial
from elasticsearch import Elasticsearch
from . import utils
from . import DEFAULTS
from .. import tokenizers
logger = logging.getLogger(__name__)
class ElasticDocRanker(object):
""" Connect to an ElasticSearch index.
Score pairs based on Elasticsearch
"""
def __init__(self, elastic_url=None, elastic_index=None, elastic_fields=None, elastic_field_doc_name=None, strict=True, elastic_field_content=None):
"""
Args:
elastic_url: URL of the ElasticSearch server containing port
elastic_index: Index name of ElasticSearch
elastic_fields: Fields of the Elasticsearch index to search in
elastic_field_doc_name: Field containing the name of the document (index)
strict: fail on empty queries or continue (and return empty result)
elastic_field_content: Field containing the content of document in plaint text
"""
# Load from disk
elastic_url = elastic_url or DEFAULTS['elastic_url']
logger.info('Connecting to %s' % elastic_url)
self.es = Elasticsearch(hosts=elastic_url)
self.elastic_index = elastic_index
self.elastic_fields = elastic_fields
self.elastic_field_doc_name = elastic_field_doc_name
self.elastic_field_content = elastic_field_content
self.strict = strict
# Elastic Ranker
def get_doc_index(self, doc_id):
"""Convert doc_id --> doc_index"""
field_index = self.elastic_field_doc_name
if isinstance(field_index, list):
field_index = '.'.join(field_index)
result = self.es.search(index=self.elastic_index, body={'query':{'match':
{field_index: doc_id}}})
return result['hits']['hits'][0]['_id']
def get_doc_id(self, doc_index):
"""Convert doc_index --> doc_id"""
result = self.es.search(index=self.elastic_index, body={'query': { 'match': {"_id": doc_index}}})
source = result['hits']['hits'][0]['_source']
return utils.get_field(source, self.elastic_field_doc_name)
def closest_docs(self, query, k=1):
"""Closest docs by using ElasticSearch
"""
results = self.es.search(index=self.elastic_index, body={'size':k ,'query':
{'multi_match': {
'query': query,
'type': 'most_fields',
'fields': self.elastic_fields}}})
hits = results['hits']['hits']
doc_ids = [utils.get_field(row['_source'], self.elastic_field_doc_name) for row in hits]
doc_scores = [row['_score'] for row in hits]
return doc_ids, doc_scores
def batch_closest_docs(self, queries, k=1, num_workers=None):
"""Process a batch of closest_docs requests multithreaded.
Note: we can use plain threads here as scipy is outside of the GIL.
"""
with ThreadPool(num_workers) as threads:
closest_docs = partial(self.closest_docs, k=k)
results = threads.map(closest_docs, queries)
return results
# Elastic DB
def __enter__(self):
return self
def close(self):
"""Close the connection to the database."""
self.es = None
def get_doc_ids(self):
"""Fetch all ids of docs stored in the db."""
results = self.es.search(index= self.elastic_index, body={
"query": {"match_all": {}}})
doc_ids = [utils.get_field(result['_source'], self.elastic_field_doc_name) for result in results['hits']['hits']]
return doc_ids
def get_doc_text(self, doc_id):
"""Fetch the raw text of the doc for 'doc_id'."""
idx = self.get_doc_index(doc_id)
result = self.es.get(index=self.elastic_index, doc_type='_doc', id=idx)
return result if result is None else result['_source'][self.elastic_field_content]
You can’t perform that action at this time.