Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 69 additions & 64 deletions src/routers/openml/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Annotated, Any, Literal, NamedTuple

from fastapi import APIRouter, Body, Depends
from sqlalchemy import text
from sqlalchemy import bindparam, text
from sqlalchemy.engine import Row
from sqlalchemy.ext.asyncio import AsyncConnection

Expand Down Expand Up @@ -73,9 +73,26 @@ class DatasetStatusFilter(StrEnum):
ALL = "all"


def _quality_clause(quality: str, range_: str | None) -> str:
if not range_:
return ""
if not (match := re.match(integer_range_regex, range_)):
msg = f"`range_` not a valid range: {range_}"
raise ValueError(msg)
start, end = match.groups()
value = f"`value` BETWEEN {start} AND {end[2:]}" if end else f"`value`={start}"
return f""" AND
d.`did` IN (
SELECT `data`
FROM data_quality
WHERE `quality`='{quality}' AND {value}
)
""" # noqa: S608 - `quality` is not user provided, value is filtered with regex


@router.post(path="/list", description="Provided for convenience, same as `GET` endpoint.")
@router.get(path="/list")
async def list_datasets( # noqa: PLR0913
async def list_datasets( # noqa: PLR0913, C901
pagination: Annotated[Pagination, Body(default_factory=Pagination)],
data_name: Annotated[str | None, CasualString128] = None,
tag: Annotated[str | None, SystemString64] = None,
Expand Down Expand Up @@ -103,7 +120,7 @@ async def list_datasets( # noqa: PLR0913
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None,
) -> list[dict[str, Any]]:
assert expdb_db is not None # noqa: S101
current_status = text(
status_subquery = text(
"""
SELECT ds1.`did`, ds1.`status`
FROM dataset_status as ds1
Expand All @@ -115,90 +132,78 @@ async def list_datasets( # noqa: PLR0913
""",
)

if status == DatasetStatusFilter.ALL:
statuses = [
DatasetStatusFilter.ACTIVE,
DatasetStatusFilter.DEACTIVATED,
DatasetStatusFilter.IN_PREPARATION,
]
else:
statuses = [status]
clauses = []
parameters: dict[str, Any] = {
"offset": pagination.offset,
"limit": pagination.limit,
}
if status != DatasetStatusFilter.ALL:
clauses.append("AND IFNULL(cs.`status`, 'in_preparation') = :status")
parameters["status"] = status

where_status = ",".join(f"'{status}'" for status in statuses)
if user is None:
visible_to_user = "`visibility`='public'"
elif UserGroup.ADMIN in await user.get_groups():
visible_to_user = "TRUE"
else:
visible_to_user = f"(`visibility`='public' OR `uploader`={user.user_id})"
clauses.append("AND `visibility`='public'")
elif UserGroup.ADMIN not in await user.get_groups():
clauses.append("AND (`visibility`='public' OR `uploader`=:user_id)")
parameters["user_id"] = user.user_id

if uploader:
clauses.append("AND `uploader`=:uploader")
parameters["uploader"] = uploader

if data_name:
clauses.append("AND `name`=:data_name")
parameters["data_name"] = data_name

if data_version:
clauses.append("AND `version`=:data_version")
parameters["data_version"] = data_version

where_name = "" if data_name is None else "AND `name`=:data_name"
where_version = "" if data_version is None else "AND `version`=:data_version"
where_uploader = "" if uploader is None else "AND `uploader`=:uploader"
data_id_str = ",".join(str(did) for did in data_id) if data_id else ""
where_data_id = "" if not data_id else f"AND d.`did` IN ({data_id_str})"
if data_id:
clauses.append("AND d.`did` IN :data_ids")
parameters["data_ids"] = data_id

# requires some benchmarking on whether e.g., IN () is more efficient.
matching_tag = (
text(
if tag:
clauses.append(
"""
AND d.`did` IN (
SELECT `id`
FROM dataset_tag as dt
WHERE dt.`tag`=:tag
)
""",
AND d.`did` IN (
SELECT `id`
FROM dataset_tag as dt
WHERE dt.`tag`=:tag
)
""",
)
if tag
else ""
)
parameters["tag"] = tag

def quality_clause(quality: str, range_: str | None) -> str:
if not range_:
return ""
if not (match := re.match(integer_range_regex, range_)):
msg = f"`range_` not a valid range: {range_}"
raise ValueError(msg)
start, end = match.groups()
value = f"`value` BETWEEN {start} AND {end[2:]}" if end else f"`value`={start}"
return f""" AND
d.`did` IN (
SELECT `data`
FROM data_quality
WHERE `quality`='{quality}' AND {value}
)
""" # noqa: S608 - `quality` is not user provided, value is filtered with regex
number_instances_filter = _quality_clause("NumberOfInstances", number_instances)
number_classes_filter = _quality_clause("NumberOfClasses", number_classes)
number_features_filter = _quality_clause("NumberOfFeatures", number_features)
number_missing_values_filter = _quality_clause("NumberOfMissingValues", number_missing_values)

number_instances_filter = quality_clause("NumberOfInstances", number_instances)
number_classes_filter = quality_clause("NumberOfClasses", number_classes)
number_features_filter = quality_clause("NumberOfFeatures", number_features)
number_missing_values_filter = quality_clause("NumberOfMissingValues", number_missing_values)
columns = ["did", "name", "version", "format", "file_id", "status"]
matching_filter = text(
f"""
SELECT d.`did`,d.`name`,d.`version`,d.`format`,d.`file_id`,
IFNULL(cs.`status`, 'in_preparation')
FROM dataset AS d
LEFT JOIN ({current_status}) AS cs ON d.`did`=cs.`did`
WHERE {visible_to_user} {where_name} {where_version} {where_uploader}
{where_data_id} {matching_tag} {number_instances_filter} {number_features_filter}
LEFT JOIN ({status_subquery}) AS cs ON d.`did`=cs.`did`
WHERE 1=1 {number_instances_filter} {number_features_filter}
{number_classes_filter} {number_missing_values_filter}
AND IFNULL(cs.`status`, 'in_preparation') IN ({where_status})
LIMIT {pagination.limit} OFFSET {pagination.offset}
{" ".join(clauses)}
LIMIT :limit OFFSET :offset
""", # noqa: S608
# I am not sure how to do this correctly without an error from Bandit here.
# However, the `status` input is already checked by FastAPI to be from a set
# of given options, so no injection is possible (I think). The `current_status`
# subquery also has no user input. So I think this should be safe.
)
columns = ["did", "name", "version", "format", "file_id", "status"]

if data_id:
matching_filter.bindparams(bindparam("data_ids", expanding=True))
result = await expdb_db.execute(
matching_filter,
parameters={
"tag": tag,
"data_name": data_name,
"data_version": data_version,
"uploader": uploader,
},
parameters=parameters,
)
rows = result.all()
datasets: dict[int, dict[str, Any]] = {
Expand Down
Loading