Skip to content

Commit

Permalink
fix: add more tests around proper usage patterns (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
viveknair committed Apr 24, 2023
1 parent b7414c6 commit fe8970c
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 85 deletions.
20 changes: 20 additions & 0 deletions package/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,30 @@
import os

import pytest
from dotenv import load_dotenv

import gentrace
from fixtures.aioresponse import mockaio
from fixtures.completion import completion_response
from fixtures.embedding import embedding_response
from fixtures.gentrace import gentrace_pipeline_run_response


@pytest.fixture()
def setupTeardown():
gentrace.api_key = os.getenv("GENTRACE_API_KEY")
gentrace.host = "http://localhost:3000/api/v1"

gentrace.configure_openai()

yield "done"
gentrace.api_key = ""
gentrace.host = ""


def setup():
print("Invoking setup")


def pytest_configure():
load_dotenv()
13 changes: 3 additions & 10 deletions package/gentrace/providers/getters.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,20 @@
import importlib.util
import os
from typing import Any, Dict, Optional, cast

import openai

from gentrace.configuration import Configuration as GentraceConfiguration

openai.api_key = os.getenv("OPENAI_KEY")

configured = False


def configure_openai():
global configured

if configured:
return

configured = True
from gentrace import api_key, host

from .llms.openai import annotate_openai_module

if not api_key:
raise ValueError("Gentrace API key not set")

gentrace_config = GentraceConfiguration(host=host)
gentrace_config.access_token = api_key

Expand Down
48 changes: 19 additions & 29 deletions package/gentrace/providers/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,17 @@ async def wrapper(cls, *args, **kwargs):
return wrapper


def swap_methods(cls, attribute: str, gentrace_config: Configuration, intercept_fn):
has_old_sync_saved = hasattr(cls, attribute + "_old")
original_create = getattr(cls, attribute)
if has_old_sync_saved:
original_create = getattr(cls, attribute + "_old")
else:
setattr(cls, attribute + "_old", getattr(cls, attribute))

setattr(cls, attribute, intercept_fn(original_create, gentrace_config))


def annotate_openai_module(
gentrace_config: Configuration,
):
Expand All @@ -649,40 +660,19 @@ def annotate_openai_module(
for name, cls in vars(openai.api_resources).items():
if isinstance(cls, type):
if name == "Completion":
original_create = (
cls.create_old if hasattr(cls, "create_old") else cls.create
)
cls.create = intercept_completion(original_create, gentrace_config)

original_acreate = (
cls.acreate_old if hasattr(cls, "acreate_old") else cls.acreate
)
cls.acreate = intercept_completion_async(
original_acreate, gentrace_config
swap_methods(cls, "create", gentrace_config, intercept_completion)
swap_methods(
cls, "acreate", gentrace_config, intercept_completion_async
)
elif name == "ChatCompletion":
original_create = (
cls.create_old if hasattr(cls, "create_old") else cls.create
swap_methods(cls, "create", gentrace_config, intercept_chat_completion)
swap_methods(
cls, "acreate", gentrace_config, intercept_chat_completion_async
)
cls.create = intercept_chat_completion(original_create, gentrace_config)

original_acreate = (
cls.acreate_old if hasattr(cls, "acreate_old") else cls.acreate
)
cls.acreate = intercept_chat_completion_async(
original_acreate, gentrace_config
)
elif name == "Embedding":
original_create = (
cls.create_old if hasattr(cls, "create_old") else cls.create
)
cls.create = intercept_embedding(original_create, gentrace_config)
original_acreate = (
cls.acreate_old if hasattr(cls, "acreate_old") else cls.acreate
)
cls.acreate = intercept_embedding_async(
original_acreate, gentrace_config
)
swap_methods(cls, "create", gentrace_config, intercept_embedding)
swap_methods(cls, "acreate", gentrace_config, intercept_embedding_async)

setattr(openai.api_resources, name, cls)

Expand Down
28 changes: 17 additions & 11 deletions package/tests/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,9 @@

import gentrace

gentrace.api_key = os.getenv("GENTRACE_API_KEY")
gentrace.host = "http://localhost:3000/api/v1"

gentrace.configure_openai()


def test_openai_completion_self_contained_pipeline_id(
mocker, completion_response, gentrace_pipeline_run_response
mocker, completion_response, gentrace_pipeline_run_response, setupTeardown
):
openai.api_key = os.getenv("OPENAI_KEY")

Expand Down Expand Up @@ -69,9 +64,11 @@ def test_openai_completion_self_contained_pipeline_id(

assert uuid.UUID(result.pipeline_run_id) is not None

print(setupTeardown)


def test_openai_completion_self_contained_no_pipeline_id(
mocker, completion_response, gentrace_pipeline_run_response
mocker, completion_response, gentrace_pipeline_run_response, setupTeardown
):
openai.api_key = os.getenv("OPENAI_KEY")

Expand Down Expand Up @@ -118,11 +115,12 @@ def test_openai_completion_self_contained_no_pipeline_id(
)

assert not hasattr(result, "pipeline_run_id")
print(setupTeardown)


@pytest.mark.asyncio
async def test_openai_completion_self_contained_no_pipeline_id_async(
mocker, mockaio, completion_response, gentrace_pipeline_run_response
mocker, mockaio, completion_response, gentrace_pipeline_run_response, setupTeardown
):
# Setup OpenAI mocked request
pattern = re.compile(r"^https://api\.openai\.com/v1/.*$")
Expand Down Expand Up @@ -161,10 +159,12 @@ async def test_openai_completion_self_contained_no_pipeline_id_async(

assert not hasattr(result, "pipeline_run_id")

print(setupTeardown)


@pytest.mark.asyncio
async def test_openai_completion_self_contained_pipeline_id_async(
mocker, mockaio, completion_response, gentrace_pipeline_run_response
mocker, mockaio, completion_response, gentrace_pipeline_run_response, setupTeardown
):
# Setup OpenAI mocked request
pattern = re.compile(r"^https://api\.openai\.com/v1/.*$")
Expand Down Expand Up @@ -204,10 +204,12 @@ async def test_openai_completion_self_contained_pipeline_id_async(

assert uuid.UUID(result.pipeline_run_id) is not None

print(setupTeardown)


@responses.activate
def test_openai_completion_self_contained_pipeline_id_stream(
mocker, completion_response, gentrace_pipeline_run_response
mocker, completion_response, gentrace_pipeline_run_response, setupTeardown
):
openai.api_key = os.getenv("OPENAI_KEY")

Expand Down Expand Up @@ -258,9 +260,11 @@ def test_openai_completion_self_contained_pipeline_id_stream(

assert uuid.UUID(pipeline_run_id) is not None

print(setupTeardown)


@pytest.mark.asyncio
async def test_openai_completion_self_contained_pipeline_id_stream_async():
async def test_openai_completion_self_contained_pipeline_id_stream_async(setupTeardown):
responses.add_passthru("https://api.openai.com/v1/")

openai.api_key = os.getenv("OPENAI_KEY")
Expand All @@ -278,3 +282,5 @@ async def test_openai_completion_self_contained_pipeline_id_stream_async():
pipeline_run_id = value["pipeline_run_id"]

assert uuid.UUID(pipeline_run_id) is not None

print(setupTeardown)
58 changes: 23 additions & 35 deletions package/tests/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,10 @@

import gentrace

gentrace.api_key = os.getenv("GENTRACE_API_KEY")
gentrace.host = "http://localhost:3000/api/v1"

# TODO: must move back into test once GEN-143 is resolved
gentrace.configure_openai()


def test_openai_embedding_self_contained_pipeline_id(
mocker, embedding_response, gentrace_pipeline_run_response
mocker, embedding_response, gentrace_pipeline_run_response, setupTeardown
):
gentrace.api_key = os.getenv("GENTRACE_API_KEY")
gentrace.host = "http://localhost:3000/api/v1"

openai.api_key = os.getenv("OPENAI_KEY")

# Setup OpenAI mocked request
openai_api_key_getter = mocker.patch.object(openai.util, "default_api_key")
openai_api_key_getter.return_value = "test-key"
Expand Down Expand Up @@ -72,13 +61,12 @@ def test_openai_embedding_self_contained_pipeline_id(

assert uuid.UUID(result.pipeline_run_id) is not None

print(setupTeardown)


def test_openai_embedding_self_contained_no_pipeline_id(
mocker, embedding_response, gentrace_pipeline_run_response
mocker, embedding_response, gentrace_pipeline_run_response, setupTeardown
):
gentrace.api_key = os.getenv("GENTRACE_API_KEY")
gentrace.host = "http://localhost:3000/api/v1"

openai.api_key = os.getenv("OPENAI_KEY")

# Setup OpenAI mocked request
Expand Down Expand Up @@ -123,12 +111,10 @@ def test_openai_embedding_self_contained_no_pipeline_id(
)

assert not hasattr(result, "pipeline_run_id")
print(setupTeardown)


def test_openai_embedding_self_contained_pipeline_id_server(mocker):
gentrace.api_key = os.getenv("GENTRACE_API_KEY")
gentrace.host = "http://localhost:3000/api/v1"

def test_openai_embedding_self_contained_pipeline_id_server(mocker, setupTeardown):
openai.api_key = os.getenv("OPENAI_KEY")

responses.add_passthru("https://api.openai.com/v1/")
Expand All @@ -140,12 +126,10 @@ def test_openai_embedding_self_contained_pipeline_id_server(mocker):
)

assert uuid.UUID(result.pipeline_run_id) is not None
print(setupTeardown)


def test_openai_embedding_self_contained_no_pipeline_id_server(mocker):
gentrace.api_key = os.getenv("GENTRACE_API_KEY")
gentrace.host = "http://localhost:3000/api/v1"

def test_openai_embedding_self_contained_no_pipeline_id_server(setupTeardown):
openai.api_key = os.getenv("OPENAI_KEY")

responses.add_passthru("https://api.openai.com/v1/")
Expand All @@ -156,9 +140,10 @@ def test_openai_embedding_self_contained_no_pipeline_id_server(mocker):
)

assert not hasattr(result, "pipeline_run_id")
print(setupTeardown)


def test_openai_embedding_pipeline_server(mocker, embedding_response):
def test_openai_embedding_pipeline_server(setupTeardown):
responses.add_passthru("https://api.openai.com/v1/")

pipeline = gentrace.Pipeline(
Expand All @@ -181,11 +166,12 @@ def test_openai_embedding_pipeline_server(mocker, embedding_response):
info = runner.submit()

assert uuid.UUID(info["pipelineRunId"]) is not None
print(setupTeardown)


@responses.activate
def test_openai_embedding_pipeline(
mocker, embedding_response, gentrace_pipeline_run_response
mocker, embedding_response, gentrace_pipeline_run_response, setupTeardown
):
# Setup OpenAI mocked request
responses.add(
Expand Down Expand Up @@ -240,13 +226,13 @@ def test_openai_embedding_pipeline(
info = runner.submit()

assert uuid.UUID(info["pipelineRunId"]) is not None
print(setupTeardown)


@pytest.mark.asyncio
async def test_openai_embedding_self_contained_no_pipeline_id_server_async():
gentrace.api_key = os.getenv("GENTRACE_API_KEY")
gentrace.host = "http://localhost:3000/api/v1"

async def test_openai_embedding_self_contained_no_pipeline_id_server_async(
setupTeardown,
):
openai.api_key = os.getenv("OPENAI_KEY")

result = await openai.Embedding.acreate(
Expand All @@ -255,13 +241,11 @@ async def test_openai_embedding_self_contained_no_pipeline_id_server_async():
)

assert not hasattr(result, "pipeline_run_id")
print(setupTeardown)


@pytest.mark.asyncio
async def test_openai_embedding_self_contained_pipeline_id_server_async():
gentrace.api_key = os.getenv("GENTRACE_API_KEY")
gentrace.host = "http://localhost:3000/api/v1"

async def test_openai_embedding_self_contained_pipeline_id_server_async(setupTeardown):
openai.api_key = os.getenv("OPENAI_KEY")

result = await openai.Embedding.acreate(
Expand All @@ -272,10 +256,12 @@ async def test_openai_embedding_self_contained_pipeline_id_server_async():

assert uuid.UUID(result.pipeline_run_id) is not None

print(setupTeardown)


@pytest.mark.asyncio
async def test_openai_embedding_pipeline_async(
mocker, mockaio, embedding_response, gentrace_pipeline_run_response
mocker, mockaio, embedding_response, gentrace_pipeline_run_response, setupTeardown
):
# Setup OpenAI mocked request
mockaio.post(
Expand Down Expand Up @@ -329,3 +315,5 @@ async def test_openai_embedding_pipeline_async(
info = await runner.asubmit()

assert uuid.UUID(info["pipelineRunId"]) is not None

print(setupTeardown)
22 changes: 22 additions & 0 deletions package/tests/test_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os
from unittest.mock import create_autospec

import pytest
from urllib3.response import HTTPResponse

import gentrace


def test_should_raise_error():
with pytest.raises(ValueError):
gentrace.configure_openai()

gentrace.api_key = os.getenv("GENTRACE_API_KEY")
gentrace.host = "http://localhost:3000/api/v1"


def test_should_not_raise_error():
gentrace.api_key = os.getenv("GENTRACE_API_KEY")
gentrace.host = "http://localhost:3000/api/v1"

gentrace.configure_openai()

0 comments on commit fe8970c

Please sign in to comment.