Skip to content

Commit

Permalink
add first implementation of bm25_tokenize_with
Browse files Browse the repository at this point in the history
  • Loading branch information
CarlosFerLo committed May 16, 2024
1 parent 686a499 commit ae7ae65
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions haystack/document_stores/in_memory/document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re
from collections import Counter
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -49,15 +49,17 @@ class InMemoryDocumentStore:

def __init__(
self,
bm25_tokenization_regex: str = r"(?u)\b\w\w+\b",
bm25_tokenization_regex: str = r"(?u)\b\w\w+\b", # deprecated
bm25_algorithm: Literal["BM25Okapi", "BM25L", "BM25Plus"] = "BM25L",
bm25_parameters: Optional[Dict] = None,
embedding_similarity_function: Literal["dot_product", "cosine"] = "dot_product",
bm25_tokenize_with: Optional[Union[str, Callable[[str], List[str]]]] = None,
):
"""
Initializes the DocumentStore.
:param bm25_tokenization_regex: The regular expression used to tokenize the text for BM25 retrieval.
:param bm25_tokenization_regex: The regular expression used to tokenize the text for BM25 retrieval. (deprecated)
:param bm25_tokenize_with: The regular expression used to tokenize the text for BM25 retrieval or a callable that tokenizes the text.
:param bm25_algorithm: The BM25 algorithm to use. One of "BM25Okapi", "BM25L", or "BM25Plus".
:param bm25_parameters: Parameters for BM25 implementation in a dictionary format.
For example: {'k1':1.5, 'b':0.75, 'epsilon':0.25}
Expand All @@ -68,8 +70,18 @@ def __init__(
To choose the most appropriate function, look for information about your embedding model.
"""
self.storage: Dict[str, Document] = {}
self.bm25_tokenization_regex = bm25_tokenization_regex
self.tokenizer = re.compile(bm25_tokenization_regex).findall

if bm25_tokenize_with is None:
self.bm25_tokenization_regex = bm25_tokenization_regex # deprecated
self.bm25_tokenize_with = None
self.tokenizer = re.compile(bm25_tokenization_regex).findall
else:
self.bm25_tokenization_regex = None # deprecated
self.bm25_tokenize_with = bm25_tokenize_with
if isinstance(bm25_tokenize_with, str):
self.tokenizer = re.compile(bm25_tokenize_with).findall
else:
self.tokenizer = bm25_tokenize_with

self.bm25_algorithm = bm25_algorithm
self.bm25_algorithm_inst = self._dispatch_bm25()
Expand Down

0 comments on commit ae7ae65

Please sign in to comment.