Skip to content

Commit

Permalink
Add ability to cache text prop models (#2641)
Browse files Browse the repository at this point in the history
* Add ability to cache text prop models

* isort fix

* adding test to model caching

* adding a comment

* adding license

* adding license to test

* exclude certifi.2023.7.22

* disabling the test for now

* add docstrings

* docstring changes
  • Loading branch information
shayts7 committed Jul 25, 2023
1 parent fa069e6 commit bd95147
Show file tree
Hide file tree
Showing 4 changed files with 261 additions and 197 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Expand Up @@ -128,7 +128,7 @@ jobs:
with:
requirements: 'requirements-all.txt'
fail: 'Copyleft,Other,Error'
exclude: '(pyzmq.*23\.2\.1|debugpy.*1\.6\.7|certifi.*2023\.5\.7|tqdm.*4\.65\.0|webencodings.*0\.5\.1|torch.*1\.10\.2.*|torchvision.*0\.11\.3.*|terminado.*0\.15\.0.*|urllib3.*1\.26\.11.*|imageio.*2\.20\.0.*|jsonschema.*4\.8\.0.*|qudida.*0\.0\.4)'
exclude: '(pyzmq.*23\.2\.1|debugpy.*1\.6\.7|certifi.*2023\.7\.22|tqdm.*4\.65\.0|webencodings.*0\.5\.1|torch.*1\.10\.2.*|torchvision.*0\.11\.3.*|terminado.*0\.15\.0.*|urllib3.*1\.26\.11.*|imageio.*2\.20\.0.*|jsonschema.*4\.8\.0.*|qudida.*0\.0\.4)'
# pyzmq is Revised BSD https://github.com/zeromq/pyzmq/blob/main/examples/LICENSE
# debugpy is MIT https://github.com/microsoft/debugpy/blob/main/LICENSE
# certifi is MPL-2.0 https://github.com/certifi/python-certifi/blob/master/LICENSE
Expand Down
206 changes: 16 additions & 190 deletions deepchecks/nlp/utils/text_properties.py
Expand Up @@ -10,7 +10,6 @@
#
"""Module containing the text properties for the NLP module."""
import gc
import importlib
import pathlib
import re
import string
Expand All @@ -20,7 +19,6 @@

import numpy as np
import pandas as pd
import requests
import textblob
from nltk import corpus
from nltk import download as nltk_download
Expand All @@ -30,15 +28,14 @@

from deepchecks.core.errors import DeepchecksValueError
from deepchecks.nlp.utils.text import cut_string, hash_text, normalize_text, remove_punctuation
from deepchecks.nlp.utils.text_properties_models import get_cmudict_dict, get_fasttext_model, get_transformer_pipeline
from deepchecks.utils.function import run_available_kwargs
from deepchecks.utils.strings import SPECIAL_CHARACTERS, format_list

__all__ = ['calculate_builtin_properties', 'get_builtin_properties_types']

from deepchecks.utils.validation import is_sequence_not_str

MODELS_STORAGE = pathlib.Path(__file__).absolute().parent / '.nlp-models'
FASTTEXT_LANG_MODEL = 'https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin'
DEFAULT_SENTENCE_SAMPLE_SIZE = 300
MAX_CHARS = 512 # Bert accepts max of 512 tokens, so without counting tokens we go for the lower bound.
# all SPECIAL_CHARACTERS - all string.punctuation except for <>@[]^_`{|}~ - all whitespace
Expand Down Expand Up @@ -103,159 +100,12 @@ def _sample_for_property(text: str, mode: str = 'words', limit: int = 10000, ret
return ' '.join(all_units) if not return_as_list else list(all_units)


def _import_optional_property_dependency(
module: str,
property_name: str,
package_name: Optional[str] = None,
error_template: Optional[str] = None
):
try:
lib = importlib.import_module(module)
except ImportError as error:
package_name = package_name or module.split('.', maxsplit=1)[0]
error_template = error_template or (
'property {property_name} requires the {package_name} python package. '
'To get it, run:\n'
'>> pip install {package_name}\n\n'
'You may install dependencies for all text properties by running:\n'
'>> pip install deepchecks[nlp-properties]\n'
)
raise ImportError(error_template.format(
property_name=property_name,
package_name=package_name
)) from error
else:
return lib


def _warn_if_missing_nltk_dependencies(dependency: str, property_name: str):
"""Warn if NLTK dependency is missing."""
warnings.warn(f'NLTK {dependency} not found, {property_name} cannot be calculated.'
' Please check your internet connection.', UserWarning)


def get_create_model_storage(models_storage: Union[pathlib.Path, str, None] = None):
"""Get the models storage directory and create it if needed."""
if models_storage is None:
models_storage = MODELS_STORAGE
else:
if isinstance(models_storage, str):
models_storage = pathlib.Path(models_storage)
if not isinstance(models_storage, pathlib.Path):
raise ValueError(
f'Unexpected type of the "models_storage" parameter - {type(models_storage)}'
)
if not models_storage.exists():
models_storage.mkdir(parents=True)
if not models_storage.is_dir():
raise ValueError('"model_storage" expected to be a directory')

return models_storage


def get_transformer_model(
property_name: str,
model_name: str,
device: Optional[str] = None,
quantize_model: bool = False,
models_storage: Union[pathlib.Path, str, None] = None
):
"""Get the transformer model and decide if to use optimum.onnxruntime.
optimum.onnxruntime is used to optimize running times on CPU.
"""
models_storage = get_create_model_storage(models_storage)

if device not in (None, 'cpu'):
transformers = _import_optional_property_dependency('transformers', property_name=property_name)
# TODO: quantize if 'quantize_model' is True
return transformers.AutoModelForSequenceClassification.from_pretrained(
model_name,
cache_dir=models_storage,
device_map=device
)

onnx = _import_optional_property_dependency(
'optimum.onnxruntime',
property_name=property_name,
error_template=(
f'The device was set to {device} while computing the {property_name} property,'
'in which case deepchecks resorts to accelerating the inference by using optimum,'
'bit it is not installed. Either:\n'
'\t- Set the device according to your hardware;\n'
'\t- Install optimum by running "pip install optimum";\n'
'\t- Install all dependencies needed for text properties by running '
'"pip install deepchecks[nlp-properties]";\n'
)
)

if quantize_model is False:
model_path = models_storage / 'onnx' / model_name

if model_path.exists():
return onnx.ORTModelForSequenceClassification.from_pretrained(model_path).to(device or -1)

model = onnx.ORTModelForSequenceClassification.from_pretrained(
model_name,
export=True,
cache_dir=models_storage,
).to(device or -1)
# NOTE:
# 'optimum', after exporting/converting a model to the ONNX format,
# does not store it onto disk we need to save it now to not reconvert
# it each time
model.save_pretrained(model_path)
return model

model_path = models_storage / 'onnx' / 'quantized' / model_name

if model_path.exists():
return onnx.ORTModelForSequenceClassification.from_pretrained(model_path).to(device or -1)

not_quantized_model = get_transformer_model(
property_name,
model_name,
device,
quantize_model=False,
models_storage=models_storage
)

quantizer = onnx.ORTQuantizer.from_pretrained(not_quantized_model).to(device or -1)

quantizer.quantize(
save_dir=model_path,
# TODO: make it possible to provide a config as a parameter
quantization_config=onnx.configuration.AutoQuantizationConfig.avx512_vnni(
is_static=False,
per_channel=False
)
)
return onnx.ORTModelForSequenceClassification.from_pretrained(model_path).to(device or -1)


def get_transformer_pipeline(
property_name: str,
model_name: str,
device: Optional[str] = None,
models_storage: Union[pathlib.Path, str, None] = None
):
"""Return a transformers pipeline for the given model name."""
transformers = _import_optional_property_dependency('transformers', property_name=property_name)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, device_map=device)
model = get_transformer_model(
property_name=property_name,
model_name=model_name,
device=device,
models_storage=models_storage
)
return transformers.pipeline(
'text-classification',
model=model,
tokenizer=tokenizer,
device=device
)


def text_length(text: str) -> int:
"""Return text length."""
return len(text)
Expand Down Expand Up @@ -283,36 +133,6 @@ def max_word_length(text: str) -> int:
return max(len(w) for w in words) if words else 0


def _get_fasttext_model(models_storage: Union[pathlib.Path, str, None] = None):
"""Return fasttext model."""
fasttext = _import_optional_property_dependency(module='fasttext', property_name='language')

model_name = FASTTEXT_LANG_MODEL.rsplit('/', maxsplit=1)[-1]
model_path = get_create_model_storage(models_storage)
model_path = model_path / 'fasttext'

if not model_path.exists():
model_path.mkdir(parents=True)

model_path = model_path / model_name

# Save the model to a file
if not model_path.exists():
response = requests.get(FASTTEXT_LANG_MODEL, timeout=240)
if response.status_code != 200:
raise RuntimeError('Failed to donwload fasttext model')
model_path.write_bytes(response.content)

# This weird code is to suppress a warning from fasttext about a deprecated function
try:
fasttext.FastText.eprint = lambda *args, **kwargs: None
fasttext_model = fasttext.load_model(str(model_path))
except Exception as exp:
raise exp

return fasttext_model


def language(
text: str,
lang_certainty_threshold: float = 0.8,
Expand All @@ -324,7 +144,7 @@ def language(
# Load the model if it wasn't received as a parameter. This is done to avoid loading the model
# each time the function is called.
if fasttext_model is None:
fasttext_model = _get_fasttext_model()
fasttext_model = get_fasttext_model()

# Predictions are the first prediction (k=1), only if the probability is above the threshold
prediction = fasttext_model.predict(text.replace('\n', ' '), k=1, threshold=lang_certainty_threshold)[0]
Expand Down Expand Up @@ -830,7 +650,8 @@ def calculate_builtin_properties(
ignore_non_english_samples_for_english_properties: bool = True,
device: Optional[str] = None,
models_storage: Union[pathlib.Path, str, None] = None,
batch_size: Optional[int] = 16
batch_size: Optional[int] = 16,
cache_models: bool = False
) -> Tuple[Dict[str, List[float]], Dict[str, str]]:
"""Calculate properties on provided text samples.
Expand Down Expand Up @@ -875,6 +696,8 @@ def calculate_builtin_properties(
Also, if a folder already contains relevant resources they are not re-downloaded.
batch_size : int, default 8
The batch size.
cache_models : bool, default False
cache the models being used in this function, to save load time in next execution
Returns
-------
Expand Down Expand Up @@ -902,7 +725,7 @@ def calculate_builtin_properties(

# Prepare kwargs for properties that require outside resources:
if 'fasttext_model' not in kwargs:
kwargs['fasttext_model'] = _get_fasttext_model(models_storage=models_storage)
kwargs['fasttext_model'] = get_fasttext_model(models_storage=models_storage, use_cache=cache_models)

if 'cmudict_dict' not in kwargs:
properties_requiring_cmudict = list(set(CMUDICT_PROPERTIES) & set(text_properties_names))
Expand All @@ -911,20 +734,22 @@ def calculate_builtin_properties(
_warn_if_missing_nltk_dependencies('cmudict', format_list(properties_requiring_cmudict))
for prop in properties_requiring_cmudict:
calculated_properties[prop] = [np.nan] * len(raw_text)
cmudict_dict = corpus.cmudict.dict()
kwargs['cmudict_dict'] = cmudict_dict
kwargs['cmudict_dict'] = get_cmudict_dict(use_cache=cache_models)

if 'Toxicity' in text_properties_names and 'toxicity_classifier' not in kwargs:
kwargs['toxicity_classifier'] = get_transformer_pipeline(
property_name='toxicity', model_name=TOXICITY_MODEL_NAME, device=device, models_storage=models_storage)
property_name='toxicity', model_name=TOXICITY_MODEL_NAME, device=device,
models_storage=models_storage, use_cache=cache_models)

if 'Formality' in text_properties_names and 'formality_classifier' not in kwargs:
kwargs['formality_classifier'] = get_transformer_pipeline(
property_name='formality', model_name=FORMALITY_MODEL_NAME, device=device, models_storage=models_storage)
property_name='formality', model_name=FORMALITY_MODEL_NAME, device=device,
models_storage=models_storage, use_cache=cache_models)

if 'Fluency' in text_properties_names and 'fluency_classifier' not in kwargs:
kwargs['fluency_classifier'] = get_transformer_pipeline(
property_name='fluency', model_name=FLUENCY_MODEL_NAME, device=device, models_storage=models_storage)
property_name='fluency', model_name=FLUENCY_MODEL_NAME, device=device,
models_storage=models_storage, use_cache=cache_models)

is_language_property_requested = 'Language' in [prop['name'] for prop in text_properties]
# Remove language property from the list of properties to calculate as it will be calculated separately:
Expand Down Expand Up @@ -994,7 +819,8 @@ def calculate_builtin_properties(
sentences_cache.clear()

# Clean all remaining RAM:
gc.collect()
if not cache_models:
gc.collect()

if not calculated_properties:
raise RuntimeError('Failed to calculate any of the properties.')
Expand Down

0 comments on commit bd95147

Please sign in to comment.