From 53dba9c775a0dc8a61bfd55f410ddac9c32b5184 Mon Sep 17 00:00:00 2001 From: Brad Smith Date: Wed, 16 Nov 2022 15:16:03 -0500 Subject: [PATCH] create /internal/tasks/ API to call arbitrary Celery tasks remotely and check their async results --- cloudigrade/internal/serializers.py | 15 ++- .../internal/tests/views/test_task_get.py | 96 +++++++++++++++++++ .../internal/tests/views/test_task_run.py | 87 +++++++++++++++++ cloudigrade/internal/urls.py | 6 ++ cloudigrade/internal/views.py | 82 ++++++++++++++++ 5 files changed, 285 insertions(+), 1 deletion(-) create mode 100644 cloudigrade/internal/tests/views/test_task_get.py create mode 100644 cloudigrade/internal/tests/views/test_task_run.py diff --git a/cloudigrade/internal/serializers.py b/cloudigrade/internal/serializers.py index dd2a93aa..b0579279 100644 --- a/cloudigrade/internal/serializers.py +++ b/cloudigrade/internal/serializers.py @@ -3,7 +3,13 @@ from django_celery_beat.models import PeriodicTask from rest_framework import serializers from rest_framework.exceptions import ValidationError -from rest_framework.fields import BooleanField, CharField, ChoiceField, ListField +from rest_framework.fields import ( + BooleanField, + CharField, + ChoiceField, + JSONField, + ListField, +) from rest_framework.serializers import ModelSerializer, Serializer from api import models @@ -374,3 +380,10 @@ class InternalRedisRawInputSerializer(Serializer): command = ChoiceField(allowed_commands, required=True) args = ListField(required=False, allow_empty=True, child=CharField(min_length=1)) + + +class InternalRunTaskInputSerializer(Serializer): + """Serializer to validate input for the internal run_task API.""" + + task_name = CharField(required=True) + kwargs = JSONField(required=False) diff --git a/cloudigrade/internal/tests/views/test_task_get.py b/cloudigrade/internal/tests/views/test_task_get.py new file mode 100644 index 00000000..0a314fae --- /dev/null +++ b/cloudigrade/internal/tests/views/test_task_get.py @@ -0,0 +1,96 @@ +"""Collection of tests for getting task results from the internal API.""" +from unittest.mock import patch + +import faker +from django.test import TestCase +from rest_framework.test import APIRequestFactory + +from internal.views import task_get + +_faker = faker.Faker() + + +class TaskGetTest(TestCase): + """Task get view test case.""" + + def setUp(self): + """Set up shared test data.""" + self.factory = APIRequestFactory() + AsyncResult_patch = patch("internal.views.AsyncResult") + self.mock_AsyncResult = AsyncResult_patch.start() + self.addCleanup(AsyncResult_patch.stop) + + def test_task_get(self): + """Test happy path success getting a completed (ready) task.""" + async_result_id = str(_faker.uuid4()) + self.mock_AsyncResult.return_value.ready.return_value = True + expected_returned_value = self.mock_AsyncResult.return_value.get.return_value + + request = self.factory.get(f"/task_get/{async_result_id}/") + response = task_get(request, async_result_id) + + self.mock_AsyncResult.assert_called_once_with(async_result_id) + self.mock_AsyncResult.return_value.get.assert_called_once_with() + + self.assertEqual(response.status_code, 200) + self.assertTrue(response.data["async_result_id"], async_result_id) + self.assertTrue(response.data["ready"]) + self.assertEqual(response.data["result"], expected_returned_value) + + def test_task_get_not_ready(self): + """Test when task has not completed (is not ready).""" + async_result_id = str(_faker.uuid4()) + self.mock_AsyncResult.return_value.ready.return_value = False + + request = self.factory.get(f"/task_get/{async_result_id}/") + response = task_get(request, async_result_id) + + self.mock_AsyncResult.assert_called_once_with(async_result_id) + self.mock_AsyncResult.return_value.get.assert_not_called() + + self.assertEqual(response.status_code, 200) + self.assertTrue(response.data["async_result_id"], async_result_id) + self.assertFalse(response.data["ready"]) + self.assertIsNone(response.data["result"]) + + def test_task_get_execution_raised_exception(self): + """ + Test when the worker raised an execution when executing the task. + + Typically, what happens is the async task reports "ready" and the "get" call + raises the exception that occurred when the worker executed the task. Example: + + $ http localhost:8000/internal/tasks/ \ + task_name="api.tasks.enable_account" kwargs:='{"cloud_account_id":"potato"}' + + HTTP/1.1 201 Created + { + "async_result_id": "9fb41f3d-e4ed-4b55-844a-6d7b62732d80" + } + + $ http localhost:8000/internal/tasks/9fb41f3d-e4ed-4b55-844a-6d7b62732d80/ + + HTTP/1.1 200 OK + { + "async_result_id": "9fb41f3d-e4ed-4b55-844a-6d7b62732d80", + "error_args": [ + "Field 'id' expected a number but got 'potato'." + ], + "error_class": "builtins.ValueError" + } + """ + async_result_id = str(_faker.uuid4()) + self.mock_AsyncResult.return_value.ready.return_value = True + error_message = "Field 'id' expected a number but got 'potato'." + self.mock_AsyncResult.return_value.get.side_effect = ValueError(error_message) + + request = self.factory.get(f"/task_get/{async_result_id}/") + response = task_get(request, async_result_id) + + self.mock_AsyncResult.assert_called_once_with(async_result_id) + self.mock_AsyncResult.return_value.get.assert_called_once_with() + + self.assertEqual(response.status_code, 200) + self.assertTrue(response.data["async_result_id"], async_result_id) + self.assertEqual(response.data["error_args"], [error_message]) + self.assertEqual(response.data["error_class"], "builtins.ValueError") diff --git a/cloudigrade/internal/tests/views/test_task_run.py b/cloudigrade/internal/tests/views/test_task_run.py new file mode 100644 index 00000000..6fd28229 --- /dev/null +++ b/cloudigrade/internal/tests/views/test_task_run.py @@ -0,0 +1,87 @@ +"""Collection of tests for running tasks from the internal API.""" +from unittest.mock import patch + +import faker +from django.test import TestCase +from rest_framework.test import APIRequestFactory + +from internal.views import task_run + +_faker = faker.Faker() + + +class TaskRunTest(TestCase): + """Task run view test case.""" + + def setUp(self): + """Set up shared test data.""" + self.factory = APIRequestFactory() + celery_app_patch = patch("internal.views.celery_app") + self.mock_celery_app = celery_app_patch.start() + self.addCleanup(celery_app_patch.stop) + + def test_task_run(self): + """Test happy path success for running an arbitrary task.""" + task_name = _faker.slug() + task_kwargs = {_faker.slug(): _faker.slug(), _faker.slug(): _faker.slug()} + + mock_signature = self.mock_celery_app.signature.return_value + mock_async_result = mock_signature.delay.return_value + + request = self.factory.post( + "/task_run/", + data={"task_name": task_name, "kwargs": task_kwargs}, + format="json", + ) + response = task_run(request) + + self.mock_celery_app.signature.assert_called_with(task_name) + mock_signature.delay.assert_called_with(**task_kwargs) + + self.assertEqual(response.status_code, 201) + self.assertEqual(response.data["async_result_id"], mock_async_result.id) + + def test_task_run_bad_post_params(self): + """Test error is returned and task is not called when bad input is given.""" + task_name = _faker.slug() + request = self.factory.post( + "/task_run/", + data={task_name: task_name, "unrelated_potato_argument": task_name}, + format="json", + ) + response = task_run(request) + self.assertEqual(response.status_code, 400) + self.mock_celery_app.signature.assert_not_called() + + def test_task_run_bad_task_kwargs(self): + """ + Test error is returned when task is called when bad kwargs. + + Celery's delay raises TypeError if the arguments it is given do not match the + actual function signature of the task being called. Example: + + $ http localhost:8000/internal/tasks/ \ + task_name="api.tasks.enable_account" kwargs:='{"potato":1}' + + HTTP/1.1 400 Bad Request + {"error":"enable_account() got an unexpected keyword argument 'potato'"} + """ + task_name = _faker.slug() + task_kwargs = {_faker.slug(): _faker.slug(), _faker.slug(): _faker.slug()} + + type_error_message = _faker.sentence() + mock_signature = self.mock_celery_app.signature.return_value + mock_signature.delay.side_effect = TypeError(type_error_message) + + request = self.factory.post( + "/task_run/", + data={"task_name": task_name, "kwargs": task_kwargs}, + format="json", + ) + response = task_run(request) + + self.mock_celery_app.signature.assert_called_with(task_name) + mock_signature.delay.assert_called_with(**task_kwargs) + + self.assertEqual(response.status_code, 400) + self.assertEqual(response.data["error"], type_error_message) diff --git a/cloudigrade/internal/urls.py b/cloudigrade/internal/urls.py index 0cbca5ab..df424784 100644 --- a/cloudigrade/internal/urls.py +++ b/cloudigrade/internal/urls.py @@ -125,6 +125,12 @@ class PermissiveAPIRootView(routers.APIRootView): ), path("recalculate_runs/", views.recalculate_runs, name="internal-recalculate-runs"), path("redis_raw/", views.redis_raw, name="internal-redis-raw"), + path("tasks/", views.task_run, name="internal-task-run"), + path( + "tasks//", + views.task_get, + name="internal-task-collect", + ), path("sources_kafka/", views.sources_kafka, name="internal-sources-kafka"), ] diff --git a/cloudigrade/internal/views.py b/cloudigrade/internal/views.py index 36ae5391..d1a8910d 100644 --- a/cloudigrade/internal/views.py +++ b/cloudigrade/internal/views.py @@ -4,6 +4,9 @@ import os from app_common_python import isClowderEnabled +from celery import current_app as celery_app +from celery.exceptions import NotRegistered +from celery.result import AsyncResult from dateutil import tz from dateutil.parser import ParserError, parse from django.conf import settings @@ -504,3 +507,82 @@ def redis_raw(request): return Response(status=status.HTTP_500_INTERNAL_SERVER_ERROR) return Response({"results": results}, status=status.HTTP_200_OK) + + +@api_view(["POST"]) +@authentication_classes([IdentityHeaderAuthenticationInternal]) +@permission_classes([permissions.AllowAny]) +@schema(None) +def task_run(request): + """ + Execute an arbitrary task with optional keyword arguments. + + This is an internal only API, so we do not want it to be in the openapi.spec. + Typical calls to this internal HTTP API would look like: + + http localhost:8000/internal/tasks/ \ + task_name="api.tasks.persist_inspection_cluster_results_task" + + http localhost:8000/internal/tasks/ \ + task_name="api.tasks.synthesize_concurrent_usage" \ + kwargs:='{"__": null, "user_id": 1, "on_date": "2022-01-01"}' + """ + serializer = serializers.InternalRunTaskInputSerializer(data=request.data) + if serializer.is_valid(raise_exception=True): + task_name = serializer.validated_data["task_name"] + kwargs = serializer.validated_data.get("kwargs", {}) + + signature = celery_app.signature(task_name) + try: + async_result = signature.delay(**kwargs) + except TypeError as e: + # TypeError may be raised if you send bad arguments to the delay function. + logger.info(e) + type_error_arg = e.args[0] if getattr(e, "args") else None + return Response( + {"error": type_error_arg}, + status=status.HTTP_400_BAD_REQUEST, + ) + return Response( + {"async_result_id": async_result.id}, status=status.HTTP_201_CREATED + ) + + +@api_view(["GET"]) +@authentication_classes([IdentityHeaderAuthenticationInternal]) +@permission_classes([permissions.AllowAny]) +@schema(None) +def task_get(request, async_result_id): + """ + Get the async result of a task by its async result ID. + + This is an internal only API, so we do not want it to be in the openapi.spec. + Typical calls to this internal HTTP API would look like: + + http localhost:8000/internal/tasks/e5a666a1-dadc-4de8-8f6e-53d4a0bec3a0/ + """ + try: + async_result = AsyncResult(async_result_id) + ready = async_result.ready() + result = async_result.get() if ready else None + return Response( + {"async_result_id": async_result_id, "ready": ready, "result": result} + ) + except NotRegistered as e: + task_name = e.args[0] if getattr(e, "args") else None + return Response( + { + "async_result_id": async_result_id, + "error": f"not a registered task name: {task_name}", + } + ) + except Exception as e: + logger.info(e) + error_args = [str(arg) for arg in getattr(e, "args", [])] + return Response( + { + "async_result_id": async_result_id, + "error_class": f"{e.__class__.__module__}.{e.__class__.__name__}", + "error_args": error_args, + } + )