Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create /internal/tasks/ API to call arbitrary Celery tasks remotely and check their async results #1404

Merged
merged 1 commit into from
Nov 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 14 additions & 1 deletion cloudigrade/internal/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
from django_celery_beat.models import PeriodicTask
from rest_framework import serializers
from rest_framework.exceptions import ValidationError
from rest_framework.fields import BooleanField, CharField, ChoiceField, ListField
from rest_framework.fields import (
BooleanField,
CharField,
ChoiceField,
JSONField,
ListField,
)
from rest_framework.serializers import ModelSerializer, Serializer

from api import models
Expand Down Expand Up @@ -374,3 +380,10 @@ class InternalRedisRawInputSerializer(Serializer):

command = ChoiceField(allowed_commands, required=True)
args = ListField(required=False, allow_empty=True, child=CharField(min_length=1))


class InternalRunTaskInputSerializer(Serializer):
"""Serializer to validate input for the internal run_task API."""

task_name = CharField(required=True)
kwargs = JSONField(required=False)
96 changes: 96 additions & 0 deletions cloudigrade/internal/tests/views/test_task_get.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Collection of tests for getting task results from the internal API."""
from unittest.mock import patch

import faker
from django.test import TestCase
from rest_framework.test import APIRequestFactory

from internal.views import task_get

_faker = faker.Faker()


class TaskGetTest(TestCase):
"""Task get view test case."""

def setUp(self):
"""Set up shared test data."""
self.factory = APIRequestFactory()
AsyncResult_patch = patch("internal.views.AsyncResult")
self.mock_AsyncResult = AsyncResult_patch.start()
self.addCleanup(AsyncResult_patch.stop)

def test_task_get(self):
"""Test happy path success getting a completed (ready) task."""
async_result_id = str(_faker.uuid4())
self.mock_AsyncResult.return_value.ready.return_value = True
expected_returned_value = self.mock_AsyncResult.return_value.get.return_value

request = self.factory.get(f"/task_get/{async_result_id}/")
response = task_get(request, async_result_id)

self.mock_AsyncResult.assert_called_once_with(async_result_id)
self.mock_AsyncResult.return_value.get.assert_called_once_with()

self.assertEqual(response.status_code, 200)
self.assertTrue(response.data["async_result_id"], async_result_id)
self.assertTrue(response.data["ready"])
self.assertEqual(response.data["result"], expected_returned_value)

def test_task_get_not_ready(self):
"""Test when task has not completed (is not ready)."""
async_result_id = str(_faker.uuid4())
self.mock_AsyncResult.return_value.ready.return_value = False

request = self.factory.get(f"/task_get/{async_result_id}/")
response = task_get(request, async_result_id)

self.mock_AsyncResult.assert_called_once_with(async_result_id)
self.mock_AsyncResult.return_value.get.assert_not_called()

self.assertEqual(response.status_code, 200)
self.assertTrue(response.data["async_result_id"], async_result_id)
self.assertFalse(response.data["ready"])
self.assertIsNone(response.data["result"])

def test_task_get_execution_raised_exception(self):
"""
Test when the worker raised an execution when executing the task.

Typically, what happens is the async task reports "ready" and the "get" call
raises the exception that occurred when the worker executed the task. Example:

$ http localhost:8000/internal/tasks/ \
task_name="api.tasks.enable_account" kwargs:='{"cloud_account_id":"potato"}'

HTTP/1.1 201 Created
{
"async_result_id": "9fb41f3d-e4ed-4b55-844a-6d7b62732d80"
}

$ http localhost:8000/internal/tasks/9fb41f3d-e4ed-4b55-844a-6d7b62732d80/

HTTP/1.1 200 OK
{
"async_result_id": "9fb41f3d-e4ed-4b55-844a-6d7b62732d80",
"error_args": [
"Field 'id' expected a number but got 'potato'."
],
"error_class": "builtins.ValueError"
}
"""
async_result_id = str(_faker.uuid4())
self.mock_AsyncResult.return_value.ready.return_value = True
error_message = "Field 'id' expected a number but got 'potato'."
self.mock_AsyncResult.return_value.get.side_effect = ValueError(error_message)

request = self.factory.get(f"/task_get/{async_result_id}/")
response = task_get(request, async_result_id)

self.mock_AsyncResult.assert_called_once_with(async_result_id)
self.mock_AsyncResult.return_value.get.assert_called_once_with()

self.assertEqual(response.status_code, 200)
self.assertTrue(response.data["async_result_id"], async_result_id)
self.assertEqual(response.data["error_args"], [error_message])
self.assertEqual(response.data["error_class"], "builtins.ValueError")
87 changes: 87 additions & 0 deletions cloudigrade/internal/tests/views/test_task_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Collection of tests for running tasks from the internal API."""
from unittest.mock import patch

import faker
from django.test import TestCase
from rest_framework.test import APIRequestFactory

from internal.views import task_run

_faker = faker.Faker()


class TaskRunTest(TestCase):
"""Task run view test case."""

def setUp(self):
"""Set up shared test data."""
self.factory = APIRequestFactory()
celery_app_patch = patch("internal.views.celery_app")
self.mock_celery_app = celery_app_patch.start()
self.addCleanup(celery_app_patch.stop)

def test_task_run(self):
"""Test happy path success for running an arbitrary task."""
task_name = _faker.slug()
task_kwargs = {_faker.slug(): _faker.slug(), _faker.slug(): _faker.slug()}

mock_signature = self.mock_celery_app.signature.return_value
mock_async_result = mock_signature.delay.return_value

request = self.factory.post(
"/task_run/",
data={"task_name": task_name, "kwargs": task_kwargs},
format="json",
)
response = task_run(request)

self.mock_celery_app.signature.assert_called_with(task_name)
mock_signature.delay.assert_called_with(**task_kwargs)

self.assertEqual(response.status_code, 201)
self.assertEqual(response.data["async_result_id"], mock_async_result.id)

def test_task_run_bad_post_params(self):
"""Test error is returned and task is not called when bad input is given."""
task_name = _faker.slug()
request = self.factory.post(
"/task_run/",
data={task_name: task_name, "unrelated_potato_argument": task_name},
format="json",
)
response = task_run(request)
self.assertEqual(response.status_code, 400)
self.mock_celery_app.signature.assert_not_called()

def test_task_run_bad_task_kwargs(self):
"""
Test error is returned when task is called when bad kwargs.

Celery's delay raises TypeError if the arguments it is given do not match the
actual function signature of the task being called. Example:

$ http localhost:8000/internal/tasks/ \
task_name="api.tasks.enable_account" kwargs:='{"potato":1}'

HTTP/1.1 400 Bad Request
{"error":"enable_account() got an unexpected keyword argument 'potato'"}
"""
task_name = _faker.slug()
task_kwargs = {_faker.slug(): _faker.slug(), _faker.slug(): _faker.slug()}

type_error_message = _faker.sentence()
mock_signature = self.mock_celery_app.signature.return_value
mock_signature.delay.side_effect = TypeError(type_error_message)

request = self.factory.post(
"/task_run/",
data={"task_name": task_name, "kwargs": task_kwargs},
format="json",
)
response = task_run(request)

self.mock_celery_app.signature.assert_called_with(task_name)
mock_signature.delay.assert_called_with(**task_kwargs)

self.assertEqual(response.status_code, 400)
self.assertEqual(response.data["error"], type_error_message)
6 changes: 6 additions & 0 deletions cloudigrade/internal/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,12 @@ class PermissiveAPIRootView(routers.APIRootView):
),
path("recalculate_runs/", views.recalculate_runs, name="internal-recalculate-runs"),
path("redis_raw/", views.redis_raw, name="internal-redis-raw"),
path("tasks/", views.task_run, name="internal-task-run"),
path(
"tasks/<str:async_result_id>/",
views.task_get,
name="internal-task-collect",
),
path("sources_kafka/", views.sources_kafka, name="internal-sources-kafka"),
]

Expand Down
82 changes: 82 additions & 0 deletions cloudigrade/internal/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import os

from app_common_python import isClowderEnabled
from celery import current_app as celery_app
from celery.exceptions import NotRegistered
from celery.result import AsyncResult
from dateutil import tz
from dateutil.parser import ParserError, parse
from django.conf import settings
Expand Down Expand Up @@ -504,3 +507,82 @@ def redis_raw(request):
return Response(status=status.HTTP_500_INTERNAL_SERVER_ERROR)

return Response({"results": results}, status=status.HTTP_200_OK)


@api_view(["POST"])
@authentication_classes([IdentityHeaderAuthenticationInternal])
@permission_classes([permissions.AllowAny])
@schema(None)
def task_run(request):
"""
Execute an arbitrary task with optional keyword arguments.

This is an internal only API, so we do not want it to be in the openapi.spec.
Typical calls to this internal HTTP API would look like:

http localhost:8000/internal/tasks/ \
task_name="api.tasks.persist_inspection_cluster_results_task"

http localhost:8000/internal/tasks/ \
task_name="api.tasks.synthesize_concurrent_usage" \
kwargs:='{"__": null, "user_id": 1, "on_date": "2022-01-01"}'
"""
serializer = serializers.InternalRunTaskInputSerializer(data=request.data)
if serializer.is_valid(raise_exception=True):
task_name = serializer.validated_data["task_name"]
kwargs = serializer.validated_data.get("kwargs", {})

signature = celery_app.signature(task_name)
try:
async_result = signature.delay(**kwargs)
except TypeError as e:
# TypeError may be raised if you send bad arguments to the delay function.
logger.info(e)
type_error_arg = e.args[0] if getattr(e, "args") else None
return Response(
{"error": type_error_arg},
status=status.HTTP_400_BAD_REQUEST,
)
return Response(
{"async_result_id": async_result.id}, status=status.HTTP_201_CREATED
)


@api_view(["GET"])
@authentication_classes([IdentityHeaderAuthenticationInternal])
@permission_classes([permissions.AllowAny])
@schema(None)
def task_get(request, async_result_id):
"""
Get the async result of a task by its async result ID.

This is an internal only API, so we do not want it to be in the openapi.spec.
Typical calls to this internal HTTP API would look like:

http localhost:8000/internal/tasks/e5a666a1-dadc-4de8-8f6e-53d4a0bec3a0/
"""
try:
async_result = AsyncResult(async_result_id)
ready = async_result.ready()
result = async_result.get() if ready else None
return Response(
{"async_result_id": async_result_id, "ready": ready, "result": result}
)
except NotRegistered as e:
abellotti marked this conversation as resolved.
Show resolved Hide resolved
task_name = e.args[0] if getattr(e, "args") else None
return Response(
{
"async_result_id": async_result_id,
"error": f"not a registered task name: {task_name}",
}
)
except Exception as e:
logger.info(e)
error_args = [str(arg) for arg in getattr(e, "args", [])]
return Response(
{
"async_result_id": async_result_id,
"error_class": f"{e.__class__.__module__}.{e.__class__.__name__}",
"error_args": error_args,
}
)