From 99ea4b91dafc46416e12b1bbdfd49ace2fd31cde Mon Sep 17 00:00:00 2001 From: Manabu Niseki Date: Sun, 28 Jun 2020 19:28:43 +0900 Subject: [PATCH] refactor: improve input validation Add mime-type based input validation --- app/api/endpoints/analyze.py | 33 ++++++++++++++++++++--------- app/schemas/payload.py | 15 ++++++++++++- app/services/validator.py | 15 +++++++++++++ pyproject.toml | 2 +- tests/api/endpoints/test_analyze.py | 22 ++++++++++++++++--- tests/schemas/test_payload.py | 24 +++++++++++++++++++++ 6 files changed, 96 insertions(+), 15 deletions(-) create mode 100644 app/services/validator.py create mode 100644 tests/schemas/test_payload.py diff --git a/app/api/endpoints/analyze.py b/app/api/endpoints/analyze.py index 246cd4e..54a7cea 100644 --- a/app/api/endpoints/analyze.py +++ b/app/api/endpoints/analyze.py @@ -1,32 +1,45 @@ from fastapi import APIRouter, File +from fastapi.encoders import jsonable_encoder +from fastapi.responses import JSONResponse +from pydantic import ValidationError from app.factories.response import ResponseFactory -from app.schemas.payload import Payload +from app.schemas.payload import FilePayload, Payload from app.schemas.response import Response router = APIRouter() +async def _analyze(file: bytes) -> Response: + try: + payload = FilePayload(file=file) + except ValidationError as exc: + return JSONResponse( + status_code=422, content=jsonable_encoder({"detail": exc.errors()}) + ) + + return await ResponseFactory.from_bytes(payload.file) + + @router.post( "/", response_model=Response, - response_description="Return a parsed result", - summary="Parse an eml", - description="Parse an eml and return a parsed result", + response_description="Return an analysis result", + summary="Analyze an eml", + description="Analyze an eml and return an analysis result", status_code=200, ) async def analyze(payload: Payload) -> Response: - eml_file = payload.eml_file.encode() - return await ResponseFactory.from_bytes(eml_file) + return await _analyze(payload.file.encode()) @router.post( "/file", response_model=Response, - response_description="Return a parsed result", - summary="Parse an eml", - description="Parse an eml and return a parsed result", + response_description="Return an analysis result", + summary="Analyze an eml", + description="Analyze an eml and return an analysis result", status_code=200, ) async def analyze_file(file: bytes = File(...)) -> Response: - return await ResponseFactory.from_bytes(file) + return await _analyze(file) diff --git a/app/schemas/payload.py b/app/schemas/payload.py index cc545c6..f39995d 100644 --- a/app/schemas/payload.py +++ b/app/schemas/payload.py @@ -1,5 +1,18 @@ from fastapi_utils.api_model import APIModel +from pydantic import validator + +from app.services.validator import is_eml_file class Payload(APIModel): - eml_file: str + file: str + + +class FilePayload(APIModel): + file: bytes + + @validator("file") + def eml_file_must_be_eml(cls, v: bytes): + if is_eml_file(v) is False: + raise ValueError("Invalid EML file.") + return v diff --git a/app/services/validator.py b/app/services/validator.py new file mode 100644 index 0000000..a697e25 --- /dev/null +++ b/app/services/validator.py @@ -0,0 +1,15 @@ +from typing import cast + +import magic + +VALID_MIME_TYPES = ["message/rfc822", "text/html", "text/plain"] + + +def is_eml_file(data: bytes) -> bool: + detected = magic.detect_from_content(data) + mime_type = cast(str, detected.mime_type) + + if mime_type in VALID_MIME_TYPES: + return True + + return False diff --git a/pyproject.toml b/pyproject.toml index 3b71744..47a3d73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ seed-isort-config = "^2.1.0" [tool.isort] force_grid_wrap = 0 include_trailing_comma = true -known_third_party = ["aiospamc", "arrow", "async_timeout", "asynctest", "eml_parser", "fastapi", "fastapi_utils", "httpx", "ioc_finder", "loguru", "olefile", "oletools", "pytest", "respx", "starlette"] +known_third_party = ["aiospamc", "arrow", "async_timeout", "asynctest", "eml_parser", "fastapi", "fastapi_utils", "httpx", "ioc_finder", "loguru", "magic", "olefile", "oletools", "pydantic", "pytest", "respx", "starlette"] line_length = 88 multi_line_output = 3 use_parentheses= true diff --git a/tests/api/endpoints/test_analyze.py b/tests/api/endpoints/test_analyze.py index 0957187..517f068 100644 --- a/tests/api/endpoints/test_analyze.py +++ b/tests/api/endpoints/test_analyze.py @@ -4,8 +4,8 @@ @pytest.mark.asyncio -async def test_analyze(client, emailrep_response): - payload = {"eml_file": read_file("sample.eml")} +async def test_analyze(client): + payload = {"file": read_file("sample.eml")} response = await client.post("/api/analyze/", json=payload) json = response.json() @@ -14,10 +14,26 @@ async def test_analyze(client, emailrep_response): @pytest.mark.asyncio -async def test_analyze_file(client, emailrep_response): +async def test_analyze_with_invalid_file(client): + payload = {"file": ""} + response = await client.post("/api/analyze/", json=payload) + + assert response.status_code == 422 + + +@pytest.mark.asyncio +async def test_analyze_file(client): data = {"file": read_file("sample.eml").encode()} response = await client.post("/api/analyze/file", data=data) json = response.json() assert json.get("eml", {}).get("header", {}).get("subject") == "Winter promotions" assert json.get("eml", {}).get("header", {}).get("from") == "no-reply@example.com" + + +@pytest.mark.asyncio +async def test_analyze_file_with_invalid_file(client): + data = {"file": b""} + response = await client.post("/api/analyze/file", data=data) + + assert response.status_code == 422 diff --git a/tests/schemas/test_payload.py b/tests/schemas/test_payload.py new file mode 100644 index 0000000..0f7ab19 --- /dev/null +++ b/tests/schemas/test_payload.py @@ -0,0 +1,24 @@ +import pytest + +from app.schemas.payload import FilePayload + + +def test_sample_eml(sample_eml: bytes): + FilePayload(file=sample_eml) + + +def test_multipart_eml(multipart_eml: bytes): + FilePayload(file=multipart_eml) + + +def test_encrypted_docx_eml(encrypted_docx_eml: bytes): + FilePayload(file=encrypted_docx_eml) + + +def test_cc_eml(cc_eml: bytes): + FilePayload(file=cc_eml) + + +def test_invalid_eml_file(): + with pytest.raises(ValueError): + FilePayload(file=b"")