diff --git a/.github/workflows/e2e-tests.yml b/.github/workflows/e2e-tests.yml new file mode 100644 index 00000000..f5c67da5 --- /dev/null +++ b/.github/workflows/e2e-tests.yml @@ -0,0 +1,54 @@ +name: E2E Tests Comet LLM +env: + COMET_RAISE_EXCEPTIONS_ON_ERROR: "1" + COMET_API_KEY: ${{ secrets.PRODUCTION_CI_COMET_API_KEY }} +on: + pull_request: + +jobs: + UnitTests: + name: E2E_Python_${{matrix.python_version}} + runs-on: ubuntu-20.04 + strategy: + fail-fast: false + matrix: + python_version: ["3.6", "3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] + + steps: + - name: Check out code + uses: actions/checkout@v3 + + - name: Set the project name + run: | + echo "COMET_PROJECT_NAME=comet-llm-e2e-tests-py${{ matrix.python_version }}" >> $GITHUB_ENV + + - name: Print environment variables + run: env + + - name: Print event object + run: cat $GITHUB_EVENT_PATH + + - name: Print the PR title + run: echo "${{ github.event.pull_request.title }}" + + - name: Setup Python ${{ matrix.python_version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python_version }} + + - name: Install comet-llm + run: pip install -e . + + - name: Install test requirements + run: | + cd ./tests + pip install --no-cache-dir --disable-pip-version-check -r test_requirements.txt + + - name: Running SDK e2e Tests + run: python -m pytest --cov=src/comet_llm --cov-report=html:coverage_report_${{matrix.python_version}} -vv tests/e2e/ + + - name: archive coverage report + uses: actions/upload-artifact@v3 + with: + name: coverage_report_${{matrix.python_version}} + path: coverage_report_${{matrix.python_version}} diff --git a/setup.py b/setup.py index e14497a2..105145a6 100644 --- a/setup.py +++ b/setup.py @@ -59,7 +59,7 @@ package_dir={"": "src"}, url="https://www.comet.com", project_urls=project_urls, - version="2.2.2", + version="2.2.4", zip_safe=False, license="MIT", ) diff --git a/src/comet_llm/chains/api.py b/src/comet_llm/chains/api.py index 3318d98c..2f00387a 100644 --- a/src/comet_llm/chains/api.py +++ b/src/comet_llm/chains/api.py @@ -105,6 +105,7 @@ def log_chain(chain: chain.Chain) -> Optional[llm_result.LLMResult]: chain_data = chain.as_dict() message = messages.ChainMessage( + id=messages.generate_id(), experiment_info_=chain.experiment_info, tags=chain.tags, chain_data=chain_data, diff --git a/src/comet_llm/experiment_api/comet_api_client.py b/src/comet_llm/experiment_api/comet_api_client.py index be303032..3b24459f 100644 --- a/src/comet_llm/experiment_api/comet_api_client.py +++ b/src/comet_llm/experiment_api/comet_api_client.py @@ -13,19 +13,22 @@ # ******************************************************* import functools +import logging import urllib.parse import warnings -from typing import IO, List, Optional +from typing import IO, Any, Dict, List, Optional import requests # type: ignore import urllib3.exceptions -from .. import config +from .. import config, exceptions, semantic_version from ..types import JSONEncodable -from . import request_exception_wrapper +from . import error_codes_mapping, payload_constructor ResponseContent = JSONEncodable +LOGGER = logging.getLogger(__name__) + class CometAPIClient: def __init__(self, api_key: str, comet_url: str, session: requests.Session): @@ -33,6 +36,16 @@ def __init__(self, api_key: str, comet_url: str, session: requests.Session): self._comet_url = comet_url self._session = session + self.backend_version = semantic_version.SemanticVersion.parse( + self.is_alive_ver()["version"] + ) + + def is_alive_ver(self) -> ResponseContent: + return self._request( + "GET", + "api/isAlive/ver", + ) + def create_experiment( self, type_: str, @@ -122,6 +135,56 @@ def log_experiment_other( }, ) + def log_chain( + self, + experiment_key: str, + chain_asset: Dict[str, JSONEncodable], + workspace: Optional[str] = None, + project: Optional[str] = None, + parameters: Optional[Dict[str, JSONEncodable]] = None, + metrics: Optional[Dict[str, JSONEncodable]] = None, + tags: Optional[List[str]] = None, + others: Optional[Dict[str, JSONEncodable]] = None, + ) -> ResponseContent: + json = [ + { + "experimentKey": experiment_key, + "createExperimentRequest": { + "workspaceName": workspace, + "projectName": project, + "type": "LLM", + }, + "parameters": payload_constructor.chain_parameters_payload(parameters), + "metrics": payload_constructor.chain_metrics_payload(metrics), + "others": payload_constructor.chain_others_payload(others), + "tags": tags, + "jsonAsset": { + "extension": "json", + "type": "llm_data", + "fileName": "comet_llm_data.json", + "file": chain_asset, + }, + } + ] # we make a list because endpoint is designed for batches + + batched_response: Dict[str, Dict[str, Any]] = self._request( + "POST", + "api/rest/v2/write/experiment/llm", + json=json, + ) + sub_response = list(batched_response.values())[0] + status = sub_response["status"] + if status != 200: + LOGGER.debug( + "Failed to send trace: \nPayload %s, Response %s", + str(json), + str(batched_response), + ) + error_code = sub_response["content"]["sdk_error_code"] + raise exceptions.CometLLMException(error_codes_mapping.MESSAGES[error_code]) + + return sub_response["content"] + def _request(self, method: str, path: str, *args, **kwargs) -> ResponseContent: # type: ignore url = urllib.parse.urljoin(self._comet_url, path) response = self._session.request( diff --git a/src/comet_llm/experiment_api/failed_response_handler.py b/src/comet_llm/experiment_api/error_codes_mapping.py similarity index 61% rename from src/comet_llm/experiment_api/failed_response_handler.py rename to src/comet_llm/experiment_api/error_codes_mapping.py index ed9dc2e5..94dcb7ab 100644 --- a/src/comet_llm/experiment_api/failed_response_handler.py +++ b/src/comet_llm/experiment_api/error_codes_mapping.py @@ -13,24 +13,12 @@ # ******************************************************* import collections -import json -from typing import NoReturn -import requests # type: ignore +from .. import backend_error_codes, logging_messages -from .. import backend_error_codes, exceptions, logging_messages - -_SDK_ERROR_CODES_LOGGING_MESSAGE = collections.defaultdict( +MESSAGES = collections.defaultdict( lambda: logging_messages.FAILED_TO_SEND_DATA_TO_SERVER, { backend_error_codes.UNABLE_TO_LOG_TO_NON_LLM_PROJECT: logging_messages.UNABLE_TO_LOG_TO_NON_LLM_PROJECT }, ) - - -def handle(exception: requests.RequestException) -> NoReturn: - response = exception.response - sdk_error_code = json.loads(response.text)["sdk_error_code"] - error_message = _SDK_ERROR_CODES_LOGGING_MESSAGE[sdk_error_code] - - raise exceptions.CometLLMException(error_message) from exception diff --git a/src/comet_llm/experiment_api/payload_constructor.py b/src/comet_llm/experiment_api/payload_constructor.py new file mode 100644 index 00000000..4ebcd774 --- /dev/null +++ b/src/comet_llm/experiment_api/payload_constructor.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# ******************************************************* +# ____ _ _ +# / ___|___ _ __ ___ ___| |_ _ __ ___ | | +# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | +# | |__| (_) | | | | | | __/ |_ _| | | | | | | +# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| +# +# Sign up for free at https://www.comet.com +# Copyright (C) 2015-2023 Comet ML INC +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this package. +# ******************************************************* + +from typing import Dict, List, Optional + +from ..types import JSONEncodable + + +def chain_parameters_payload( + parameters: Optional[Dict[str, JSONEncodable]] +) -> List[Dict[str, JSONEncodable]]: + return _dict_to_payload_format(parameters, "parameterName", "parameterValue") + + +def chain_metrics_payload( + metrics: Optional[Dict[str, JSONEncodable]] +) -> List[Dict[str, JSONEncodable]]: + return _dict_to_payload_format(metrics, "metricName", "metricValue") + + +def chain_others_payload( + others: Optional[Dict[str, JSONEncodable]] +) -> List[Dict[str, JSONEncodable]]: + return _dict_to_payload_format(others, "key", "value") + + +def _dict_to_payload_format( + source: Optional[Dict[str, JSONEncodable]], key_name: str, value_name: str +) -> List[Dict[str, JSONEncodable]]: + if source is None: + return [] + + result = [{key_name: key, value_name: value} for key, value in source.items()] + + return result diff --git a/src/comet_llm/experiment_api/request_exception_wrapper.py b/src/comet_llm/experiment_api/request_exception_wrapper.py index 52914b90..88fe4e22 100644 --- a/src/comet_llm/experiment_api/request_exception_wrapper.py +++ b/src/comet_llm/experiment_api/request_exception_wrapper.py @@ -13,15 +13,16 @@ # ******************************************************* import functools +import json import logging import urllib.parse from pprint import pformat -from typing import Any, Callable, List +from typing import Any, Callable, List, NoReturn import requests # type: ignore from .. import config, exceptions, logging_messages -from . import failed_response_handler +from . import error_codes_mapping LOGGER = logging.getLogger(__name__) @@ -49,7 +50,7 @@ def wrapper(*args, **kwargs) -> Any: # type: ignore logging_messages.FAILED_TO_SEND_DATA_TO_SERVER ) from exception - failed_response_handler.handle(exception) + _handle_request_exception(exception) return wrapper @@ -73,3 +74,11 @@ def _debug_log(exception: requests.RequestException) -> None: # Make sure we won't fail on attempt to debug. # It's mainly for tests when response object can be mocked pass + + +def _handle_request_exception(exception: requests.RequestException) -> NoReturn: + response = exception.response + sdk_error_code = json.loads(response.text)["sdk_error_code"] + error_message = error_codes_mapping.MESSAGES[sdk_error_code] + + raise exceptions.CometLLMException(error_message) from exception diff --git a/src/comet_llm/message_processing/messages.py b/src/comet_llm/message_processing/messages.py index b3c2d07d..5bd85dd5 100644 --- a/src/comet_llm/message_processing/messages.py +++ b/src/comet_llm/message_processing/messages.py @@ -13,7 +13,7 @@ # ******************************************************* import dataclasses -import inspect +import uuid from typing import Any, ClassVar, Dict, List, Optional, Union from comet_llm.types import JSONEncodable @@ -21,16 +21,24 @@ from .. import experiment_info, logging_messages +def generate_id() -> str: + return uuid.uuid4().hex + + @dataclasses.dataclass class BaseMessage: experiment_info_: experiment_info.ExperimentInfo + id: str VERSION: ClassVar[int] @classmethod def from_dict( cls, d: Dict[str, Any], api_key: Optional[str] = None ) -> "BaseMessage": - d.pop("VERSION") # + version = d.pop("VERSION") + if version == 1: + # Message was dumped before id was introduced. We can generate it now. + d["id"] = generate_id() experiment_info_dict: Dict[str, Optional[str]] = d.pop("experiment_info_") experiment_info_ = experiment_info.get( @@ -57,7 +65,7 @@ class PromptMessage(BaseMessage): metadata: Optional[Dict[str, Union[str, bool, float, None]]] tags: Optional[List[str]] - VERSION: ClassVar[int] = 1 + VERSION: ClassVar[int] = 2 @dataclasses.dataclass @@ -69,4 +77,4 @@ class ChainMessage(BaseMessage): others: Dict[str, JSONEncodable] # 'other' - is a name of an attribute of experiment, logged via log_other - VERSION: ClassVar[int] = 1 + VERSION: ClassVar[int] = 2 diff --git a/src/comet_llm/message_processing/offline_message_processor.py b/src/comet_llm/message_processing/offline_message_processor.py index b5d032b5..d2fe7b43 100644 --- a/src/comet_llm/message_processing/offline_message_processor.py +++ b/src/comet_llm/message_processing/offline_message_processor.py @@ -44,15 +44,9 @@ def process(self, message: messages.BaseMessage) -> None: file_path = pathlib.Path(self._offline_directory, self._current_file_name) if isinstance(message, messages.PromptMessage): - try: - return prompt.send(message, str(file_path)) - except Exception: - LOGGER.error("Failed to log prompt", exc_info=True) + return prompt.send(message, str(file_path)) elif isinstance(message, messages.ChainMessage): - try: - return chain.send(message, str(file_path)) - except Exception: - LOGGER.error("Failed to log chain", exc_info=True) + return chain.send(message, str(file_path)) LOGGER.debug(f"Unsupported message type {message}") return None diff --git a/src/comet_llm/message_processing/online_message_processor.py b/src/comet_llm/message_processing/online_message_processor.py index 80f7b446..14cb0e26 100644 --- a/src/comet_llm/message_processing/online_message_processor.py +++ b/src/comet_llm/message_processing/online_message_processor.py @@ -28,15 +28,9 @@ def __init__(self) -> None: def process(self, message: messages.BaseMessage) -> Optional[llm_result.LLMResult]: if isinstance(message, messages.PromptMessage): - try: - return prompt.send(message) - except Exception: - LOGGER.error("Failed to log prompt", exc_info=True) + return prompt.send(message) elif isinstance(message, messages.ChainMessage): - try: - return chain.send(message) - except Exception: - LOGGER.error("Failed to log chain", exc_info=True) + return chain.send(message) LOGGER.debug(f"Unsupported message type {message}") return None diff --git a/src/comet_llm/message_processing/online_senders/chain.py b/src/comet_llm/message_processing/online_senders/chain.py index 7502adad..a200b0db 100644 --- a/src/comet_llm/message_processing/online_senders/chain.py +++ b/src/comet_llm/message_processing/online_senders/chain.py @@ -15,12 +15,23 @@ import io import json -from comet_llm import app, convert, experiment_api, llm_result +from comet_llm import app, convert, experiment_api, llm_result, url_helpers +from comet_llm.experiment_api import comet_api_client from .. import messages +from . import constants def send(message: messages.ChainMessage) -> llm_result.LLMResult: + client = comet_api_client.get(message.experiment_info_.api_key) + + if client.backend_version >= constants.V2_BACKEND_VERSION: + return _send_v2(message, client) + + return _send_v1(message) + + +def _send_v1(message: messages.ChainMessage) -> llm_result.LLMResult: experiment_api_ = experiment_api.ExperimentAPI.create_new( api_key=message.experiment_info_.api_key, workspace=message.experiment_info_.workspace, @@ -48,3 +59,24 @@ def send(message: messages.ChainMessage) -> llm_result.LLMResult: return llm_result.LLMResult( id=experiment_api_.id, project_url=experiment_api_.project_url ) + + +def _send_v2( + message: messages.ChainMessage, client: comet_api_client.CometAPIClient +) -> llm_result.LLMResult: + metrics = {"chain_duration": message.duration} + parameters = convert.chain_metadata_to_flat_parameters(message.metadata) + + response = client.log_chain( + experiment_key=message.id, + chain_asset=message.chain_data, + workspace=message.experiment_info_.workspace, + project=message.experiment_info_.project_name, + tags=message.tags, + metrics=metrics, + parameters=parameters, + others=message.others, + ) + project_url: str = url_helpers.experiment_to_project_url(response["link"]) + + return llm_result.LLMResult(id=message.id, project_url=project_url) diff --git a/src/comet_llm/message_processing/online_senders/constants.py b/src/comet_llm/message_processing/online_senders/constants.py new file mode 100644 index 00000000..a8f34e13 --- /dev/null +++ b/src/comet_llm/message_processing/online_senders/constants.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# ******************************************************* +# ____ _ _ +# / ___|___ _ __ ___ ___| |_ _ __ ___ | | +# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | +# | |__| (_) | | | | | | __/ |_ _| | | | | | | +# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| +# +# Sign up for free at https://www.comet.com +# Copyright (C) 2015-2023 Comet ML INC +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this package. +# ******************************************************* + +V2_BACKEND_VERSION = "3.25.80" diff --git a/src/comet_llm/message_processing/online_senders/prompt.py b/src/comet_llm/message_processing/online_senders/prompt.py index 6a833863..11771626 100644 --- a/src/comet_llm/message_processing/online_senders/prompt.py +++ b/src/comet_llm/message_processing/online_senders/prompt.py @@ -15,12 +15,23 @@ import io import json -from comet_llm import app, convert, experiment_api, llm_result +from comet_llm import app, convert, experiment_api, llm_result, url_helpers +from comet_llm.experiment_api import comet_api_client from .. import messages +from . import constants def send(message: messages.PromptMessage) -> llm_result.LLMResult: + client = comet_api_client.get(message.experiment_info_.api_key) + + if client.backend_version >= constants.V2_BACKEND_VERSION: + return _send_v2(message, client) + + return _send_v1(message) + + +def _send_v1(message: messages.PromptMessage) -> llm_result.LLMResult: experiment_api_ = experiment_api.ExperimentAPI.create_new( api_key=message.experiment_info_.api_key, workspace=message.experiment_info_.workspace, @@ -47,3 +58,23 @@ def send(message: messages.PromptMessage) -> llm_result.LLMResult: return llm_result.LLMResult( id=experiment_api_.id, project_url=experiment_api_.project_url ) + + +def _send_v2( + message: messages.PromptMessage, client: comet_api_client.CometAPIClient +) -> llm_result.LLMResult: + metrics = {"chain_duration": message.duration} + parameters = convert.chain_metadata_to_flat_parameters(message.metadata) + + response = client.log_chain( + experiment_key=message.id, + chain_asset=message.prompt_asset_data, + workspace=message.experiment_info_.workspace, + project=message.experiment_info_.project_name, + tags=message.tags, + metrics=metrics, + parameters=parameters, + ) + project_url: str = url_helpers.experiment_to_project_url(response["link"]) + + return llm_result.LLMResult(id=message.id, project_url=project_url) diff --git a/src/comet_llm/prompts/api.py b/src/comet_llm/prompts/api.py index 930eb2c7..511835d2 100644 --- a/src/comet_llm/prompts/api.py +++ b/src/comet_llm/prompts/api.py @@ -12,21 +12,9 @@ # LICENSE file in the root directory of this package. # ******************************************************* -import io -import json from typing import Dict, List, Optional, Union -import comet_llm.convert - -from .. import ( - app, - config, - exceptions, - experiment_api, - experiment_info, - llm_result, - logging_messages, -) +from .. import app, config, exceptions, experiment_info, llm_result, logging_messages from ..chains import version from ..message_processing import api as message_processing_api, messages from . import convert, preprocess @@ -132,6 +120,7 @@ def log_prompt( } message = messages.PromptMessage( + id=messages.generate_id(), experiment_info_=info, prompt_asset_data=asset_data, duration=duration, diff --git a/src/comet_llm/semantic_version.py b/src/comet_llm/semantic_version.py new file mode 100644 index 00000000..7635c518 --- /dev/null +++ b/src/comet_llm/semantic_version.py @@ -0,0 +1,253 @@ +# -*- coding: utf-8 -*- +# ******************************************************* +# ____ _ _ +# / ___|___ _ __ ___ ___| |_ _ __ ___ | | +# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | +# | |__| (_) | | | | | | __/ |_ _| | | | | | | +# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| +# +# Sign up for free at https://www.comet.com +# Copyright (C) 2015-2023 Comet ML INC +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this package. +# ******************************************************* + +# vendored from and tested in comet-ml + +import collections +import re +from functools import wraps +from typing import Callable, Dict, List, Optional, Tuple, Union + +VersionPart = Union[int, Optional[str]] +VersionTuple = Tuple[ + int, + int, + int, + Optional[Union[str, int]], + Optional[Union[str, int]], + Optional[Union[str, int]], +] + +ComparableVersion = Union[ + "SemanticVersion", Dict[str, VersionPart], List[VersionPart], VersionTuple, str +] +Comparator = Callable[["SemanticVersion", ComparableVersion], bool] + + +def _cmp(a, b) -> int: # type: ignore + return (a > b) - (a < b) # type: ignore + + +def _comparator(operator: Comparator) -> Comparator: + @wraps(operator) + def wrapper(self: "SemanticVersion", other: ComparableVersion) -> bool: + comparable_types = ( + SemanticVersion, + dict, + tuple, + list, + str, + ) + if not isinstance(other, comparable_types): + return NotImplemented + return operator(self, other) + + return wrapper + + +class SemanticVersion: + # Based on regex from https://semver.org + # Regex template for a semver version + _SEMVER_REGEX_TEMPLATE = r""" + ^ + (?P0|[1-9]\d*) + (?:\.(?P0|[1-9]\d*) + (?:-(?P0|[1-9]\d*|[a-zA-Z-_][0-9a-zA-Z-_]*))? + (?:\.(?P0|[1-9]\d*)){opt_patch} + ){opt_minor} + (?:-(?P + (?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*) + (?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))* + ))? + (?:\+(?P + [0-9a-zA-Z-]+ + (?:\.[0-9a-zA-Z-]+)* + ))? + $ + """ + # Regex for a semver version + _SEMVER_REGEX = re.compile( + _SEMVER_REGEX_TEMPLATE.format(opt_patch="", opt_minor=""), + re.VERBOSE, + ) + # Regex for a semver version that might be shorter + _SEMVER_REGEX_OPTIONAL_MINOR_AND_PATCH = re.compile( + _SEMVER_REGEX_TEMPLATE.format(opt_patch="?", opt_minor="?"), + re.VERBOSE, + ) + + def __init__( + self, + major: int, + minor: int = 0, + patch: int = 0, + feature_branch: Optional[Union[str, int]] = None, + pre_release: Optional[Union[str, int]] = None, + build: Optional[Union[str, int]] = None, + ): + self._major = major + self._minor = minor + self._patch = patch + self._feature_branch = None if feature_branch is None else str(feature_branch) + self._pre_release = None if pre_release is None else str(pre_release) + self._build = None if build is None else str(build) + + @property + def major(self) -> int: + return self._major + + @property + def minor(self) -> int: + return self._minor + + @property + def patch(self) -> int: + return self._patch + + @property + def pre_release(self) -> Optional[str]: + return self._pre_release + + @property + def build(self) -> Optional[str]: + return self._build + + @property + def feature_branch(self) -> Optional[str]: + return self._feature_branch + + def to_tuple(self) -> VersionTuple: + return ( + self.major, + self.minor, + self.patch, + self.feature_branch, + self.pre_release, + self.build, + ) + + def to_dict(self) -> collections.OrderedDict: + return collections.OrderedDict( + ( + ("major", self._major), + ("minor", self._minor), + ("feature_branch", self._feature_branch), + ("patch", self._patch), + ("pre_release", self._pre_release), + ("build", self._build), + ) + ) + + def compare(self, other: ComparableVersion) -> int: + """ + Compare self with other version. + + :param other: another version + :return: The return value is negative if self < other, + zero if self == other and strictly positive if self > other + """ + cls = type(self) + if isinstance(other, str): + other = cls.parse(other) + elif isinstance(other, dict): + other = cls(**other) # type: ignore + elif isinstance(other, (tuple, list)): + other = cls(*other) + elif not isinstance(other, cls): + raise TypeError( + "Wrong type. Expected str, bytes, dict, tuple, list, or %r instance, but got %r" + % (cls.__name__, type(other)) + ) + + v1 = self.to_tuple()[:3] + v2 = other.to_tuple()[:3] + return _cmp(v1, v2) + + @_comparator + def __eq__(self, other: ComparableVersion) -> bool: # type: ignore + return self.compare(other) == 0 + + @_comparator + def __ne__(self, other: ComparableVersion) -> bool: # type: ignore + return self.compare(other) != 0 + + @_comparator + def __lt__(self, other: ComparableVersion) -> bool: + return self.compare(other) < 0 + + @_comparator + def __le__(self, other: ComparableVersion) -> bool: + return self.compare(other) <= 0 + + @_comparator + def __gt__(self, other: ComparableVersion) -> bool: + return self.compare(other) > 0 + + @_comparator + def __ge__(self, other: ComparableVersion) -> bool: + return self.compare(other) >= 0 + + def __repr__(self) -> str: + s = ", ".join("%s=%r" % (key, val) for key, val in self.to_dict().items()) + return "%s(%s)" % (type(self).__name__, s) + + def __str__(self) -> str: + version = "%d.%d" % (self.major, self.minor) + if self._feature_branch: + version += "-%s" % self._feature_branch + + version += ".%d" % self.patch + + if self.pre_release: + version += "-%s" % self.pre_release + if self.build: + version += "+%s" % self.build + return version + + @classmethod + def parse( + cls, version: str, optional_minor_and_patch: bool = False + ) -> "SemanticVersion": + + if not isinstance(version, str): + raise TypeError("wrong version string type %r" % type(version)) + + if optional_minor_and_patch: + match = cls._SEMVER_REGEX_OPTIONAL_MINOR_AND_PATCH.match(version) + else: + match = cls._SEMVER_REGEX.match(version) + if match is None: + raise ValueError("%r is not valid SemVer string" % version) + + version_parts = match.groupdict() + if not version_parts["minor"]: + version_parts["minor"] = 0 + if not version_parts["patch"]: + version_parts["patch"] = 0 + + major = int(version_parts["major"]) + minor = int(version_parts["minor"]) + patch = int(version_parts["patch"]) + feature_branch = version_parts.get("feature_branch", None) + pre_release = version_parts.get("pre_release", None) + build = version_parts.get("build", None) + + return cls( + major=major, + minor=minor, + patch=patch, + feature_branch=feature_branch, + pre_release=pre_release, + build=build, + ) diff --git a/src/comet_llm/url_helpers.py b/src/comet_llm/url_helpers.py index ae075cbe..0940cd43 100644 --- a/src/comet_llm/url_helpers.py +++ b/src/comet_llm/url_helpers.py @@ -12,6 +12,8 @@ # LICENSE file in the root directory of this package. # ******************************************************* +# vendored from and tested in comet-ml + from urllib.parse import urljoin, urlparse, urlunparse @@ -53,3 +55,7 @@ def get_root_url(url: str) -> str: scheme, netloc, path, params, query, fragment = parts return urlunparse((scheme, netloc, "", "", "", "")) + + +def experiment_to_project_url(experiment_url: str) -> str: + return "/".join(experiment_url.split("/")[:-1]) diff --git a/tests/e2e/__init__.py b/tests/e2e/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py new file mode 100644 index 00000000..16630c59 --- /dev/null +++ b/tests/e2e/conftest.py @@ -0,0 +1,9 @@ +import comet_ml + +import pytest + + +@pytest.fixture(scope="session") +def comet_api(): + api = comet_ml.API(cache=False) + return api diff --git a/tests/e2e/test_chains.py b/tests/e2e/test_chains.py new file mode 100644 index 00000000..4409e7f9 --- /dev/null +++ b/tests/e2e/test_chains.py @@ -0,0 +1,47 @@ +import logging +from typing import TYPE_CHECKING + +import comet_llm + +from . import verifier + +if TYPE_CHECKING: + import comet_ml + +LOGGER = logging.getLogger(__name__) + + +def test_start_and_end_chain__happyflow(comet_api: "comet_ml.API"): + # Neither chain nor span inputs and outputs are not verified for now + + comet_llm.start_chain( + inputs="chain-inputs", + tags=["tag1", "tag2"], + metadata={"start-metadata-key": "start-metadata-value"} + ) + + with comet_llm.Span(category="grand-parent", inputs="grand-parent-span-input") as grandparent_span: + with comet_llm.Span(category="parent", inputs="parent-span-input") as parent_span: + with comet_llm.Span(category="llm-call", inputs="llm-call-input") as llm_call_span: + llm_call_span.set_outputs({"llm-call-output-key": "llm-call-output-value"}) + parent_span.set_outputs({"parent-output-key": "parent-output-value"}) + grandparent_span.set_outputs({"grandparent-output-key": "grandparent-output-value"}) + + llm_result = comet_llm.end_chain( + outputs="chain-outputs", + metadata={"end-metadata-key": "end-metadata-value"} + ) + + print("test_start_and_end_chain__happyflow trace ID: %s" % llm_result.id) + + verifier.verify_trace( + comet_api, + llm_result.id, + expected_tags=["tag1", "tag2"], + expected_metadata={ + "start-metadata-key": "start-metadata-value", + "end-metadata-key": "end-metadata-value", + } + ) + + diff --git a/tests/e2e/test_prompts.py b/tests/e2e/test_prompts.py new file mode 100644 index 00000000..85164791 --- /dev/null +++ b/tests/e2e/test_prompts.py @@ -0,0 +1,38 @@ +import logging +from typing import TYPE_CHECKING + +import comet_llm + +from . import verifier + +if TYPE_CHECKING: + import comet_ml + +LOGGER = logging.getLogger(__name__) + +def test_log_prompt__happyflow(comet_api: "comet_ml.API"): + # prompt and output are not verified for now + + llm_result = comet_llm.log_prompt( + prompt="the-input", + output="the-output", + duration=42, + tags=["tag1", "tag2"], + metadata={ + "metadata-key-1": "metadata-value-1", + "metadata-key-2": 123, + } + ) + + print("test_log_prompt__happyflow trace ID: %s" % llm_result.id) + + verifier.verify_trace( + comet_api, + llm_result.id, + expected_duration=42, + expected_tags=["tag1", "tag2"], + expected_metadata={ + "metadata-key-1": "metadata-value-1", + "metadata-key-2": 123, + } + ) diff --git a/tests/e2e/verifier.py b/tests/e2e/verifier.py new file mode 100644 index 00000000..3173b40c --- /dev/null +++ b/tests/e2e/verifier.py @@ -0,0 +1,75 @@ +from typing import Any, Dict, List, Optional + +import comet_ml + +from .. import testlib + + +def verify_trace( + comet_api: "comet_ml.API", + trace_id: str, + expected_duration: Optional[float] = None, + expected_tags: Optional[List[float]] = None, + expected_metadata: Optional[Dict[str, Any]] = None, + ): + """ + Performs assertions for various trace (prompt | chain) attributes. + As of today it can check the fact that: + - Trace was saved on the backend side (as experiment) + - It contains comet_llm_data.json asset + - Expected duration, tags, metadata are the same as the actual ones. + + The function takes into account that some data might not be avalable + right after logging, so it can wait for some pieces of data (except for the check + for trace and asset existence). + + TODO: probably add assertions for asset content. E.g. today trace input and output + are not verified, however, they are + """ + api_experiment: "comet_ml.APIExperiment" = comet_api.get_experiment_by_id(experiment=trace_id) + assert api_experiment is not None, "Failed to verify that trace was saved" + + assets = api_experiment.get_asset_list() + assert len(assets) == 1, "Failed to verify that trace contains asset" + assert assets[0]["fileName"] == "comet_llm_data.json" + + if expected_duration is not None: + testlib.until( + function=lambda: len(api_experiment.get_metrics(metric="chain_duration")) != 0 + ), "Failed to get duration (a.k.a. chain_duration metric)" + metrics = api_experiment.get_metrics(metric="chain_duration") + _assert_equal_with_conversion_to_left_type( + expected_duration, + metrics[0]["metricValue"] + ) + + if expected_tags is not None: + assert testlib.until( + function=lambda: len(api_experiment.get_tags()) != 0 + ), "Failed to get tags" + actual_tags = api_experiment.get_tags() + assert actual_tags == expected_tags + + if expected_metadata is not None: + assert testlib.until( + function=lambda: len(api_experiment.get_parameters_summary()) != 0 + ), "Failed to get trace metadata (a.k.a. parameters)" + actual_parameters = api_experiment.get_parameters_summary() + assert len(actual_parameters) == len(expected_metadata) + for actual_parameter in actual_parameters: + name = actual_parameter["name"] + _assert_equal_with_conversion_to_left_type( + expected_metadata[name], + actual_parameter["valueCurrent"] + ) + + +def _assert_equal_with_conversion_to_left_type(left_value: Any, right_value: Any) -> None: + """ + Used for more convenient assertions with + string data returned from backend + """ + left_type = type(left_value) + right_value_converted = left_type(right_value) + + assert left_value == right_value_converted \ No newline at end of file diff --git a/tests/testlib.py b/tests/testlib.py index 37e486b7..72f9c627 100644 --- a/tests/testlib.py +++ b/tests/testlib.py @@ -1,9 +1,12 @@ import contextlib +import logging import os +import time +from typing import Callable, Dict @contextlib.contextmanager -def environ(env): +def environ(env: Dict[str, str]): """Temporarily set environment variables inside the context manager and fully restore previous environment afterwards """ @@ -17,4 +20,16 @@ def environ(env): if value is None: del os.environ[key] else: - os.environ[key] = value \ No newline at end of file + os.environ[key] = value + + +def until(function: Callable, sleep: float = 0.5, max_try_seconds: int = 20) -> bool: + """ + Try assert function(). 20 seconds max + """ + start_time = time.time() + while not function(): + if (time.time() - start_time) > max_try_seconds: + return False + time.sleep(sleep) + return True \ No newline at end of file diff --git a/tests/unit/chains/test_chains_api.py b/tests/unit/chains/test_chains_api.py index 9ec1f555..78601466 100644 --- a/tests/unit/chains/test_chains_api.py +++ b/tests/unit/chains/test_chains_api.py @@ -77,8 +77,9 @@ def test_end_chain__happyflow(): ) s.global_chain.set_outputs(outputs="the-outputs", metadata="the-metadata") s.global_chain.as_dict() >> CHAIN_DICT - + s.messages.generate_id() >> "message-id" s.messages.ChainMessage( + id="message-id", experiment_info_=experiment_info, tags="the-tags", chain_data=CHAIN_DICT, diff --git a/tests/unit/experiment_api/test_failed_response_handler.py b/tests/unit/experiment_api/test_failed_response_handler.py deleted file mode 100644 index cd437b94..00000000 --- a/tests/unit/experiment_api/test_failed_response_handler.py +++ /dev/null @@ -1,21 +0,0 @@ -import json - -import box -import pytest -from testix import * - -from comet_llm.exceptions import exceptions -from comet_llm.experiment_api import failed_response_handler - - -def test_wrap__request_exception_non_llm_project_sdk_code__log_specifc_message_in_exception(): - exception = Exception() - exception.response = box.Box(text=json.dumps({"sdk_error_code": 34323})) - - expected_log_message = "Failed to send prompt to the specified project as it is not an LLM project, please specify a different project name." - - - with pytest.raises(exceptions.CometLLMException) as excinfo: - failed_response_handler.handle(exception) - - assert excinfo.value.args == (expected_log_message, ) \ No newline at end of file diff --git a/tests/unit/experiment_api/test_request_exception_wrapper.py b/tests/unit/experiment_api/test_request_exception_wrapper.py index 504d73ff..a2e8bd3c 100644 --- a/tests/unit/experiment_api/test_request_exception_wrapper.py +++ b/tests/unit/experiment_api/test_request_exception_wrapper.py @@ -9,13 +9,13 @@ from comet_llm.experiment_api import request_exception_wrapper -@pytest.fixture(autouse=True) +@pytest.fixture() def mock_imports(patch_module): patch_module(request_exception_wrapper, "config") patch_module(request_exception_wrapper, "failed_response_handler") -def test_wrap_no_exceptions(): +def test_wrap_no_exceptions(mock_imports): @request_exception_wrapper.wrap() def f(): return "return-value" @@ -23,7 +23,7 @@ def f(): assert f() == "return-value" -def test_wrap__request_exception_caught__comet_exception_raised(): +def test_wrap__request_exception_caught__comet_exception_raised(mock_imports): @request_exception_wrapper.wrap() def f(): raise requests.RequestException @@ -32,7 +32,7 @@ def f(): f() -def test_wrap__on_prem_check_enabled__request_exception_caught__on_prem_detected__comet_exception_raised_with_additional_message(): +def test_wrap__on_prem_check_enabled__request_exception_caught__on_prem_detected__comet_exception_raised_with_additional_message(mock_imports): @request_exception_wrapper.wrap(check_on_prem=True) def f(): raise requests.RequestException @@ -43,7 +43,7 @@ def f(): f() -def test_wrap__on_prem_check_enabled__request_exception_caught__on_prem_not_detected__comet_exception_raised_without_additional_message(): +def test_wrap__on_prem_check_enabled__request_exception_caught__on_prem_not_detected__comet_exception_raised_without_additional_message(mock_imports): @request_exception_wrapper.wrap(check_on_prem=True) def f(): raise requests.RequestException @@ -56,13 +56,18 @@ def f(): f() -def test_wrap__request_exception_with_not_None_response__exception_handled_by_failed_response_handler(): - exception = requests.RequestException(response="not-None") +def test_wrap__request_exception_non_llm_project_sdk_code__log_specifc_message_in_exception(): + exception = requests.RequestException() + exception.response = box.Box(text=json.dumps({"sdk_error_code": 34323})) - @request_exception_wrapper.wrap() + expected_log_message = "Failed to send prompt to the specified project as it is not an LLM project, please specify a different project name." + + @request_exception_wrapper.wrap(check_on_prem=True) def f(): raise exception - with Scenario() as s: - s.failed_response_handler.handle(exception) + + with pytest.raises(exceptions.CometLLMException) as excinfo: f() + + assert excinfo.value.args == (expected_log_message, ) \ No newline at end of file diff --git a/tests/unit/message_processing/online_senders/test_chain_online_sender.py b/tests/unit/message_processing/online_senders/test_chain_online_sender.py index 6d89e039..dae5f234 100644 --- a/tests/unit/message_processing/online_senders/test_chain_online_sender.py +++ b/tests/unit/message_processing/online_senders/test_chain_online_sender.py @@ -6,7 +6,9 @@ from comet_llm import llm_result from comet_llm.message_processing import messages -from comet_llm.message_processing.online_senders import chain +from comet_llm.message_processing.online_senders import chain, constants + +NOT_USED = None @pytest.fixture(autouse=True) @@ -15,15 +17,22 @@ def mock_imports(patch_module): patch_module(chain, "chain") patch_module(chain, "state") patch_module(chain, "convert") - #patch_module(chain, "experiment_info") patch_module(chain, "experiment_api") - # patch_module(chain, "app") + patch_module(chain, "comet_api_client") + patch_module(chain, "url_helpers") + + +@pytest.fixture +def mock_v2_backend_version(patch_module): + patch_module(constants, "V2_BACKEND_VERSION", 10) + return 10 -def test_send__happyflow(): +def test_send__v1_backend__happyflow(mock_v2_backend_version): CHAIN_DICT = {"some-key": "some-value"} message = messages.ChainMessage( + id=NOT_USED, experiment_info_=box.Box(api_key="api-key", workspace="the-workspace", project_name="project-name"), tags="the-tags", chain_data=CHAIN_DICT, @@ -32,8 +41,9 @@ def test_send__happyflow(): others={"other-name-1": "other-value-1", "other-name-2": "other-value-2"} ) - + V1_BACKEND_VERSION = mock_v2_backend_version - 1 with Scenario() as s: + s.comet_api_client.get("api-key") >> box.Box(backend_version=V1_BACKEND_VERSION) s.experiment_api.ExperimentAPI.create_new( api_key="api-key", @@ -61,3 +71,36 @@ def test_send__happyflow(): assert chain.send(message) == llm_result.LLMResult(id="experiment-id", project_url="project-url") + + +def test_send__v2_backend__happyflow(mock_v2_backend_version): + CHAIN_DICT = {"some-key": "some-value"} + message = messages.ChainMessage( + id="experiment-id", + experiment_info_=box.Box(api_key="api-key", workspace="the-workspace", project_name="project-name"), + tags="the-tags", + chain_data=CHAIN_DICT, + duration="chain-duration", + metadata="the-metadata", + others={"other-name-1": "other-value-1", "other-name-2": "other-value-2"} + ) + + V2_BACKEND_VERSION = mock_v2_backend_version + with Scenario() as s: + s.comet_api_client.get("api-key") >> Fake("client", backend_version=V2_BACKEND_VERSION) + s.convert.chain_metadata_to_flat_parameters( + "the-metadata", + ) >> {"parameter-key-1": "value-1", "parameter-key-2": "value-2"} + s.client.log_chain( + experiment_key="experiment-id", + chain_asset=CHAIN_DICT, + workspace="the-workspace", + project="project-name", + tags="the-tags", + metrics={"chain_duration": "chain-duration"}, + parameters={"parameter-key-1": "value-1", "parameter-key-2": "value-2"}, + others={"other-name-1": "other-value-1", "other-name-2": "other-value-2"} + ) >> box.Box(link="experiment-url") + s.url_helpers.experiment_to_project_url("experiment-url") >> "project-url" + + assert chain.send(message) == llm_result.LLMResult(id="experiment-id", project_url="project-url") diff --git a/tests/unit/message_processing/online_senders/test_prompt_online_sender.py b/tests/unit/message_processing/online_senders/test_prompt_online_sender.py index 5a7f2298..2d5f9345 100644 --- a/tests/unit/message_processing/online_senders/test_prompt_online_sender.py +++ b/tests/unit/message_processing/online_senders/test_prompt_online_sender.py @@ -7,27 +7,27 @@ from comet_llm import llm_result from comet_llm.chains import version from comet_llm.message_processing import messages -from comet_llm.message_processing.online_senders import prompt +from comet_llm.message_processing.online_senders import constants, prompt @pytest.fixture(autouse=True) def mock_imports(patch_module): - patch_module(prompt, "comet_ml") patch_module(prompt, "convert") patch_module(prompt, "experiment_api") - patch_module(prompt, "experiment_info") - patch_module(prompt, "flatten_dict") - patch_module(prompt, "datetimes") patch_module(prompt, "io") patch_module(prompt, "preprocess") - patch_module(prompt, "app") - patch_module(prompt, "messages") - patch_module(prompt, "message_processing_api") + patch_module(prompt, "comet_api_client") + patch_module(prompt, "url_helpers") +@pytest.fixture +def mock_v2_backend_version(patch_module): + patch_module(constants, "V2_BACKEND_VERSION", 10) + return 10 -def test_send__happyflow(): +def test_send__v1_backend__happyflow(mock_v2_backend_version): message = messages.PromptMessage( + id="id-which-wont-be-used", experiment_info_=box.Box(api_key="api-key", workspace="the-workspace", project_name="project-name"), prompt_asset_data={"asset-dict-key": "asset-dict-value"}, duration="the-duration", @@ -35,7 +35,11 @@ def test_send__happyflow(): tags="the-tags" ) + V1_BACKEND_VERSION = mock_v2_backend_version - 1 + with Scenario() as s: + s.comet_api_client.get("api-key") >> box.Box(backend_version=V1_BACKEND_VERSION) + s.experiment_api.ExperimentAPI.create_new( api_key="api-key", workspace="the-workspace", @@ -58,4 +62,34 @@ def test_send__happyflow(): s.experiment_api_instance.log_parameter("parameter-key-1", "value-1") s.experiment_api_instance.log_parameter("parameter-key-2", "value-2") - prompt.send(message) \ No newline at end of file + assert prompt.send(message) + + +def test_send__v2_backend__happyflow(mock_v2_backend_version): + message = messages.PromptMessage( + id="experiment-id", + experiment_info_=box.Box(api_key="api-key", workspace="the-workspace", project_name="project-name"), + prompt_asset_data={"asset-dict-key": "asset-dict-value"}, + duration="the-duration", + metadata="the-metadata", + tags="the-tags" + ) + + V2_BACKEND_VERSION = mock_v2_backend_version + with Scenario() as s: + s.comet_api_client.get("api-key") >> Fake("client", backend_version=V2_BACKEND_VERSION) + s.convert.chain_metadata_to_flat_parameters( + "the-metadata", + ) >> {"parameter-key-1": "value-1", "parameter-key-2": "value-2"} + s.client.log_chain( + experiment_key="experiment-id", + chain_asset={"asset-dict-key": "asset-dict-value"}, + workspace="the-workspace", + project="project-name", + tags="the-tags", + metrics={"chain_duration": "the-duration"}, + parameters={"parameter-key-1": "value-1", "parameter-key-2": "value-2"}, + ) >> box.Box(link="experiment-url") + s.url_helpers.experiment_to_project_url("experiment-url") >> "project-url" + + assert prompt.send(message) == llm_result.LLMResult(id="experiment-id", project_url="project-url") diff --git a/tests/unit/message_processing/test_messages.py b/tests/unit/message_processing/test_messages.py index 54f07b28..67604e15 100644 --- a/tests/unit/message_processing/test_messages.py +++ b/tests/unit/message_processing/test_messages.py @@ -1,3 +1,5 @@ +import mock + from comet_llm import experiment_info from comet_llm.message_processing import messages @@ -10,6 +12,7 @@ def test_message_dict_conversion__api_key_excluded_in_to_dict__api_key_included_ ) prompt_message = messages.PromptMessage( + id="the-id", experiment_info_=experiment_info_, prompt_asset_data={"asset-key": "asset-value"}, duration=1000, @@ -20,6 +23,7 @@ def test_message_dict_conversion__api_key_excluded_in_to_dict__api_key_included_ dict_message = prompt_message.to_dict() assert dict_message == { + "id": "the-id", "experiment_info_": {"workspace": "the-workspace", "project_name": "project-name"}, "prompt_asset_data": {"asset-key": "asset-value"}, "duration": 1000, @@ -28,4 +32,32 @@ def test_message_dict_conversion__api_key_excluded_in_to_dict__api_key_included_ "VERSION": messages.PromptMessage.VERSION } - assert prompt_message == messages.PromptMessage.from_dict(dict_message, api_key="api-key") \ No newline at end of file + assert prompt_message == messages.PromptMessage.from_dict(dict_message, api_key="api-key") + + +def test_message_dict_conversion__version_1_dict__converted_to_version_2_message_with_id_generation(): + experiment_info_ = experiment_info.ExperimentInfo( + api_key="api-key", + workspace="the-workspace", + project_name="project-name" + ) + + prompt_message_v2 = messages.PromptMessage( + id=mock.ANY, + experiment_info_=experiment_info_, + prompt_asset_data={"asset-key": "asset-value"}, + duration=1000, + metadata={"metadata-key": "metadata-value"}, + tags=["tag1", "tag2"] + ) + + v1_message_dict = { + "experiment_info_": {"workspace": "the-workspace", "project_name": "project-name"}, + "prompt_asset_data": {"asset-key": "asset-value"}, + "duration": 1000, + "metadata": {"metadata-key": "metadata-value"}, + "tags": ["tag1", "tag2"], + "VERSION": 1 + } + + assert prompt_message_v2 == messages.PromptMessage.from_dict(v1_message_dict, api_key="api-key") \ No newline at end of file diff --git a/tests/unit/message_processing/test_offline_message_processor.py b/tests/unit/message_processing/test_offline_message_processor.py index a3c5376f..93f6641e 100644 --- a/tests/unit/message_processing/test_offline_message_processor.py +++ b/tests/unit/message_processing/test_offline_message_processor.py @@ -23,6 +23,7 @@ def mock_imports(patch_module): def test_offline_message_processor__new_filename_created_because_of_time_passed(mock_imports): message = messages.PromptMessage( + id=NOT_USED, experiment_info_=NOT_USED, prompt_asset_data=NOT_USED, duration=NOT_USED, @@ -58,6 +59,7 @@ def test_offline_message_processor__new_filename_created_because_of_time_passed( def test_offline_message_processor__messages_dispatched_to_correct_senders(mock_imports): prompt_message = messages.PromptMessage( + id=NOT_USED, experiment_info_=NOT_USED, prompt_asset_data=NOT_USED, duration=NOT_USED, @@ -66,6 +68,7 @@ def test_offline_message_processor__messages_dispatched_to_correct_senders(mock_ ) chain_message = messages.ChainMessage( + id=NOT_USED, experiment_info_=NOT_USED, chain_data=NOT_USED, duration=NOT_USED, @@ -114,6 +117,7 @@ def test_offline_message_processor_with_senders__prompt_and_chain_messages__happ ) prompt_message = messages.PromptMessage( + id="prompt-id", experiment_info_=experiment_info_, prompt_asset_data={"prompt-asset-key": "prompt-asset-value"}, duration="prompt-duration", @@ -122,6 +126,7 @@ def test_offline_message_processor_with_senders__prompt_and_chain_messages__happ ) chain_message = messages.ChainMessage( + id="chain-id", experiment_info_=experiment_info_, chain_data={"chain-key": "chain-value"}, duration="chain-duration", diff --git a/tests/unit/message_processing/test_online_message_processor.py b/tests/unit/message_processing/test_online_message_processor.py index 6bb0a451..860743d8 100644 --- a/tests/unit/message_processing/test_online_message_processor.py +++ b/tests/unit/message_processing/test_online_message_processor.py @@ -15,6 +15,7 @@ def mock_imports(patch_module): def test_offline_message_processor__messages_dispatched_to_correct_senders(): tested = online_message_processor.OnlineMessageProcessor() prompt_message = messages.PromptMessage( + id=NOT_USED, experiment_info_=NOT_USED, prompt_asset_data=NOT_USED, duration=NOT_USED, @@ -23,6 +24,7 @@ def test_offline_message_processor__messages_dispatched_to_correct_senders(): ) chain_message = messages.ChainMessage( + id=NOT_USED, experiment_info_=NOT_USED, chain_data=NOT_USED, duration=NOT_USED, diff --git a/tests/unit/prompts/test_prompts_api.py b/tests/unit/prompts/test_prompts_api.py index e9b30015..fe001354 100644 --- a/tests/unit/prompts/test_prompts_api.py +++ b/tests/unit/prompts/test_prompts_api.py @@ -74,7 +74,9 @@ def test_log_prompt__happyflow(): duration="the-duration" ) >> "CALL-DATA-DICT" + s.messages.generate_id() >> "message-id" s.messages.PromptMessage( + id="message-id", experiment_info_=experiment_info, prompt_asset_data=EXPECTED_ASSET_DICT_TO_LOG, duration="the-duration",