Skip to content

Commit

Permalink
Upgrade to Pydantic 2
Browse files Browse the repository at this point in the history
  • Loading branch information
hluk committed Mar 5, 2024
1 parent c1c8f38 commit 52cfd37
Show file tree
Hide file tree
Showing 9 changed files with 459 additions and 329 deletions.
279 changes: 172 additions & 107 deletions poetry.lock

Large diffs are not rendered by default.

10 changes: 3 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,12 @@ SQLAlchemy = {version = "^2.0.24"}
psycopg2-binary = {version = "^2.9.7"}
alembic = "^1.13.1"
iso8601 = "^2.1.0"
pydantic = "^1.10.14"
Flask-Pydantic = "^0.11.0"
pydantic = "^2.6.3"
Flask-Pydantic = "^0.12.0"

email-validator = "^2.1.1"
python-ldap = "^3.4.3"
Flask-pyoidc = "^3.14.0"
# dependency of Flask-pyoidc;
# oic-1.6.1 is compatible only with pydantic 2
# whereas Flask-Pydantic needs older versions
oic = {version = "<=1.6.0", optional = true}
Flask-pyoidc = "^3.14.3"
Flask-Session = "^0.6.0"

# tracing support
Expand Down
15 changes: 15 additions & 0 deletions resultsdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import os

from flask import Flask, current_app, jsonify, send_from_directory, session
from flask_pydantic.exceptions import ValidationError
from flask_pyoidc import OIDCAuthentication
from flask_pyoidc.provider_configuration import (
ClientMetadata,
Expand Down Expand Up @@ -52,6 +53,8 @@
except NameError:
basestring = (str, bytes)

VALIDATION_KEYS = frozenset({"input", "loc", "msg", "type", "url"})


def create_app(config_obj=None):
app = Flask(__name__)
Expand Down Expand Up @@ -225,6 +228,18 @@ def bad_gateway(error):
app.logger.error("External error received: %s", error)
return jsonify({"message": "Bad Gateway"}), 502

app.register_error_handler(ValidationError, handle_validation_error)


def handle_validation_error(error: ValidationError):
errors = error.body_params or error.form_params or error.path_params or error.query_params
# Keep only interesting stuff and remove objects potentially
# unserializable in JSON.
err = [{k: v for k, v in e.items() if k in VALIDATION_KEYS} for e in errors]
response = jsonify({"validation_error": err})
response.status_code = 400
return response


def init_session(app):
app.config["SESSION_SQLALCHEMY"] = db
Expand Down
2 changes: 2 additions & 0 deletions resultsdb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class Config(object):
SQLALCHEMY_DATABASE_URI = "sqlite://"
SHOW_DB_URI = True

FLASK_PYDANTIC_VALIDATION_ERROR_RAISE = True

LOGGING = {
"version": 1,
"disable_existing_loggers": False,
Expand Down
32 changes: 11 additions & 21 deletions resultsdb/controllers/api_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from flask import Blueprint, jsonify, render_template
from flask import current_app as app
from flask_pydantic import validate
from pydantic import BaseModel
from pydantic import RootModel

from resultsdb.models import db
from resultsdb.authorization import match_testcase_permissions, verify_authorization
Expand All @@ -22,20 +22,6 @@
api = Blueprint("api_v3", __name__)


def ensure_dict_input(cls):
"""
Wraps Pydantic model to ensure that the input type is dict.
This is a workaround for a bug in flask-pydantic that causes validation to
fail with unexpected exception.
"""

class EnsureJsonObject(BaseModel):
__root__: cls

return EnsureJsonObject


def permissions():
return app.config.get("PERMISSIONS", [])

Expand All @@ -58,13 +44,15 @@ def create_result(body: ResultParamsBase):
app.logger.debug(
"Updating ref_url for testcase %s: %s", body.testcase, body.testcase_ref_url
)
testcase.ref_url = body.testcase_ref_url
testcase.ref_url = str(body.testcase_ref_url)
db.session.add(testcase)

ref_url = str(body.ref_url) if body.ref_url else None

result = Result(
testcase=testcase,
outcome=body.outcome,
ref_url=body.ref_url,
ref_url=ref_url,
note=body.note,
groups=[],
)
Expand All @@ -83,8 +71,10 @@ def create_endpoint(params_class, oidc, provider):

@oidc.token_auth(provider)
@validate()
def create(body: ensure_dict_input(params_class)):
return create_result(body)
# Using RootModel is a workaround for a bug in flask-pydantic that causes
# validation to fail with unexpected exception.
def create(body: RootModel[params_class]):
return create_result(body.root)

def get_schema():
return jsonify(params.construct().schema()), 200
Expand Down Expand Up @@ -126,8 +116,8 @@ def index():
"method": "POST",
"description": example.__doc__,
"query_type": "JSON",
"example": example.json(exclude_unset=True, indent=2),
"schema": example.schema(),
"example": example.model_dump_json(exclude_unset=True, indent=2),
"schema": example.model_json_schema(),
"schema_endpoint": f".schemas_{example.artifact_type()}s",
}
for example in examples
Expand Down
89 changes: 49 additions & 40 deletions resultsdb/parsers/api_v2.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
# SPDX-License-Identifier: LGPL-2.0-or-later
from datetime import datetime, timezone
from numbers import Number
from typing import Any, List, Optional
from typing import Any, List, Optional, Union
from typing_extensions import Annotated

import iso8601
from pydantic import BaseModel, Field, validator
from pydantic import (
AfterValidator,
BaseModel,
Field,
StringConstraints,
ValidationInfo,
field_validator,
)
from pydantic.types import constr

from resultsdb.models.results import result_outcomes
Expand Down Expand Up @@ -41,49 +49,48 @@ class BaseListParams(BaseModel):


class GroupsParams(BaseListParams):
uuid: Optional[str]
description: Optional[str]
description_like_: Optional[str] = Field(alias="description:like")
uuid: Optional[str] = None
description: Optional[str] = None
description_like_: Optional[str] = Field(alias="description:like", default=None)


class CreateGroupParams(BaseModel):
uuid: Optional[str]
ref_url: Optional[str]
description: Optional[str]
uuid: Optional[str] = None
ref_url: Optional[str] = None
description: Optional[str] = None


class QueryList(List[str]):
@classmethod
def __get_validators__(cls):
yield cls.validate
def validate_query_list(v: Union[str, List[str]], info: ValidationInfo):
if isinstance(v, str):
return [x for x in (x.strip() for x in v.split(",")) if x]
if isinstance(v, list) and len(v) == 1 and isinstance(v[0], str):
return [x for x in (x.strip() for x in v[0].split(",")) if x]
return v

@classmethod
def validate(cls, v):
if isinstance(v, str):
return cls([x for x in (x.strip() for x in v.split(",")) if x])
if isinstance(v, list) and len(v) == 1 and isinstance(v[0], str):
return cls([x for x in (x.strip() for x in v[0].split(",")) if x])
return cls(v)

QueryList = Annotated[Union[str, List[str]], AfterValidator(validate_query_list)]


class ResultsParams(BaseListParams):
sort_: str = Field(alias="_sort", default="")
since: dict = {"start": None, "end": None}
outcome: Optional[QueryList]
groups: Optional[QueryList]
testcases: Optional[QueryList]
testcases_like_: Optional[QueryList] = Field(alias="testcases:like")
distinct_on_: Optional[QueryList] = Field(alias="_distinct_on")
outcome: Optional[QueryList] = None
groups: Optional[QueryList] = None
testcases: Optional[QueryList] = None
testcases_like_: Optional[QueryList] = Field(alias="testcases:like", default=None)
distinct_on_: Optional[QueryList] = Field(alias="_distinct_on", default=None)

@validator("since", pre=True)
@field_validator("since", mode="before")
@classmethod
def parse_since(cls, v):
try:
s, e = parse_since(v[0])
s, e = parse_since(v)
except iso8601.iso8601.ParseError:
raise ValueError("must be in ISO8601 format")
return {"start": s, "end": e}

@validator("outcome")
@field_validator("outcome", mode="after")
@classmethod
def outcome_must_be_valid(cls, v):
outcomes = [x.upper() for x in v]
if any(x not in result_outcomes() for x in outcomes):
Expand All @@ -92,23 +99,25 @@ def outcome_must_be_valid(cls, v):


class CreateResultParams(BaseModel):
outcome: constr(min_length=1, strip_whitespace=True, to_upper=True)
outcome: Annotated[str, StringConstraints(min_length=1, strip_whitespace=True, to_upper=True)]
testcase: dict
groups: Optional[list]
note: Optional[str]
data: Optional[dict]
ref_url: Optional[str]
submit_time: Any
groups: Optional[list] = None
note: Optional[str] = None
data: Optional[dict] = None
ref_url: Optional[str] = None
submit_time: Any = None

@validator("testcase", pre=True)
@field_validator("testcase", mode="before")
@classmethod
def parse_testcase(cls, v):
if not v or (isinstance(v, dict) and not v.get("name")):
raise ValueError("testcase name must be non-empty")
if isinstance(v, str):
return {"name": v}
return v

@validator("submit_time", pre=True)
@field_validator("submit_time", mode="before")
@classmethod
def parse_submit_time(cls, v):
if isinstance(v, datetime):
return v
Expand All @@ -133,24 +142,24 @@ def parse_submit_time(cls, v):
" got %r" % v
)

@validator("testcase")
@field_validator("testcase", mode="after")
def testcase_must_be_valid(cls, v):
if isinstance(v, dict) and not v.get("name"):
raise ValueError("testcase name must be non-empty")
return v

@validator("outcome")
@field_validator("outcome", mode="after")
def outcome_must_be_valid(cls, v):
if v not in result_outcomes():
raise ValueError(f'must be one of: {", ".join(result_outcomes())}')
return v


class TestcasesParams(BaseListParams):
name: Optional[str]
name_like_: Optional[str] = Field(alias="name:like")
name: Optional[str] = None
name_like_: Optional[str] = Field(alias="name:like", default=None)


class CreateTestcaseParams(BaseModel):
name: constr(min_length=1)
ref_url: Optional[str]
ref_url: Optional[str] = None

0 comments on commit 52cfd37

Please sign in to comment.