Skip to content

Commit

Permalink
Set device optional for config and remove index_management (#818)
Browse files Browse the repository at this point in the history
  • Loading branch information
wanliAlex committed Apr 23, 2024
1 parent d2b96af commit a5ba6e9
Show file tree
Hide file tree
Showing 17 changed files with 78 additions and 55 deletions.
File renamed without changes.
8 changes: 6 additions & 2 deletions src/marqo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
from marqo.core.monitoring.monitoring import Monitoring
from marqo.core.search.recommender import Recommender
from marqo.tensor_search import enums
from marqo.tensor_search import utils
from marqo.tensor_search.enums import EnvVars
from marqo.vespa.vespa_client import VespaClient


class Config:
def __init__(
self,
vespa_client: VespaClient,
default_device: str,
default_device: Optional[str] = None,
timeout: Optional[int] = None,
backend: Optional[Union[enums.SearchDb, str]] = None,
) -> None:
Expand All @@ -27,13 +29,15 @@ def __init__(
self.set_is_remote(vespa_client)
self.timeout = timeout
self.backend = backend if backend is not None else enums.SearchDb.vespa
self.default_device = default_device if default_device is not None else (
utils.read_env_vars_and_defaults(EnvVars.MARQO_BEST_AVAILABLE_DEVICE))

# Initialize Core layer dependencies
self.index_management = IndexManagement(vespa_client)
self.monitoring = Monitoring(vespa_client, self.index_management)
self.document = Document(vespa_client, self.index_management)
self.recommender = Recommender(vespa_client, self.index_management)
self.embed = Embed(vespa_client, self.index_management, default_device)
self.embed = Embed(vespa_client, self.index_management, self.default_device)

def set_is_remote(self, vespa_client: VespaClient):
local_host_markers = ["localhost", "0.0.0.0", "127.0.0.1"]
Expand Down
39 changes: 8 additions & 31 deletions src/marqo/core/embed/embed.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,15 @@
from timeit import default_timer as timer
from typing import List, Optional, Union, Iterable, Sequence, Dict, Any, Tuple
import numpy as np
from typing import List, Optional, Union, Dict

import pydantic
import marqo.core.unstructured_vespa_index.common as unstructured_common
from marqo import marqo_docs
from marqo import exceptions as base_exceptions
from marqo.api import exceptions as api_exceptions
from marqo.core import exceptions as core_exceptions
from marqo.core import constants

from marqo import exceptions as base_exceptions
from marqo.core.index_management.index_management import IndexManagement
from marqo.core.models.marqo_index import IndexType
from marqo.core.models.marqo_index import MarqoIndex, FieldType, UnstructuredMarqoIndex, StructuredMarqoIndex
from marqo.core.models.marqo_query import MarqoTensorQuery, MarqoLexicalQuery
from marqo.core.structured_vespa_index.structured_vespa_index import StructuredVespaIndex
from marqo.core.unstructured_vespa_index import unstructured_validation as unstructured_index_add_doc_validation
from marqo.core.unstructured_vespa_index.unstructured_vespa_index import UnstructuredVespaIndex
from marqo.core.vespa_index import for_marqo_index as vespa_index_factory
from marqo.s2_inference import errors as s2_inference_errors
from marqo.s2_inference import s2_inference
from marqo.s2_inference.clip_utils import _is_image
from marqo.s2_inference.processing import image as image_processor
from marqo.s2_inference.processing import text as text_processor
from marqo.s2_inference.reranking import rerank
from marqo.tensor_search.models.add_docs_objects import AddDocsParams
from marqo.tensor_search.models.api_models import BulkSearchQueryEntity, ScoreModifier
from marqo.tensor_search.models.api_models import BulkSearchQueryEntity
from marqo.tensor_search.models.private_models import ModelAuth
from marqo.tensor_search.models.search import Qidx, JHash, SearchContext, VectorisedJobs, VectorisedJobPointer
from marqo.tensor_search.models.search import Qidx
from marqo.tensor_search.telemetry import RequestMetricsStore
from marqo.tensor_search.tensor_search_logging import get_logger
from marqo.vespa.exceptions import VespaStatusError
from marqo.vespa.models import VespaDocument, FeedBatchResponse, QueryResult
from marqo.vespa.vespa_client import VespaClient

logger = get_logger(__name__)
Expand All @@ -53,7 +31,7 @@ def embed_content(
self, content: Union[str, Dict[str, float], List[Union[str, Dict[str, float]]]],
index_name: str, device: str = None, image_download_headers: Optional[Dict] = None,
model_auth: Optional[ModelAuth] = None
) -> List[List[float]]:
) -> Dict:
"""
Use the index's model to embed the content
Expand All @@ -65,11 +43,9 @@ def embed_content(
# TODO: Remove this config constructor once vectorise pipeline doesn't need it. Just pass the vespa client
# and index management objects.
from marqo import config
from marqo.tensor_search import utils, validation, tensor_search, index_meta_cache
from marqo.tensor_search import tensor_search, index_meta_cache
temp_config = config.Config(
vespa_client=self.vespa_client,
index_management=self.index_management,
default_device=self.default_device
)

# Set default device if not provided
Expand Down Expand Up @@ -108,7 +84,8 @@ def embed_content(
# Vectorise the queries
with RequestMetricsStore.for_request().time(f"embed.vector_inference_full_pipeline"):
qidx_to_vectors: Dict[Qidx, List[float]] = tensor_search.run_vectorise_pipeline(temp_config, queries, device)
embeddings = list(qidx_to_vectors.values())

embeddings: List[List[float]] = list(qidx_to_vectors.values())

# Record time and return final result
time_taken = timer() - t0
Expand Down
2 changes: 2 additions & 0 deletions src/marqo/core/search/recommender.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from marqo.core.models.marqo_index import IndexType
from marqo.core.utils.vector_interpolation import from_interpolation_method
from marqo.exceptions import InvalidArgumentError
from marqo.tensor_search import utils
from marqo.tensor_search.enums import EnvVars
from marqo.tensor_search.models.score_modifiers_object import ScoreModifier
from marqo.tensor_search.models.search import SearchContext, SearchContextTensor
from marqo.vespa.vespa_client import VespaClient
Expand Down
5 changes: 5 additions & 0 deletions src/marqo/inference/inference_cache/abstract_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def popitem(self) -> None:
"""
pass

@abstractmethod
def clear(self) -> None:
"""Remove all items from the cache."""
pass

@abstractmethod
def __contains__(self, key: Hashable) -> bool:
"""Return True if the key is in the cache, else False.
Expand Down
5 changes: 5 additions & 0 deletions src/marqo/inference/inference_cache/marqo_inference_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def _generate_key(self, model_cache_key: str, content: str) -> str:
raise TypeError(f"content must be a string, not {type(content)}")
return f"{model_cache_key}||{content}"

def clear(self) -> None:
"""Clear the cache."""
if self._cache is not None:
self._cache.clear()

def is_enabled(self) -> bool:
"""Return True if the cache is enabled, else False."""
return self._cache is not None
Expand Down
4 changes: 4 additions & 0 deletions src/marqo/inference/inference_cache/marqo_lfu_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def popitem(self) -> None:
with self.lock.gen_wlock():
self._cache.popitem()

def clear(self) -> None:
with self.lock.gen_wlock():
self._cache.clear()

@property
def maxsize(self) -> int:
"""Return the maximum size of the cache."""
Expand Down
4 changes: 4 additions & 0 deletions src/marqo/inference/inference_cache/marqo_lru_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def popitem(self) -> None:
with self.lock.gen_wlock():
self._cache.popitem()

def clear(self) -> None:
with self.lock.gen_wlock():
self._cache.clear()

@property
def maxsize(self) -> int:
"""Return the maximum size of the cache."""
Expand Down
7 changes: 6 additions & 1 deletion src/marqo/s2_inference/s2_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from marqo.s2_inference.model_registry import load_model_properties
from marqo.s2_inference.models.model_type import ModelType
from marqo.s2_inference.types import *
from marqo.tensor_search.configs import EnvVars
from marqo.api.configs import EnvVars
from marqo.tensor_search.enums import AvailableModelsKey
from marqo.tensor_search.models.private_models import ModelAuth
from marqo.tensor_search.utils import read_env_vars_and_defaults, generate_batches, read_env_vars_and_defaults_ints
Expand Down Expand Up @@ -474,6 +474,11 @@ def clear_loaded_models() -> None:
torch.cuda.empty_cache()


def clear_marqo_inference_cache() -> None:
""" clears the inference cache if it is enabled"""
if _marqo_inference_cache.is_enabled():
_marqo_inference_cache.clear()

def get_model_properties_from_registry(model_name: str) -> dict:
""" Returns a dict describing properties of a model.
Expand Down
4 changes: 2 additions & 2 deletions src/marqo/tensor_search/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import torch
from fastapi import HTTPException

from marqo.api import exceptions
from marqo.api import exceptions, configs
from marqo.marqo_logging import logger
from marqo.tensor_search import enums, configs
from marqo.tensor_search import enums
from marqo.tensor_search.enums import EnvVars


Expand Down
13 changes: 12 additions & 1 deletion tests/core/inference/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,15 @@ def test_cache_maxsizeAndCurrsize(self):
self.assertEqual(cache.maxsize, 2, f"{cache_type} Maxsize property incorrect.")
cache.set('key1', 'value1')
cache.set('key2', 'value2')
self.assertEqual(cache.currsize, 2, f"{cache_type} Currsize property incorrect after adding items.")
self.assertEqual(cache.currsize, 2, f"{cache_type} Currsize property incorrect after adding items.")

def test_cache_clear(self):
"""Test the clear method for both cache types."""
for cache_type, cache in self.caches.items():
with self.subTest(cache_type=cache_type):
cache.set('key1', 'value1')
cache.set('key2', 'value2')
cache.clear()
self.assertEqual(len(cache), 0, f"{cache_type} cache did not clear.")
self.assertEqual(cache.currsize, 0, f"{cache_type} cache did not clear currsize.")
self.assertEqual(cache.maxsize, 2, f"{cache_type} cache maxsize changed after clear.")
11 changes: 11 additions & 0 deletions tests/core/inference/test_inference_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,14 @@ def read_write_cache(cache):
self.assertEqual(ITERATIONS, len(result))
self.assertTrue(hits.qsize() > 0)
self.assertTrue(misses.qsize() > 0)

def test_inference_cache_clear(self):
"""Test if the cache clears all items."""
for cache_type in ['LRU', 'LFU']:
with self.subTest(cache_type=cache_type):
cache = MarqoInferenceCache(cache_size=self.cache_size, cache_type=cache_type)
cache.set("model-cache-key", "content", [1.0])
cache.clear()
self.assertEqual(cache.currsize, 0)
self.assertIsNone(cache.get("model-cache-key", "content"))
self.assertEqual(cache.maxsize, self.cache_size)
4 changes: 3 additions & 1 deletion tests/core/inference/test_vectorise_inference_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
import numpy as np
from PIL import Image

from marqo.s2_inference.s2_inference import get_marqo_inference_cache
from marqo.s2_inference.s2_inference import get_marqo_inference_cache, clear_marqo_inference_cache, clear_loaded_models


class TestVectoriseInferenceCache(unittest.TestCase):

def tearDown(self):
clear_marqo_inference_cache()
clear_loaded_models()
# Remove the specific environment variables and loaded modules
if 'MARQO_INFERENCE_CACHE_TYPE' in os.environ:
del os.environ['MARQO_INFERENCE_CACHE_TYPE']
Expand Down
10 changes: 2 additions & 8 deletions tests/tensor_search/integ_tests/test_delete_documents.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,19 @@
import datetime
import unittest
from copy import deepcopy, copy
from copy import copy
import marqo.tensor_search.delete_docs
from marqo.tensor_search.models.delete_docs_objects import MqDeleteDocsRequest, MqDeleteDocsResponse
from marqo.tensor_search.models.add_docs_objects import AddDocsParams
import marqo.tensor_search.tensor_search
from marqo.core.index_management import index_management
from marqo.core.models.marqo_index import Model
from marqo.tensor_search import tensor_search, delete_docs
from tests.marqo_test import MarqoTestCase
import requests
from unittest.mock import patch
from marqo.core.models.marqo_index import *
from marqo.core.models.marqo_index_request import FieldRequest
from marqo.core import exceptions as core_exceptions
from marqo.api import exceptions as api_exceptions
from marqo.tensor_search import enums
from tests.utils.transition import add_docs_caller, add_docs_batched
from tests.utils.transition import add_docs_batched
import os
from marqo.vespa.models.delete_document_response import DeleteBatchDocumentResponse, DeleteBatchResponse
from marqo.tensor_search.configs import default_env_vars


class TestDeleteDocuments(MarqoTestCase):
Expand Down
6 changes: 3 additions & 3 deletions tests/tensor_search/test_index_meta_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
from marqo.tensor_search import tensor_search
from marqo.tensor_search import index_meta_cache
from marqo.config import Config
from marqo.api.exceptions import MarqoError, MarqoApiError, IndexNotFoundError
from marqo.api.exceptions import IndexNotFoundError
from marqo.tensor_search import utils
from marqo.tensor_search.enums import TensorField, SearchMethod
from marqo.tensor_search import configs
from tests.marqo_test import MarqoTestCase
from unittest import mock
from marqo.api import exceptions
from marqo.api import exceptions, configs


@unittest.skip
class TestIndexMetaCache(MarqoTestCase):
Expand Down
5 changes: 2 additions & 3 deletions tests/tensor_search/test_on_start_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

from tests.marqo_test import MarqoTestCase
from unittest import mock
from marqo.tensor_search import enums, configs
from marqo.tensor_search import enums
from marqo.tensor_search import on_start_script
from marqo.s2_inference import s2_inference
from marqo.api import exceptions
from marqo.api import exceptions, configs
import os


Expand Down
6 changes: 3 additions & 3 deletions tests/tensor_search/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def test_read_env_vars_and_defaults(self):
mock_default_env_vars = mock.MagicMock()
mock_default_env_vars.return_value = default_vars

@mock.patch("marqo.tensor_search.configs.default_env_vars", mock_default_env_vars)
@mock.patch("marqo.api.configs.default_env_vars", mock_default_env_vars)
@mock.patch.dict(os.environ, mock_real_environ)
def run():
assert expected == utils.read_env_vars_and_defaults(var=key)
Expand Down Expand Up @@ -184,7 +184,7 @@ def test_read_env_vars_and_defaults_ints(self):
mock_default_env_vars = mock.MagicMock()
mock_default_env_vars.return_value = default_vars

@mock.patch("marqo.tensor_search.configs.default_env_vars", mock_default_env_vars)
@mock.patch("marqo.api.configs.default_env_vars", mock_default_env_vars)
@mock.patch.dict(os.environ, mock_real_environ)
def run():
result = utils.read_env_vars_and_defaults_ints(var=key)
Expand All @@ -204,7 +204,7 @@ def test_read_env_vars_and_defaults_ints_invalid_values(self):
mock_default_env_vars = mock.MagicMock()
mock_default_env_vars.return_value = default_vars

@mock.patch("marqo.tensor_search.configs.default_env_vars", mock_default_env_vars)
@mock.patch("marqo.api.configs.default_env_vars", mock_default_env_vars)
@mock.patch.dict(os.environ, mock_real_environ)
def run():
with self.assertRaises(exceptions.ConfigurationError):
Expand Down

0 comments on commit a5ba6e9

Please sign in to comment.