diff --git a/src/coldfront_plugin_api/serializers.py b/src/coldfront_plugin_api/serializers.py index 8a29059..8f29417 100644 --- a/src/coldfront_plugin_api/serializers.py +++ b/src/coldfront_plugin_api/serializers.py @@ -1,14 +1,29 @@ +import logging +from datetime import datetime, timedelta + from rest_framework import serializers -from coldfront.core.allocation.models import Allocation, AllocationAttribute -from coldfront.core.allocation.models import Project +from coldfront.core.allocation.models import ( + Allocation, + AllocationAttribute, + AllocationStatusChoice, + AllocationAttributeType, +) +from coldfront.core.allocation.models import Project, Resource +from coldfront.core.allocation import signals + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) class ProjectSerializer(serializers.ModelSerializer): class Meta: model = Project fields = ["id", "title", "pi", "description", "field_of_science", "status"] + read_only_fields = ["title", "pi", "description", "field_of_science", "status"] + id = serializers.IntegerField() pi = serializers.SerializerMethodField() field_of_science = serializers.SerializerMethodField() status = serializers.SerializerMethodField() @@ -23,15 +38,62 @@ def get_status(self, obj: Project) -> str: return obj.status.name +class AllocationAttributeSerializer(serializers.ModelSerializer): + class Meta: + model = AllocationAttribute + fields = ["attribute_type", "value"] + + attribute_type = ( + serializers.SlugRelatedField( # Peforms validation to ensure attribute exists + read_only=False, + slug_field="name", + queryset=AllocationAttributeType.objects.all(), + source="allocation_attribute_type", + ) + ) + value = serializers.CharField(read_only=False) + + +class ResourceSerializer(serializers.ModelSerializer): + class Meta: + model = Resource + fields = ["id", "name", "resource_type"] + + id = serializers.IntegerField() + name = serializers.CharField(required=False) + resource_type = serializers.SerializerMethodField(required=False) + + def get_resource_type(self, obj: Resource): + return obj.resource_type.name + + class AllocationSerializer(serializers.ModelSerializer): class Meta: model = Allocation - fields = ["id", "project", "description", "resource", "status", "attributes"] + fields = [ + "id", + "project", + "description", + "resource", + "status", + "attributes", + "requested_resource", + "requested_attributes", + ] resource = serializers.SerializerMethodField() project = ProjectSerializer() attributes = serializers.SerializerMethodField() - status = serializers.SerializerMethodField() + status = serializers.SlugRelatedField( + slug_field="name", queryset=AllocationStatusChoice.objects.all() + ) + + requested_attributes = AllocationAttributeSerializer( + many=True, source="allocationattribute_set", required=False, write_only=True + ) + requested_resource = serializers.SlugRelatedField( + slug_field="name", queryset=Resource.objects.all(), write_only=True + ) def get_resource(self, obj: Allocation) -> dict: resource = obj.resources.first() @@ -46,5 +108,53 @@ def get_attributes(self, obj: Allocation): for a in attrs } - def get_status(self, obj: Allocation) -> str: - return obj.status.name + def create(self, validated_data): + project_obj = Project.objects.get(id=validated_data["project"]["id"]) + allocation = Allocation.objects.create( + project=project_obj, + status=validated_data["status"], + justification="", + start_date=datetime.now(), + end_date=datetime.now() + timedelta(days=365), + ) + allocation.resources.add(validated_data["requested_resource"]) + allocation.save() + + for attribute in validated_data.pop("allocationattribute_set", []): + AllocationAttribute.objects.create( + allocation=allocation, + allocation_attribute_type=attribute["allocation_attribute_type"], + value=attribute["value"], + ) + + logger.info( + f"Created allocation {allocation.id} for project {project_obj.title}" + ) + return allocation + + def update(self, allocation: Allocation, validated_data): + """ + Only allow updating allocation status for now + + Certain status transitions will have side effects (activating/deactivating allocations) + """ + + old_status = allocation.status.name + new_status = validated_data.get("status", old_status).name + + allocation.status = validated_data.get("status", allocation.status) + allocation.save() + + if old_status == "New" and new_status == "Active": + signals.allocation_activate.send( + sender=self.__class__, allocation_pk=allocation.pk + ) + elif old_status == "Active" and new_status in ["Denied", "Revoked"]: + signals.allocation_disable.send( + sender=self.__class__, allocation_pk=allocation.pk + ) + + logger.info( + f"Updated allocation {allocation.id} for project {allocation.project.title}" + ) + return allocation diff --git a/src/coldfront_plugin_api/tests/unit/test_allocations.py b/src/coldfront_plugin_api/tests/unit/test_allocations.py index af71803..5b9ce0b 100644 --- a/src/coldfront_plugin_api/tests/unit/test_allocations.py +++ b/src/coldfront_plugin_api/tests/unit/test_allocations.py @@ -1,5 +1,7 @@ from os import devnull +from datetime import datetime, timedelta import sys +from unittest.mock import patch, ANY from coldfront.core.allocation import models as allocation_models from django.core.management import call_command @@ -146,3 +148,93 @@ def test_filter_allocations(self): "/api/allocations?fake_model_attribute=fake" ).json() self.assertEqual(r_json, []) + + def test_create_allocation(self): + user = self.new_user() + project = self.new_project(pi=user) + + payload = { + "requested_attributes": [ + {"attribute_type": "OpenShift Limit on CPU Quota", "value": 8}, + {"attribute_type": "OpenShift Limit on RAM Quota (MiB)", "value": 16}, + ], + "project": {"id": project.id}, + "requested_resource": self.resource.name, + "status": "New", + } + + self.admin_client.post("/api/allocations", payload, format="json") + + created_allocation = allocation_models.Allocation.objects.get( + project=project, + resources__in=[self.resource], + ) + self.assertEqual(created_allocation.status.name, "New") + self.assertEqual(created_allocation.justification, "") + self.assertEqual(created_allocation.start_date, datetime.now().date()) + self.assertEqual( + created_allocation.end_date, (datetime.now() + timedelta(days=365)).date() + ) + + allocation_models.AllocationAttribute.objects.get( + allocation=created_allocation, + allocation_attribute_type=allocation_models.AllocationAttributeType.objects.get( + name="OpenShift Limit on CPU Quota" + ), + value=8, + ) + allocation_models.AllocationAttribute.objects.get( + allocation=created_allocation, + allocation_attribute_type=allocation_models.AllocationAttributeType.objects.get( + name="OpenShift Limit on RAM Quota (MiB)" + ), + value=16, + ) + + def test_update_allocation_status_new_to_active(self): + user = self.new_user() + project = self.new_project(pi=user) + allocation = self.new_allocation(project, self.resource, 1) + allocation.status = allocation_models.AllocationStatusChoice.objects.get( + name="New" + ) + allocation.save() + + payload = {"status": "Active"} + + with patch( + "coldfront.core.allocation.signals.allocation_activate.send" + ) as mock_activate: + response = self.admin_client.patch( + f"/api/allocations/{allocation.id}?all=true", payload, format="json" + ) + self.assertEqual(response.status_code, 200) + allocation.refresh_from_db() + self.assertEqual(allocation.status.name, "Active") + mock_activate.assert_called_once_with( + sender=ANY, allocation_pk=allocation.pk + ) + + def test_update_allocation_status_active_to_denied(self): + user = self.new_user() + project = self.new_project(pi=user) + allocation = self.new_allocation(project, self.resource, 1) + allocation.status = allocation_models.AllocationStatusChoice.objects.get( + name="Active" + ) + allocation.save() + + payload = {"status": "Denied"} + + with patch( + "coldfront.core.allocation.signals.allocation_disable.send" + ) as mock_disable: + response = self.admin_client.patch( + f"/api/allocations/{allocation.id}", payload, format="json" + ) + self.assertEqual(response.status_code, 200) + allocation.refresh_from_db() + self.assertEqual(allocation.status.name, "Denied") + mock_disable.assert_called_once_with( + sender=ANY, allocation_pk=allocation.pk + ) diff --git a/src/coldfront_plugin_api/urls.py b/src/coldfront_plugin_api/urls.py index e5ac5a4..2e83fce 100644 --- a/src/coldfront_plugin_api/urls.py +++ b/src/coldfront_plugin_api/urls.py @@ -10,7 +10,7 @@ from coldfront_plugin_api import auth, serializers -class AllocationViewSet(viewsets.ReadOnlyModelViewSet): +class AllocationViewSet(viewsets.ModelViewSet): """ This viewset implements the API to Coldfront's allocation object The API allows filtering allocations by any of Coldfront's allocation model attributes, diff --git a/tools/create_allocations.py b/tools/create_allocations.py new file mode 100644 index 0000000..0ba5a9a --- /dev/null +++ b/tools/create_allocations.py @@ -0,0 +1,94 @@ +""" +Download ColdFront allocation data. + +Usage: + python3 download_allocation_data.py + +- Environment variables CLIENT_ID and CLIENT_SECRET must be set, +corresponding to a service account in Keycloak. +""" +import json +import logging +import os +import argparse + +import requests +from requests.auth import HTTPBasicAuth + + +API_URL = "http://localhost:8000/api/allocations" + + +logger = logging.getLogger() + + +class ColdFrontClient(object): + def __init__(self, keycloak_url, keycloak_client_id, keycloak_client_secret): + self.session = self.get_session( + keycloak_url, keycloak_client_id, keycloak_client_secret + ) + + @staticmethod + def get_session(keycloak_url, keycloak_client_id, keycloak_client_secret): + """Authenticate as a client with Keycloak to receive an access token.""" + token_url = f"{keycloak_url}/auth/realms/mss/protocol/openid-connect/token" + + r = requests.post( + token_url, + data={"grant_type": "client_credentials"}, + auth=HTTPBasicAuth(keycloak_client_id, keycloak_client_secret), + ) + client_token = r.json()["access_token"] + + session = requests.session() + headers = { + "Authorization": f"Bearer {client_token}", + "Content-Type": "application/json", + } + session.headers.update(headers) + return session + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "allocations_file", + help="JSON file containing list of allocations, must adhere to API specifications", + ) + parser.add_argument( + "--activate", + action="store_true", + help="If set, will also trigger the Coldfront `activate_allocations` signal." + "For OpenShift and OpenStack allocations, this will create the projects on remote clusters" + "Only works if the requested status choice is `New`", + ) + args = parser.parse_args() + + client = ColdFrontClient( + "https://keycloak.mss.mghpcc.org", + os.environ.get("CLIENT_ID"), + os.environ.get("CLIENT_SECRET"), + ) + + with open(args.allocations_file, "r") as f: + allocation_payloads = json.load(f) + + for allocation_payload in allocation_payloads: + r = client.session.post(url=API_URL, json=allocation_payload) + r.raise_for_status() + allocation_id = r.json()["id"] + + if args.activate: + allocation_payload["status"] = "Active" + r = client.session.put( + url=f"{API_URL}/{allocation_id}", json=allocation_payload + ) + r.raise_for_status() + + logger.info( + f"Created allocation {allocation_id} for project {allocation_payload["project"]["id"]}" + ) + + +if __name__ == "__main__": + main()