Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,14 @@ jobs:
- uses: actions/setup-python@v6
with:
python-version: 3.x

# https://github.com/docker/compose/issues/10596
- name: Start services
run: |
services="python-api"
if [ "${{ matrix.php_api }}" = "true" ]; then
sed -i 's/INDEX_ES_DURING_STARTUP=false/INDEX_ES_DURING_STARTUP=true/' docker/php/.env
services="$services php-api"
fi
docker compose up $services --detach --wait --remove-orphans || exit $(docker compose ps -q | xargs docker inspect -f '{{.State.ExitCode}}' | grep -v '^0' | wc -l)

docker compose up $services --detach --wait --remove-orphans
- name: Run tests
run: |
marker="${{ matrix.php_api == true && 'php_api' || 'not php_api' }} and ${{ matrix.mutations == true && 'mut' || 'not mut' }}"
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
docker/mysql/data
.claude/
.ignore/
*.log
logs/
.DS_Store
Expand Down
41 changes: 28 additions & 13 deletions src/database/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,15 @@

from sqlalchemy import text
from sqlalchemy.engine import Row
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncConnection

from database.exceptions import (
_DUPLICATE_ENTRY,
_FOREIGN_KEY_CONSTRAINT_FAILED,
DuplicatePrimaryKeyError,
ForeignKeyConstraintError,
)
from schemas.datasets.openml import Feature


Expand Down Expand Up @@ -54,19 +61,27 @@ async def get_tags_for(id_: int, connection: AsyncConnection) -> list[str]:


async def tag(id_: int, tag_: str, *, user_id: int, connection: AsyncConnection) -> None:
await connection.execute(
text(
"""
INSERT INTO dataset_tag(`id`, `tag`, `uploader`)
VALUES (:dataset_id, :tag, :user_id)
""",
),
parameters={
"dataset_id": id_,
"user_id": user_id,
"tag": tag_,
},
)
try:
await connection.execute(
text(
"""
INSERT INTO dataset_tag(`id`, `tag`, `uploader`)
VALUES (:dataset_id, :tag, :user_id)
""",
),
parameters={
"dataset_id": id_,
"user_id": user_id,
"tag": tag_,
},
)
except IntegrityError as e:
code, msg = e.orig.args
if code == _FOREIGN_KEY_CONSTRAINT_FAILED:
raise ForeignKeyConstraintError(msg) from e
if code == _DUPLICATE_ENTRY:
raise DuplicatePrimaryKeyError(msg) from e
raise


async def get_description(
Expand Down
22 changes: 22 additions & 0 deletions src/database/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Defines exceptions of the database layer."""

_FOREIGN_KEY_CONSTRAINT_FAILED = 1452
_DUPLICATE_ENTRY = 1062


class ForeignKeyConstraintError(Exception):
"""Foreign key constraint violated."""

def __init__(self, msg: str) -> None:
"""Initialize the error with a message `msg`."""
super().__init__()
self.msg: str = msg


class DuplicatePrimaryKeyError(Exception):
"""Primary key already present."""

def __init__(self, msg: str) -> None:
"""Initialize the error with a message `msg`."""
super().__init__()
self.msg: str = msg
32 changes: 22 additions & 10 deletions src/routers/openml/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
_format_dataset_url,
_format_parquet_url,
)
from database.exceptions import DuplicatePrimaryKeyError, ForeignKeyConstraintError
from database.users import User
from routers.dependencies import (
Pagination,
Expand All @@ -40,7 +41,13 @@
fetch_user_or_raise,
userdb_connection,
)
from routers.types import CasualString128, IntegerRange, SystemString64, integer_range_regex
from routers.types import (
CasualString128,
Identifier,
IntegerRange,
SystemString64,
integer_range_regex,
)
from schemas.datasets.openml import DatasetMetadata, DatasetStatus, Feature, FeatureType

router = APIRouter(prefix="/datasets", tags=["datasets"])
Expand All @@ -50,21 +57,26 @@
path="/tag",
)
async def tag_dataset(
data_id: Annotated[int, Body()],
data_id: Annotated[Identifier, Body()],
tag: Annotated[str, SystemString64],
user: Annotated[User, Depends(fetch_user_or_raise)],
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None,
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> dict[str, dict[str, Any]]:
assert expdb_db is not None # noqa: S101
tags = await database.datasets.get_tags_for(data_id, expdb_db)
if tag.casefold() in [t.casefold() for t in tags]:
try:
await database.datasets.tag(data_id, tag, user_id=user.user_id, connection=expdb_db)
except ForeignKeyConstraintError:
msg = f"Dataset {data_id} not found."
raise DatasetNotFoundError(msg, code=472) from None
except DuplicatePrimaryKeyError:
msg = f"Dataset {data_id} already tagged with {tag!r}."
raise TagAlreadyExistsError(msg)
raise TagAlreadyExistsError(msg) from None

logger.info("Dataset {data_id} tagged '{tag}'.", data_id=data_id, tag=tag)

tags = await database.datasets.get_tags_for(data_id, expdb_db)

await database.datasets.tag(data_id, tag, user_id=user.user_id, connection=expdb_db)
logger.info("Dataset {dataset_id} tagged '{tag}'.", dataset_id=data_id, tag=tag)
return {
"data_tag": {"id": str(data_id), "tag": [*tags, tag]},
"data_tag": {"id": str(data_id), "tag": tags},
}


Expand Down
5 changes: 5 additions & 0 deletions src/routers/types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from typing import Annotated

from fastapi import Body
from pydantic import Field

SystemString64 = Body(pattern=r"^[\w\-\.]+$", min_length=1, max_length=64)

CasualString128 = Body(pattern=r"^[\w\-\.\(\),]+$", min_length=1, max_length=128)

Identifier = Annotated[int, Field(gt=0)]

integer_range_regex = r"^(\d+)(\.\.\d+)?$"
IntegerRange = Body(
pattern=integer_range_regex,
Expand Down
3 changes: 3 additions & 0 deletions tests/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
DATASET_ID_THAT_DOES_NOT_EXIST = 9_9999_999
SOME_PRIVATE_DATASET_ID = 130
PRIVATE_DATASET_ID = {130}
IN_PREPARATION_ID = {33, 161, 162, 163}
SOME_DEACTIVATED_DATASET_ID = 131
DEACTIVATED_DATASETS = {131}
DATASETS = set(range(1, 132)) | {161, 162, 163}

Expand Down
33 changes: 31 additions & 2 deletions tests/routers/openml/dataset_tag_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sqlalchemy.ext.asyncio import AsyncConnection

from core.conversions import nested_remove_single_element_list
from core.errors import TagAlreadyExistsError
from core.errors import DatasetNotFoundError, TagAlreadyExistsError
from database.datasets import get_tags_for
from database.users import User
from routers.openml.datasets import tag_dataset
Expand Down Expand Up @@ -96,13 +96,32 @@ async def test_dataset_tag_fails_if_tag_exists(expdb_test: AsyncConnection) -> N
assert tag in e.value.detail


async def test_dataset_tag_fails_if_dataset_does_not_exist(expdb_test: AsyncConnection) -> None:
dataset_id = 1_000_000
with pytest.raises(DatasetNotFoundError) as e:
await tag_dataset(
data_id=dataset_id,
tag="foo",
user=ADMIN_USER,
expdb_db=expdb_test,
)
assert str(dataset_id) in e.value.detail
dataset_not_found_in_tag_endpoint = 472
assert e.value.code == dataset_not_found_in_tag_endpoint


# -- migration tests --


@pytest.mark.mut
@pytest.mark.parametrize(
"dataset_id",
[*range(1, 10), 101, 131],
[
*range(1, 10),
101,
constants.SOME_DEACTIVATED_DATASET_ID,
constants.DATASET_ID_THAT_DOES_NOT_EXIST,
],
)
@pytest.mark.parametrize(
"api_key",
Expand Down Expand Up @@ -142,6 +161,7 @@ async def test_dataset_tag_response_is_identical(
and php_response.json()["error"]["message"] == "An Elastic Search Exception occured."
):
pytest.skip("Encountered Elastic Search error.")

py_response = await py_api.post(
f"/datasets/tag?api_key={api_key}",
json={"data_id": dataset_id, "tag": tag},
Expand All @@ -158,6 +178,15 @@ async def test_dataset_tag_response_is_identical(
)
return

if py_response.status_code == HTTPStatus.NOT_FOUND:
assert php_response.status_code == HTTPStatus.PRECONDITION_FAILED
py_error = py_response.json()
php_error = php_response.json()["error"]
assert py_error["code"] == php_error["code"]
assert php_error["message"] == "Entity not found."
assert re.match(r"Dataset \d+ not found.", py_error["detail"])
return

assert py_response.status_code == php_response.status_code, php_response.json()
if py_response.status_code != HTTPStatus.OK:
assert py_response.json()["code"] == php_response.json()["error"]["code"]
Expand Down
Loading