diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index 164efd7..1b6bc52 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -4,7 +4,7 @@ from typing import Annotated, Any, Literal, NamedTuple from fastapi import APIRouter, Body, Depends -from sqlalchemy import text +from sqlalchemy import bindparam, text from sqlalchemy.engine import Row from sqlalchemy.ext.asyncio import AsyncConnection @@ -73,9 +73,26 @@ class DatasetStatusFilter(StrEnum): ALL = "all" +def _quality_clause(quality: str, range_: str | None) -> str: + if not range_: + return "" + if not (match := re.match(integer_range_regex, range_)): + msg = f"`range_` not a valid range: {range_}" + raise ValueError(msg) + start, end = match.groups() + value = f"`value` BETWEEN {start} AND {end[2:]}" if end else f"`value`={start}" + return f""" AND + d.`did` IN ( + SELECT `data` + FROM data_quality + WHERE `quality`='{quality}' AND {value} + ) + """ # noqa: S608 - `quality` is not user provided, value is filtered with regex + + @router.post(path="/list", description="Provided for convenience, same as `GET` endpoint.") @router.get(path="/list") -async def list_datasets( # noqa: PLR0913 +async def list_datasets( # noqa: PLR0913, C901 pagination: Annotated[Pagination, Body(default_factory=Pagination)], data_name: Annotated[str | None, CasualString128] = None, tag: Annotated[str | None, SystemString64] = None, @@ -103,7 +120,7 @@ async def list_datasets( # noqa: PLR0913 expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> list[dict[str, Any]]: assert expdb_db is not None # noqa: S101 - current_status = text( + status_subquery = text( """ SELECT ds1.`did`, ds1.`status` FROM dataset_status as ds1 @@ -115,90 +132,78 @@ async def list_datasets( # noqa: PLR0913 """, ) - if status == DatasetStatusFilter.ALL: - statuses = [ - DatasetStatusFilter.ACTIVE, - DatasetStatusFilter.DEACTIVATED, - DatasetStatusFilter.IN_PREPARATION, - ] - else: - statuses = [status] + clauses = [] + parameters: dict[str, Any] = { + "offset": pagination.offset, + "limit": pagination.limit, + } + if status != DatasetStatusFilter.ALL: + clauses.append("AND IFNULL(cs.`status`, 'in_preparation') = :status") + parameters["status"] = status - where_status = ",".join(f"'{status}'" for status in statuses) if user is None: - visible_to_user = "`visibility`='public'" - elif UserGroup.ADMIN in await user.get_groups(): - visible_to_user = "TRUE" - else: - visible_to_user = f"(`visibility`='public' OR `uploader`={user.user_id})" + clauses.append("AND `visibility`='public'") + elif UserGroup.ADMIN not in await user.get_groups(): + clauses.append("AND (`visibility`='public' OR `uploader`=:user_id)") + parameters["user_id"] = user.user_id + + if uploader: + clauses.append("AND `uploader`=:uploader") + parameters["uploader"] = uploader + + if data_name: + clauses.append("AND `name`=:data_name") + parameters["data_name"] = data_name + + if data_version: + clauses.append("AND `version`=:data_version") + parameters["data_version"] = data_version - where_name = "" if data_name is None else "AND `name`=:data_name" - where_version = "" if data_version is None else "AND `version`=:data_version" - where_uploader = "" if uploader is None else "AND `uploader`=:uploader" - data_id_str = ",".join(str(did) for did in data_id) if data_id else "" - where_data_id = "" if not data_id else f"AND d.`did` IN ({data_id_str})" + if data_id: + clauses.append("AND d.`did` IN :data_ids") + parameters["data_ids"] = data_id # requires some benchmarking on whether e.g., IN () is more efficient. - matching_tag = ( - text( + if tag: + clauses.append( """ - AND d.`did` IN ( - SELECT `id` - FROM dataset_tag as dt - WHERE dt.`tag`=:tag - ) - """, + AND d.`did` IN ( + SELECT `id` + FROM dataset_tag as dt + WHERE dt.`tag`=:tag + ) + """, ) - if tag - else "" - ) + parameters["tag"] = tag - def quality_clause(quality: str, range_: str | None) -> str: - if not range_: - return "" - if not (match := re.match(integer_range_regex, range_)): - msg = f"`range_` not a valid range: {range_}" - raise ValueError(msg) - start, end = match.groups() - value = f"`value` BETWEEN {start} AND {end[2:]}" if end else f"`value`={start}" - return f""" AND - d.`did` IN ( - SELECT `data` - FROM data_quality - WHERE `quality`='{quality}' AND {value} - ) - """ # noqa: S608 - `quality` is not user provided, value is filtered with regex + number_instances_filter = _quality_clause("NumberOfInstances", number_instances) + number_classes_filter = _quality_clause("NumberOfClasses", number_classes) + number_features_filter = _quality_clause("NumberOfFeatures", number_features) + number_missing_values_filter = _quality_clause("NumberOfMissingValues", number_missing_values) - number_instances_filter = quality_clause("NumberOfInstances", number_instances) - number_classes_filter = quality_clause("NumberOfClasses", number_classes) - number_features_filter = quality_clause("NumberOfFeatures", number_features) - number_missing_values_filter = quality_clause("NumberOfMissingValues", number_missing_values) + columns = ["did", "name", "version", "format", "file_id", "status"] matching_filter = text( f""" SELECT d.`did`,d.`name`,d.`version`,d.`format`,d.`file_id`, IFNULL(cs.`status`, 'in_preparation') FROM dataset AS d - LEFT JOIN ({current_status}) AS cs ON d.`did`=cs.`did` - WHERE {visible_to_user} {where_name} {where_version} {where_uploader} - {where_data_id} {matching_tag} {number_instances_filter} {number_features_filter} + LEFT JOIN ({status_subquery}) AS cs ON d.`did`=cs.`did` + WHERE 1=1 {number_instances_filter} {number_features_filter} {number_classes_filter} {number_missing_values_filter} - AND IFNULL(cs.`status`, 'in_preparation') IN ({where_status}) - LIMIT {pagination.limit} OFFSET {pagination.offset} + {" ".join(clauses)} + LIMIT :limit OFFSET :offset """, # noqa: S608 # I am not sure how to do this correctly without an error from Bandit here. # However, the `status` input is already checked by FastAPI to be from a set # of given options, so no injection is possible (I think). The `current_status` # subquery also has no user input. So I think this should be safe. ) - columns = ["did", "name", "version", "format", "file_id", "status"] + + if data_id: + matching_filter.bindparams(bindparam("data_ids", expanding=True)) result = await expdb_db.execute( matching_filter, - parameters={ - "tag": tag, - "data_name": data_name, - "data_version": data_version, - "uploader": uploader, - }, + parameters=parameters, ) rows = result.all() datasets: dict[int, dict[str, Any]] = {