-
Notifications
You must be signed in to change notification settings - Fork 3
/
views.py
99 lines (85 loc) · 3.33 KB
/
views.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from django.conf import settings
from django.utils.translation import gettext_lazy as _
from rest_framework import generics
from rest_framework.response import Response
from rest_framework import status
from data.serializers import (
DataClassificationSerializer, DatasetClassificationSerializer
)
from data.models import Data
from ..models import CurrentClassifier
class Classify(generics.GenericAPIView):
"""
Endpoint for classifying data (single observation).
"""
serializer_class = DataClassificationSerializer
use_network = False
graph = False
_classifier = None
def get_data(self, serializer):
data = Data.apply_conversion_fields_rules_to_dict(
serializer.validated_data
)
return data
def get_classifier(self):
if not self._classifier:
self._classifier = CurrentClassifier.get()
return self._classifier
def predict(self, data):
if self.use_network:
return self.get_classifier().network_predict(data)
else:
return self.get_classifier().predict(data) + (None, )
def generate_graph(self, data):
if settings.GRAPHING:
return self.get_classifier().generate_graph(data)
return None
def post(self, request, format=None):
"""
Return the result of classification if data submitted is valid.
"""
self.use_network = \
self.request.GET.get('use_network', False) in ["True", "true"]
self.graph = \
self.request.GET.get('graph', False) in ["True", "true"]
if self.get_classifier():
serializer = self.serializer_class(data=request.data)
if serializer.is_valid():
(result, result_prob, votes) = \
self.predict(self.get_data(serializer))
response_content = {'result': result, 'prob': result_prob}
if self.use_network:
response_content['votes'] = votes
if self.graph:
response_content['graph'] = \
self.generate_graph(self.get_data(serializer))
return Response(response_content, status=status.HTTP_200_OK)
else:
return Response(
serializer.errors,
status=status.HTTP_400_BAD_REQUEST
)
return Response(
{'detail': _("Classification Unavailable")},
status=status.HTTP_503_SERVICE_UNAVAILABLE
)
def get(self, request, format=None):
if self.get_classifier():
metadata = self.get_classifier().get_local_classifier().metadata
metadata["chtuid"] = settings.CHTUID
return Response(metadata, status=status.HTTP_200_OK)
return Response(
{'detail': _("Classification Unavailable")},
status=status.HTTP_503_SERVICE_UNAVAILABLE
)
class ClassifyDataset(Classify):
"""
Endpoint for classifying datasets (multiple observations).
"""
serializer_class = DatasetClassificationSerializer
def get_data(self, serializer):
data = []
for s_vd in serializer.validated_data['dataset']:
d = Data.apply_conversion_fields_rules_to_dict(s_vd)
data.append(d)
return data