Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Fixed

- Pydantic model validation error when querying Project and listing Organizations.

## [0.0.1a3] - 2025-02-06

### Added
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [
]
dependencies = [
"codex-sdk==0.1.0a9",
"pydantic>=1.9.0, <3",
"pydantic>=2.0.0, <3",
]

[project.urls]
Expand Down
4 changes: 3 additions & 1 deletion src/cleanlab_codex/internal/organization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@


def list_organizations(client: _Codex) -> list[Organization]:
return [Organization.model_validate(org) for org in client.users.myself.organizations.list().organizations]
return [
Organization.model_validate(org.model_dump()) for org in client.users.myself.organizations.list().organizations
]
6 changes: 4 additions & 2 deletions src/cleanlab_codex/internal/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@ def query_project(
) -> tuple[Optional[str], Optional[Entry]]:
maybe_entry = client.projects.entries.query(project_id, question=question)
if maybe_entry is not None:
entry = Entry.model_validate(maybe_entry)
entry = Entry.model_validate(maybe_entry.model_dump())
if entry.answer is not None:
return entry.answer, entry

return fallback_answer, entry

if not read_only:
created_entry = Entry.model_validate(client.projects.entries.add_question(project_id, question=question))
created_entry = Entry.model_validate(
client.projects.entries.add_question(project_id, question=question).model_dump()
)
return fallback_answer, created_entry

return fallback_answer, None
6 changes: 3 additions & 3 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from codex import AuthenticationError
from codex.types.project_return_schema import Config as ProjectReturnConfig
from codex.types.project_return_schema import ProjectReturnSchema
from codex.types.users.myself.user_organizations_schema import Organization as SDKOrganization
from codex.types.users.myself.user_organizations_schema import UserOrganizationsSchema

from cleanlab_codex.client import Client
from cleanlab_codex.project import MissingProjectError
from cleanlab_codex.types.organization import Organization
from cleanlab_codex.types.project import ProjectConfig

FAKE_PROJECT_ID = str(uuid.uuid4())
Expand All @@ -29,7 +29,7 @@ def test_client_uses_default_organization(mock_client_from_api_key: MagicMock) -
default_org_id = "default-org-id"
mock_client_from_api_key.users.myself.organizations.list.return_value = UserOrganizationsSchema(
organizations=[
Organization(
SDKOrganization(
organization_id=default_org_id,
created_at=datetime.now(),
updated_at=datetime.now(),
Expand Down Expand Up @@ -98,7 +98,7 @@ def test_get_project_not_found(mock_client_from_api_key: MagicMock) -> None:
def test_list_organizations(mock_client_from_api_key: MagicMock) -> None:
mock_client_from_api_key.users.myself.organizations.list.return_value = UserOrganizationsSchema(
organizations=[
Organization(
SDKOrganization(
organization_id=FAKE_ORGANIZATION_ID,
created_at=datetime.now(),
updated_at=datetime.now(),
Expand Down
34 changes: 25 additions & 9 deletions tests/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from codex import AuthenticationError
from codex.types.project_create_params import Config
from codex.types.projects.access_key_retrieve_project_id_response import AccessKeyRetrieveProjectIDResponse
from codex.types.projects.entry import Entry as SDKEntry

from cleanlab_codex.project import MissingProjectError, Project
from cleanlab_codex.types.entry import Entry, EntryCreate
from cleanlab_codex.types.entry import EntryCreate

FAKE_PROJECT_ID = str(uuid.uuid4())
FAKE_USER_ID = "Test User"
Expand Down Expand Up @@ -138,11 +139,12 @@ def test_query_read_only(mock_client_from_access_key: MagicMock) -> None:
FAKE_PROJECT_ID, question="What is the capital of France?"
)
mock_client_from_access_key.projects.entries.add_question.assert_not_called()
assert res == (None, None)
assert res[0] is None
assert res[1] is None


def test_query_question_found_fallback_answer(mock_client_from_access_key: MagicMock) -> None:
unanswered_entry = Entry(
unanswered_entry = SDKEntry(
id=str(uuid.uuid4()),
created_at=datetime.now(tz=timezone.utc),
question="What is the capital of France?",
Expand All @@ -151,22 +153,32 @@ def test_query_question_found_fallback_answer(mock_client_from_access_key: Magic
mock_client_from_access_key.projects.entries.query.return_value = unanswered_entry
project = Project(mock_client_from_access_key, FAKE_PROJECT_ID)
res = project.query("What is the capital of France?")
assert res == (None, unanswered_entry)
assert res[0] is None
assert res[1] is not None
assert res[1].model_dump() == unanswered_entry.model_dump()


def test_query_question_not_found_fallback_answer(mock_client_from_access_key: MagicMock) -> None:
mock_client_from_access_key.projects.entries.query.return_value = None
mock_client_from_access_key.projects.entries.add_question.return_value = MagicMock(spec=Entry)
mock_entry = SDKEntry(
id="fake-id",
created_at=datetime.now(tz=timezone.utc),
question="What is the capital of France?",
answer=None,
)
mock_client_from_access_key.projects.entries.add_question.return_value = mock_entry

project = Project(mock_client_from_access_key, FAKE_PROJECT_ID)
res = project.query("What is the capital of France?", fallback_answer="Paris")
assert res[0] == "Paris"
assert res[1] is not None
assert res[1].model_dump() == mock_entry.model_dump()


def test_query_add_question_when_not_found(mock_client_from_access_key: MagicMock) -> None:
"""Test that query adds question when not found and not read_only"""
mock_client_from_access_key.projects.entries.query.return_value = None
new_entry = Entry(
new_entry = SDKEntry(
id=str(uuid.uuid4()),
created_at=datetime.now(tz=timezone.utc),
question="What is the capital of France?",
Expand All @@ -180,11 +192,13 @@ def test_query_add_question_when_not_found(mock_client_from_access_key: MagicMoc
mock_client_from_access_key.projects.entries.add_question.assert_called_once_with(
FAKE_PROJECT_ID, question="What is the capital of France?"
)
assert res == (None, new_entry)
assert res[0] is None
assert res[1] is not None
assert res[1].model_dump() == new_entry.model_dump()


def test_query_answer_found(mock_client_from_access_key: MagicMock) -> None:
answered_entry = Entry(
answered_entry = SDKEntry(
id=str(uuid.uuid4()),
created_at=datetime.now(tz=timezone.utc),
question="What is the capital of France?",
Expand All @@ -193,7 +207,9 @@ def test_query_answer_found(mock_client_from_access_key: MagicMock) -> None:
mock_client_from_access_key.projects.entries.query.return_value = answered_entry
project = Project(mock_client_from_access_key, FAKE_PROJECT_ID)
res = project.query("What is the capital of France?")
assert res == ("Paris", answered_entry)
assert res[0] == answered_entry.answer
assert res[1] is not None
assert res[1].model_dump() == answered_entry.model_dump()


def test_add_entries_empty_list(mock_client_from_access_key: MagicMock) -> None:
Expand Down