Skip to content

Commit

Permalink
Merge branch 'main' into dsb/allow-config-in-streamlit
Browse files Browse the repository at this point in the history
  • Loading branch information
dsblank committed May 29, 2024
2 parents 5e12cb1 + fafae47 commit ff51f9d
Show file tree
Hide file tree
Showing 31 changed files with 876 additions and 106 deletions.
54 changes: 54 additions & 0 deletions .github/workflows/e2e-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
name: E2E Tests Comet LLM
env:
COMET_RAISE_EXCEPTIONS_ON_ERROR: "1"
COMET_API_KEY: ${{ secrets.PRODUCTION_CI_COMET_API_KEY }}
on:
pull_request:

jobs:
UnitTests:
name: E2E_Python_${{matrix.python_version}}
runs-on: ubuntu-20.04
strategy:
fail-fast: false
matrix:
python_version: ["3.6", "3.7", "3.8", "3.9", "3.10", "3.11", "3.12"]

steps:
- name: Check out code
uses: actions/checkout@v3

- name: Set the project name
run: |
echo "COMET_PROJECT_NAME=comet-llm-e2e-tests-py${{ matrix.python_version }}" >> $GITHUB_ENV
- name: Print environment variables
run: env

- name: Print event object
run: cat $GITHUB_EVENT_PATH

- name: Print the PR title
run: echo "${{ github.event.pull_request.title }}"

- name: Setup Python ${{ matrix.python_version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python_version }}

- name: Install comet-llm
run: pip install -e .

- name: Install test requirements
run: |
cd ./tests
pip install --no-cache-dir --disable-pip-version-check -r test_requirements.txt
- name: Running SDK e2e Tests
run: python -m pytest --cov=src/comet_llm --cov-report=html:coverage_report_${{matrix.python_version}} -vv tests/e2e/

- name: archive coverage report
uses: actions/upload-artifact@v3
with:
name: coverage_report_${{matrix.python_version}}
path: coverage_report_${{matrix.python_version}}
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
package_dir={"": "src"},
url="https://www.comet.com",
project_urls=project_urls,
version="2.2.2",
version="2.2.4",
zip_safe=False,
license="MIT",
)
1 change: 1 addition & 0 deletions src/comet_llm/chains/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def log_chain(chain: chain.Chain) -> Optional[llm_result.LLMResult]:
chain_data = chain.as_dict()

message = messages.ChainMessage(
id=messages.generate_id(),
experiment_info_=chain.experiment_info,
tags=chain.tags,
chain_data=chain_data,
Expand Down
69 changes: 66 additions & 3 deletions src/comet_llm/experiment_api/comet_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,39 @@
# *******************************************************

import functools
import logging
import urllib.parse
import warnings
from typing import IO, List, Optional
from typing import IO, Any, Dict, List, Optional

import requests # type: ignore
import urllib3.exceptions

from .. import config
from .. import config, exceptions, semantic_version
from ..types import JSONEncodable
from . import request_exception_wrapper
from . import error_codes_mapping, payload_constructor

ResponseContent = JSONEncodable

LOGGER = logging.getLogger(__name__)


class CometAPIClient:
def __init__(self, api_key: str, comet_url: str, session: requests.Session):
self._headers = {"Authorization": api_key}
self._comet_url = comet_url
self._session = session

self.backend_version = semantic_version.SemanticVersion.parse(
self.is_alive_ver()["version"]
)

def is_alive_ver(self) -> ResponseContent:
return self._request(
"GET",
"api/isAlive/ver",
)

def create_experiment(
self,
type_: str,
Expand Down Expand Up @@ -122,6 +135,56 @@ def log_experiment_other(
},
)

def log_chain(
self,
experiment_key: str,
chain_asset: Dict[str, JSONEncodable],
workspace: Optional[str] = None,
project: Optional[str] = None,
parameters: Optional[Dict[str, JSONEncodable]] = None,
metrics: Optional[Dict[str, JSONEncodable]] = None,
tags: Optional[List[str]] = None,
others: Optional[Dict[str, JSONEncodable]] = None,
) -> ResponseContent:
json = [
{
"experimentKey": experiment_key,
"createExperimentRequest": {
"workspaceName": workspace,
"projectName": project,
"type": "LLM",
},
"parameters": payload_constructor.chain_parameters_payload(parameters),
"metrics": payload_constructor.chain_metrics_payload(metrics),
"others": payload_constructor.chain_others_payload(others),
"tags": tags,
"jsonAsset": {
"extension": "json",
"type": "llm_data",
"fileName": "comet_llm_data.json",
"file": chain_asset,
},
}
] # we make a list because endpoint is designed for batches

batched_response: Dict[str, Dict[str, Any]] = self._request(
"POST",
"api/rest/v2/write/experiment/llm",
json=json,
)
sub_response = list(batched_response.values())[0]
status = sub_response["status"]
if status != 200:
LOGGER.debug(
"Failed to send trace: \nPayload %s, Response %s",
str(json),
str(batched_response),
)
error_code = sub_response["content"]["sdk_error_code"]
raise exceptions.CometLLMException(error_codes_mapping.MESSAGES[error_code])

return sub_response["content"]

def _request(self, method: str, path: str, *args, **kwargs) -> ResponseContent: # type: ignore
url = urllib.parse.urljoin(self._comet_url, path)
response = self._session.request(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,12 @@
# *******************************************************

import collections
import json
from typing import NoReturn

import requests # type: ignore
from .. import backend_error_codes, logging_messages

from .. import backend_error_codes, exceptions, logging_messages

_SDK_ERROR_CODES_LOGGING_MESSAGE = collections.defaultdict(
MESSAGES = collections.defaultdict(
lambda: logging_messages.FAILED_TO_SEND_DATA_TO_SERVER,
{
backend_error_codes.UNABLE_TO_LOG_TO_NON_LLM_PROJECT: logging_messages.UNABLE_TO_LOG_TO_NON_LLM_PROJECT
},
)


def handle(exception: requests.RequestException) -> NoReturn:
response = exception.response
sdk_error_code = json.loads(response.text)["sdk_error_code"]
error_message = _SDK_ERROR_CODES_LOGGING_MESSAGE[sdk_error_code]

raise exceptions.CometLLMException(error_message) from exception
46 changes: 46 additions & 0 deletions src/comet_llm/experiment_api/payload_constructor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
# *******************************************************
# ____ _ _
# / ___|___ _ __ ___ ___| |_ _ __ ___ | |
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
# | |__| (_) | | | | | | __/ |_ _| | | | | | |
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
#
# Sign up for free at https://www.comet.com
# Copyright (C) 2015-2023 Comet ML INC
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this package.
# *******************************************************

from typing import Dict, List, Optional

from ..types import JSONEncodable


def chain_parameters_payload(
parameters: Optional[Dict[str, JSONEncodable]]
) -> List[Dict[str, JSONEncodable]]:
return _dict_to_payload_format(parameters, "parameterName", "parameterValue")


def chain_metrics_payload(
metrics: Optional[Dict[str, JSONEncodable]]
) -> List[Dict[str, JSONEncodable]]:
return _dict_to_payload_format(metrics, "metricName", "metricValue")


def chain_others_payload(
others: Optional[Dict[str, JSONEncodable]]
) -> List[Dict[str, JSONEncodable]]:
return _dict_to_payload_format(others, "key", "value")


def _dict_to_payload_format(
source: Optional[Dict[str, JSONEncodable]], key_name: str, value_name: str
) -> List[Dict[str, JSONEncodable]]:
if source is None:
return []

result = [{key_name: key, value_name: value} for key, value in source.items()]

return result
15 changes: 12 additions & 3 deletions src/comet_llm/experiment_api/request_exception_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
# *******************************************************

import functools
import json
import logging
import urllib.parse
from pprint import pformat
from typing import Any, Callable, List
from typing import Any, Callable, List, NoReturn

import requests # type: ignore

from .. import config, exceptions, logging_messages
from . import failed_response_handler
from . import error_codes_mapping

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,7 +50,7 @@ def wrapper(*args, **kwargs) -> Any: # type: ignore
logging_messages.FAILED_TO_SEND_DATA_TO_SERVER
) from exception

failed_response_handler.handle(exception)
_handle_request_exception(exception)

return wrapper

Expand All @@ -73,3 +74,11 @@ def _debug_log(exception: requests.RequestException) -> None:
# Make sure we won't fail on attempt to debug.
# It's mainly for tests when response object can be mocked
pass


def _handle_request_exception(exception: requests.RequestException) -> NoReturn:
response = exception.response
sdk_error_code = json.loads(response.text)["sdk_error_code"]
error_message = error_codes_mapping.MESSAGES[sdk_error_code]

raise exceptions.CometLLMException(error_message) from exception
16 changes: 12 additions & 4 deletions src/comet_llm/message_processing/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,32 @@
# *******************************************************

import dataclasses
import inspect
import uuid
from typing import Any, ClassVar, Dict, List, Optional, Union

from comet_llm.types import JSONEncodable

from .. import experiment_info, logging_messages


def generate_id() -> str:
return uuid.uuid4().hex


@dataclasses.dataclass
class BaseMessage:
experiment_info_: experiment_info.ExperimentInfo
id: str
VERSION: ClassVar[int]

@classmethod
def from_dict(
cls, d: Dict[str, Any], api_key: Optional[str] = None
) -> "BaseMessage":
d.pop("VERSION") #
version = d.pop("VERSION")
if version == 1:
# Message was dumped before id was introduced. We can generate it now.
d["id"] = generate_id()

experiment_info_dict: Dict[str, Optional[str]] = d.pop("experiment_info_")
experiment_info_ = experiment_info.get(
Expand All @@ -57,7 +65,7 @@ class PromptMessage(BaseMessage):
metadata: Optional[Dict[str, Union[str, bool, float, None]]]
tags: Optional[List[str]]

VERSION: ClassVar[int] = 1
VERSION: ClassVar[int] = 2


@dataclasses.dataclass
Expand All @@ -69,4 +77,4 @@ class ChainMessage(BaseMessage):
others: Dict[str, JSONEncodable]
# 'other' - is a name of an attribute of experiment, logged via log_other

VERSION: ClassVar[int] = 1
VERSION: ClassVar[int] = 2
10 changes: 2 additions & 8 deletions src/comet_llm/message_processing/offline_message_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,9 @@ def process(self, message: messages.BaseMessage) -> None:
file_path = pathlib.Path(self._offline_directory, self._current_file_name)

if isinstance(message, messages.PromptMessage):
try:
return prompt.send(message, str(file_path))
except Exception:
LOGGER.error("Failed to log prompt", exc_info=True)
return prompt.send(message, str(file_path))
elif isinstance(message, messages.ChainMessage):
try:
return chain.send(message, str(file_path))
except Exception:
LOGGER.error("Failed to log chain", exc_info=True)
return chain.send(message, str(file_path))

LOGGER.debug(f"Unsupported message type {message}")
return None
Expand Down
10 changes: 2 additions & 8 deletions src/comet_llm/message_processing/online_message_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,9 @@ def __init__(self) -> None:

def process(self, message: messages.BaseMessage) -> Optional[llm_result.LLMResult]:
if isinstance(message, messages.PromptMessage):
try:
return prompt.send(message)
except Exception:
LOGGER.error("Failed to log prompt", exc_info=True)
return prompt.send(message)
elif isinstance(message, messages.ChainMessage):
try:
return chain.send(message)
except Exception:
LOGGER.error("Failed to log chain", exc_info=True)
return chain.send(message)

LOGGER.debug(f"Unsupported message type {message}")
return None
Loading

0 comments on commit ff51f9d

Please sign in to comment.