Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Danyil/feature/use marqo commons #607

Closed
wants to merge 12 commits into from
3 changes: 2 additions & 1 deletion requirements.dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@ certifi==2019.11.28
idna==2.8
six==1.14.0
typing-extensions==4.5.0
urllib3==1.25.8
urllib3==1.25.8
marqo-commons @ git+https://github.com/marqo-ai/marqo-commons
29 changes: 29 additions & 0 deletions src/marqo/s2_inference/model_loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from marqo.s2_inference.hf_utils import HF_MODEL
from marqo.s2_inference.sbert_onnx_utils import SBERT_ONNX
from marqo.s2_inference.sbert_utils import SBERT, TEST
from marqo.s2_inference.random_utils import Random
from marqo.s2_inference.clip_utils import CLIP, OPEN_CLIP, MULTILINGUAL_CLIP, FP16_CLIP, get_multilingual_clip_properties
from marqo.s2_inference.types import Any, Dict, List, Optional, Union, FloatTensor
from marqo.s2_inference.onnx_clip_utils import CLIP_ONNX

# we need to keep track of the embed dim and model load functions/classes
# we can use this as a registry

def _get_model_load_mappings() -> Dict:
danyilq marked this conversation as resolved.
Show resolved Hide resolved
return {'clip':CLIP,
'open_clip': OPEN_CLIP,
'sbert':SBERT,
'test':TEST,
'sbert_onnx':SBERT_ONNX,
'clip_onnx': CLIP_ONNX,
"multilingual_clip" : MULTILINGUAL_CLIP,
"fp16_clip": FP16_CLIP,
'random':Random,
'hf':HF_MODEL}

def get_model_loaders() -> Dict:
loaders = dict()
for key,val in _get_model_load_mappings().items():
loaders[key] = val

return loaders
1,782 changes: 0 additions & 1,782 deletions src/marqo/s2_inference/model_registry.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/marqo/s2_inference/onnx_clip_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from marqo.s2_inference.logger import get_logger
import onnxruntime as ort
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
import marqo.s2_inference.model_registry as model_registry
from zipfile import ZipFile
from huggingface_hub.utils import RevisionNotFoundError,RepositoryNotFoundError, EntryNotFoundError, LocalEntryNotFoundError
from marqo.s2_inference.errors import ModelDownloadError
Expand Down Expand Up @@ -60,6 +59,7 @@ class CLIP_ONNX(object):
def __init__(self, model_name: str ="onnx32/openai/ViT-L/14", device: str = None, embedding_dim: int = None,
truncate: bool = True,
load=True, **kwargs):
from marqo.s2_inference.s2_inference import get_model_properties_from_registry
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this not declared at the top of the file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

due to a circular import issue, we'd have to resolve it. maybe onnx clip object can work without that function but it needs deeper investigation

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pandu-k Do you have any input on resolving this circular import? Declaring it in this init function works but it feels a bit hacky. The issue is that onnx_clip_utils here calls s2_inference, which imports from model_loaders, which imports from onnx_clip_utils.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This strange situation is an indication that our package structure isn't ideal. Good pick up @vicilliar!

I think the best way is for the model_registry is purely for static data (a.k.a the model properties).

We can then create a new file that has _get_model_load_mappings() and load_model_properties(). Would that solve this?

self.model_name = model_name
self.onnx_type, self.source, self.clip_model = self.model_name.split("/", 2)
if not device:
Expand All @@ -70,7 +70,7 @@ def __init__(self, model_name: str ="onnx32/openai/ViT-L/14", device: str = None
"CPUExecutionProvider"]
self.visual_session = None
self.textual_session = None
self.model_info = model_registry._get_onnx_clip_properties()[self.model_name]
self.model_info = get_model_properties_from_registry(self.model_name)

self.visual_type = np.float16 if self.onnx_type == "onnx16" else np.float32
self.textual_type = np.int64 if self.source == "open_clip" else np.int32
Expand Down
15 changes: 9 additions & 6 deletions src/marqo/s2_inference/s2_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
The functions defined here would have endpoints, later on.
"""
import numpy as np
from marqo_commons.model_registry.model_registry import get_model_properties_dict

from marqo.errors import ModelCacheManagementError, InvalidArgError, ConfigurationError, InternalError
from marqo.s2_inference.errors import (
VectoriseError, InvalidModelPropertiesError, ModelLoadError,
UnknownModelError, ModelNotInCacheError, ModelDownloadError, S2InferenceError)
from PIL import UnidentifiedImageError
from marqo.s2_inference.model_registry import load_model_properties
from marqo.s2_inference.model_loaders import get_model_loaders
from marqo.s2_inference.configs import get_default_normalization, get_default_seq_length
from marqo.s2_inference.types import *
from marqo.s2_inference.logger import get_logger
Expand All @@ -28,7 +30,8 @@
available_models = dict()
# A lock to protect the model loading process
lock = threading.Lock()
MODEL_PROPERTIES = load_model_properties()
MODEL_PROPERTIES = get_model_properties_dict()
MODEL_LOADERS = get_model_loaders()


def vectorise(model_name: str, content: Union[str, List[str]], model_properties: dict = None,
Expand Down Expand Up @@ -369,11 +372,11 @@ def get_model_properties_from_registry(model_name: str) -> dict:
dict: a dictionary describing properties of the model.
"""

if model_name not in MODEL_PROPERTIES['models']:
if model_name not in MODEL_PROPERTIES:
raise UnknownModelError(f"Could not find model properties in model registry for model={model_name}. "
f"Model is not supported by default.")

return MODEL_PROPERTIES['models'][model_name]
return MODEL_PROPERTIES[model_name]


def _check_output_type(output: List[List[float]]) -> bool:
Expand Down Expand Up @@ -505,10 +508,10 @@ def _get_model_loader(model_name: str, model_properties: dict) -> Any:

model_type = model_properties['type']

if model_type not in MODEL_PROPERTIES['loaders']:
if model_type not in MODEL_LOADERS:
raise KeyError(f"model_name={model_name} for model_type={model_type} not in allowed model types")

return MODEL_PROPERTIES['loaders'][model_type]
return MODEL_LOADERS[model_type]


def get_available_models():
Expand Down
227 changes: 0 additions & 227 deletions src/marqo/tensor_search/models/settings_object.py

This file was deleted.

18 changes: 16 additions & 2 deletions src/marqo/tensor_search/tensor_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
import functools
import pprint
import typing

from marqo.errors import InvalidArgError
from marqo.tensor_search.models.private_models import ModelAuth
import uuid
from typing import List, Optional, Union, Iterable, Sequence, Dict, Any, Tuple
Expand All @@ -52,6 +54,7 @@
)
from marqo.tensor_search.enums import IndexSettingsField as NsField
from marqo.tensor_search import utils, backend, validation, configs, add_docs, filtering
from marqo_commons.settings_validation.settings_validation import validate_index_settings
from marqo.tensor_search.formatting import _clean_doc
from marqo.tensor_search.index_meta_cache import get_cache, get_index_info
from marqo.tensor_search import index_meta_cache
Expand All @@ -61,7 +64,7 @@
from marqo.tensor_search.models.external_apis.abstract_classes import ExternalAuth
from marqo.tensor_search.telemetry import RequestMetricsStore
from marqo.tensor_search.health import generate_heath_check_response
from marqo.tensor_search.utils import add_timing
from marqo.tensor_search.utils import add_timing, read_env_vars_and_defaults_ints
from marqo.tensor_search import delete_docs
from marqo.s2_inference.processing import text as text_processor
from marqo.s2_inference.processing import image as image_processor
Expand All @@ -77,6 +80,7 @@
from marqo import errors
from marqo.s2_inference import errors as s2_inference_errors
import threading
from marqo_commons.shared_utils.errors import InvalidSettingsArgError
from dataclasses import replace
from marqo.tensor_search.tensor_search_logging import get_logger

Expand Down Expand Up @@ -140,7 +144,17 @@ def create_vector_index(
else:
the_index_settings = configs.get_default_index_settings()

validation.validate_settings_object(settings_object=the_index_settings)
try:
"""Validates the index settings using validate_index_settings function from marqo-commons.
validate_index_settings on error raises InvalidSettingsArgError from marqo-commons.
To propagate native error catches InvalidSettingsArgError and raises InvalidArgError from marqo."""
validate_index_settings(
settings_to_validate=the_index_settings,
MAX_EF_CONSTRUCTION_VALUE=read_env_vars_and_defaults_ints(EnvVars.MARQO_EF_CONSTRUCTION_MAX_VALUE),
MAX_NUMBER_OF_REPLICAS=read_env_vars_and_defaults_ints(EnvVars.MARQO_MAX_NUMBER_OF_REPLICAS),
)
except InvalidSettingsArgError as e:
danyilq marked this conversation as resolved.
Show resolved Hide resolved
raise InvalidArgError(e)

vector_index_settings = {
"settings": {
Expand Down
Loading
Loading