Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update obs end #75

Merged
merged 10 commits into from
Sep 16, 2019
40 changes: 40 additions & 0 deletions observation_portal/observations/serializers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from rest_framework import serializers
from django.utils import timezone
from django.core.cache import cache
from django.db import transaction
from django.utils.translation import ugettext as _

Expand All @@ -10,6 +11,10 @@
from observation_portal.requestgroups.models import RequestGroup, AcquisitionConfig, GuidingConfig, Target
from observation_portal.proposals.models import Proposal

import logging

logger = logging.getLogger()


class SummarySerializer(serializers.ModelSerializer):
class Meta:
Expand Down Expand Up @@ -285,6 +290,14 @@ class Meta:
fields = ('site', 'enclosure', 'telescope', 'start', 'end', 'priority', 'configuration_statuses', 'request')

def validate(self, data):
if self.context.get('request').method == 'PATCH':
# For a partial update, only validate that the 'end' field is set, and that it is > now
if 'end' not in data:
raise serializers.ValidationError(_('Observation update must include `end` field'))
if data['end'] <= timezone.now():
raise serializers.ValidationError(_('Updated end time must be in the future'))
return data

if data['end'] <= data['start']:
raise serializers.ValidationError(_('End time must be after start time'))

Expand Down Expand Up @@ -330,6 +343,33 @@ def validate(self, data):

return data

def update(self, instance, validated_data):
if validated_data['end'] > instance.start:
# Only update the end time if it is > start time
old_end_time = instance.end
instance.end = validated_data['end']
instance.save()
# Cancel observations that used to be under this observation
if instance.end > old_end_time:
observations = Observation.objects.filter(
site=instance.site,
enclosure=instance.enclosure,
telescope=instance.telescope,
start__lte=instance.end,
start__gte=old_end_time,
state='PENDING'
)
if instance.request.request_group.observation_type != RequestGroup.RAPID_RESPONSE:
observations = observations.exclude(
request__request_group__observation_type=RequestGroup.RAPID_RESPONSE
)
num_canceled = Observation.cancel(observations)
logger.info(
f"updated end time for observation {instance.id} to {instance.end}. Canceled {num_canceled} overlapping observations.")
cache.set('observation_portal_last_change_time', timezone.now(), None)

return instance

def create(self, validated_data):
configuration_statuses = validated_data.pop('configuration_statuses')
with transaction.atomic():
Expand Down
1 change: 0 additions & 1 deletion observation_portal/observations/signals/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,3 @@ def cb_summary_pre_save(sender, instance, *args, **kwargs):
else:
current_summary = None
on_summary_update_time_accounting(current_summary, instance)

131 changes: 128 additions & 3 deletions observation_portal/observations/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,14 +403,15 @@ def setUp(self):
configuration.save()

@staticmethod
def _generate_observation_data(request_id, configuration_id_list, guide_camera_name='xx03'):
def _generate_observation_data(request_id, configuration_id_list, guide_camera_name='xx03',
start="2016-09-05T22:35:39Z", end="2016-09-05T23:35:40Z"):
observation = {
"request": request_id,
"site": "tst",
"enclosure": "domb",
"telescope": "1m0a",
"start": "2016-09-05T22:35:39Z",
"end": "2016-09-05T23:35:40Z",
"start": start,
"end": end,
"configuration_statuses": []
}
for configuration_id in configuration_id_list:
Expand Down Expand Up @@ -946,6 +947,130 @@ def test_update_summary_triggers_request_status_without_completing(self):
self.assertEqual(self.requestgroup.state, 'PENDING')


class TestUpdateObservationApi(TestObservationApiBase):
def setUp(self):
super().setUp()

@staticmethod
def _create_clone_observation(observation, start, end):
return mixer.blend(
Observation,
site=observation.site,
enclosure=observation.enclosure,
telescope=observation.telescope,
start=start.replace(tzinfo=timezone.utc),
end=end.replace(tzinfo=timezone.utc),
request=observation.request
)

def test_update_observation_end_time_succeeds(self):
original_end = datetime(2016, 9, 5, 23, 35, 40).replace(tzinfo=timezone.utc)
observation = self._generate_observation_data(
self.requestgroup.requests.first().id, [self.requestgroup.requests.first().configurations.first().id]
)
self._create_observation(observation)
observation = Observation.objects.first()
self.assertEqual(observation.end, original_end)

new_end = datetime(2016, 9, 5, 23, 47, 22).replace(tzinfo=timezone.utc)
update_data = {"end": datetime.strftime(new_end, '%Y-%m-%dT%H:%M:%SZ')}
self.client.patch(reverse('api:observations-detail', args=(observation.id,)), update_data)
observation.refresh_from_db()
self.assertEqual(observation.end, new_end)

def test_update_observation_end_time_cancels_proper_overlapping_observations(self):
self.window.start = datetime(2016, 9, 1, tzinfo=timezone.utc)
self.window.save()
observation = self._generate_observation_data(
self.requestgroup.requests.first().id, [self.requestgroup.requests.first().configurations.first().id],
start="2016-09-02T22:35:39Z",
end="2016-09-02T23:35:40Z"
)
self._create_observation(observation)
observation = Observation.objects.first()
cancel_obs_1 = self._create_clone_observation(observation, datetime(2016, 9, 2, 23, 35, 41), datetime(2016, 9, 2, 23, 39, 59))
cancel_obs_2 = self._create_clone_observation(observation, datetime(2016, 9, 2, 23, 42, 0), datetime(2016, 9, 2, 23, 55, 34))
extra_obs_1 = self._create_clone_observation(observation, datetime(2016, 9, 2, 23, 55, 35), datetime(2016, 9, 3, 0, 14, 21))
rr_obs_1 = self._create_clone_observation(observation, datetime(2016, 9, 2, 23, 40, 0), datetime(2016, 9, 2, 23, 41, 59))
rr_requestgroup = create_simple_requestgroup(self.user, self.proposal, window=self.window, location=self.location)
rr_requestgroup.observation_type = RequestGroup.RAPID_RESPONSE
rr_requestgroup.save()
rr_obs_1.request = rr_requestgroup.requests.first()
rr_obs_1.save()

new_end = datetime(2016, 9, 2, 23, 47, 22).replace(tzinfo=timezone.utc)
update_data = {"end": datetime.strftime(new_end, '%Y-%m-%dT%H:%M:%SZ')}
self.client.patch(reverse('api:observations-detail', args=(observation.id,)), update_data)
observation.refresh_from_db()
self.assertEqual(observation.end, new_end)
cancel_obs_1.refresh_from_db()
self.assertEqual(cancel_obs_1.state, 'CANCELED')
cancel_obs_2.refresh_from_db()
self.assertEqual(cancel_obs_2.state, 'CANCELED')
extra_obs_1.refresh_from_db()
self.assertEqual(extra_obs_1.state, 'PENDING')
rr_obs_1.refresh_from_db()
self.assertEqual(rr_obs_1.state, 'PENDING')

def test_update_observation_end_time_rr_cancels_overlapping_rr(self):
self.window.start = datetime(2016, 9, 1, tzinfo=timezone.utc)
self.window.save()
self.requestgroup.observation_type = RequestGroup.RAPID_RESPONSE
self.requestgroup.save()
observation = self._generate_observation_data(
self.requestgroup.requests.first().id, [self.requestgroup.requests.first().configurations.first().id],
start="2016-09-02T22:35:39Z",
end="2016-09-02T23:35:40Z"
)
self._create_observation(observation)
observation = Observation.objects.first()
cancel_obs_1 = self._create_clone_observation(observation, datetime(2016, 9, 2, 23, 35, 41), datetime(2016, 9, 2, 23, 39, 59))
new_end = datetime(2016, 9, 2, 23, 47, 22).replace(tzinfo=timezone.utc)
update_data = {"end": datetime.strftime(new_end, '%Y-%m-%dT%H:%M:%SZ')}
self.client.patch(reverse('api:observations-detail', args=(observation.id,)), update_data)
cancel_obs_1.refresh_from_db()
self.assertEqual(cancel_obs_1.state, 'CANCELED')

def test_update_observation_end_before_start_does_nothing(self):
original_end = datetime(2016, 9, 5, 23, 35, 40).replace(tzinfo=timezone.utc)
observation = self._generate_observation_data(
self.requestgroup.requests.first().id, [self.requestgroup.requests.first().configurations.first().id]
)
self._create_observation(observation)
observation = Observation.objects.first()

new_end = datetime(2016, 9, 5, 19, 35, 40).replace(tzinfo=timezone.utc)
update_data = {"end": datetime.strftime(new_end, '%Y-%m-%dT%H:%M:%SZ')}
self.client.patch(reverse('api:observations-detail', args=(observation.id,)), update_data)
observation.refresh_from_db()
self.assertEqual(observation.end, original_end)

def test_update_observation_end_must_be_in_future(self):
observation = self._generate_observation_data(
self.requestgroup.requests.first().id, [self.requestgroup.requests.first().configurations.first().id]
)
self._create_observation(observation)
observation = Observation.objects.first()

new_end = datetime(2016, 8, 5, 19, 35, 40).replace(tzinfo=timezone.utc)
update_data = {"end": datetime.strftime(new_end, '%Y-%m-%dT%H:%M:%SZ')}
response = self.client.patch(reverse('api:observations-detail', args=(observation.id,)), update_data)
self.assertEqual(response.status_code, 400)
self.assertEqual(response.json()['non_field_errors'], ['Updated end time must be in the future'])

def test_update_observation_update_must_include_end(self):
observation = self._generate_observation_data(
self.requestgroup.requests.first().id, [self.requestgroup.requests.first().configurations.first().id]
)
self._create_observation(observation)
observation = Observation.objects.first()

update_data = {'field_1': 'testtest', 'not_end': 2341}
response = self.client.patch(reverse('api:observations-detail', args=(observation.id,)), update_data)
self.assertEqual(response.status_code, 400)
self.assertEqual(response.json()['non_field_errors'], ['Observation update must include `end` field'])


class TestLastScheduled(TestObservationApiBase):
def setUp(self):
super().setUp()
Expand Down
2 changes: 1 addition & 1 deletion observation_portal/observations/viewsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_queryset(self):

class ObservationViewSet(CreateListModelMixin, ListAsDictMixin, viewsets.ModelViewSet):
permission_classes = (IsAdminUser,)
http_method_names = ['get', 'post', 'head', 'options']
http_method_names = ['get', 'post', 'head', 'options', 'patch']
filter_class = ObservationFilter
serializer_class = ObservationSerializer
filter_backends = (
Expand Down