Skip to content

Commit

Permalink
Merge d1a176a into 7fe4fa6
Browse files Browse the repository at this point in the history
  • Loading branch information
pomegranited committed Aug 3, 2016
2 parents 7fe4fa6 + d1a176a commit 4f6aa11
Show file tree
Hide file tree
Showing 10 changed files with 429 additions and 53 deletions.
2 changes: 0 additions & 2 deletions analytics_data_api/constants/learner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
LEARNER_API_DEFAULT_LIST_PAGE_SIZE = 25

SEGMENTS = ["highly_engaged", "disengaging", "struggling", "inactive", "unenrolled"]
22 changes: 22 additions & 0 deletions analytics_data_api/renderers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
API renderers common to all versions of the API.
"""
from rest_framework_csv.renderers import CSVRenderer


class PaginatedCsvRenderer(CSVRenderer):
"""
Render CSV data using just the results array.
Use with PaginatedHeadersMixin to preserve the pagination links in the response header.
"""
results_field = 'results'
media_type = 'text/csv'

def render(self, data, *args, **kwargs):
"""
Replace the rendered data with just what is in the results_field.
"""
if not isinstance(data, list):
data = data.get(self.results_field, [])
return super(PaginatedCsvRenderer, self).render(data, *args, **kwargs)
37 changes: 36 additions & 1 deletion analytics_data_api/v0/serializers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from urlparse import urljoin
from django.conf import settings
from django.utils.datastructures import SortedDict
from rest_framework import pagination, serializers

from analytics_data_api.constants import (
Expand Down Expand Up @@ -45,6 +46,40 @@ class ModelSerializerWithCreatedField(serializers.ModelSerializer):
created = serializers.DateTimeField(format=settings.DATETIME_FORMAT)


class DynamicFieldsSerializerMixin(object):
"""
Allows the `fields` query parameter to determine which fields should be returned in the response.
"""
fields_sep = ','

def get_fields(self):
"""
Filter the list of available fields based on the list of fields passed to the request.
"""
fields = super(DynamicFieldsSerializerMixin, self).get_fields()

request = self.context.get('request')
if request:
request_fields = request.QUERY_PARAMS.get('fields')
if request_fields is not None:
# Include only fields that are specified in the `fields` argument,
# in the order they are given.
request_fields = request_fields.split(self.fields_sep)
allowed = SortedDict()
for field_name in request_fields:
if field_name in fields:
allowed[field_name] = fields[field_name]
fields = allowed

# Set the renderer's header attribute to the sorted fields list, if relevant.
# 'header' is used by the CSVRenderer to decide what order to display fields in.
renderer = request.accepted_renderer
if renderer and hasattr(renderer, 'header'):
renderer.header = fields.keys()

return fields


class ProblemSerializer(serializers.Serializer):
"""
Serializer for problems.
Expand Down Expand Up @@ -333,7 +368,7 @@ class LastUpdatedSerializer(serializers.Serializer):
last_updated = serializers.DateField(source='date', format=settings.DATE_FORMAT)


class LearnerSerializer(serializers.Serializer, DefaultIfNoneMixin):
class LearnerSerializer(DynamicFieldsSerializerMixin, serializers.Serializer, DefaultIfNoneMixin):
username = serializers.CharField(source='username')
enrollment_mode = serializers.CharField(source='enrollment_mode')
name = serializers.CharField(source='name')
Expand Down
61 changes: 61 additions & 0 deletions analytics_data_api/v0/tests/views/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import json
import StringIO
import csv

from opaque_keys.edx.keys import CourseKey
from rest_framework import status

from analytics_data_api.v0.tests.utils import flatten


DEMO_COURSE_ID = u'course-v1:edX+DemoX+Demo_2014'


Expand Down Expand Up @@ -36,3 +41,59 @@ def verify_bad_course_id(self, response, course_id='malformed-course-id'):
u"developer_message": u"Course id/key {} malformed.".format(course_id)
}
self.assertDictEqual(json.loads(response.content), expected)


class VerifyCsvResponseMixin(object):

def assertCsvResponseIsValid(self, response, expected_filename, expected_data=None, expected_headers=None):

# Validate the basic response status, content type, and filename
self.assertEquals(response.status_code, 200)
if expected_data:
self.assertEquals(response['Content-Type'].split(';')[0], 'text/csv')
self.assertEquals(response['Content-Disposition'], u'attachment; filename={}'.format(expected_filename))

# Validate other response headers
if expected_headers:
for header_name, header_content in expected_headers.iteritems():
self.assertEquals(response.get(header_name), header_content)

# Validate the content data
if expected_data:
data = map(flatten, expected_data)

# The CSV renderer sorts the headers alphabetically
fieldnames = sorted(data[0].keys())

# Generate the expected CSV output
expected = StringIO.StringIO()
writer = csv.DictWriter(expected, fieldnames)
writer.writeheader()
writer.writerows(data)
self.assertEqual(response.content, expected.getvalue())
else:
self.assertEqual(response.content, '')


class VerifyDynamicFieldsMixin(object):

def assertResponseFields(self, response, fields):
content_type = response.get('Content-Type', '').split(';')[0]
if content_type == 'text/csv':
return self.assertCsvResponseFields(response, fields)
else:
return self.assertJsonResponseFields(response, fields)

def assertCsvResponseFields(self, response, fields):
data = StringIO.StringIO(response.content)
reader = csv.DictReader(data)
for row in reader:
self.assertEqual(row.keys(), fields)
self.assertEqual(reader.fieldnames, fields)

def assertJsonResponseFields(self, response, fields):
data = json.loads(response.content)
results = data.get('results')
self.assertIsNotNone(results)
for row in results:
self.assertEquals(fields, row.keys())
25 changes: 3 additions & 22 deletions analytics_data_api/v0/tests/views/test_courses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
# change for versions greater than 1.0.0. Tests target a specific version of the API, additional tests should be added
# for subsequent versions if there are breaking changes introduced in those versions.

import StringIO
import csv
import datetime
from itertools import groupby
import urllib
Expand All @@ -17,8 +15,7 @@
from analytics_data_api.v0 import models
from analytics_data_api.constants import country, enrollment_modes, genders
from analytics_data_api.v0.models import CourseActivityWeekly
from analytics_data_api.v0.tests.utils import flatten
from analytics_data_api.v0.tests.views import DemoCourseMixin, DEMO_COURSE_ID
from analytics_data_api.v0.tests.views import DemoCourseMixin, VerifyCsvResponseMixin, DEMO_COURSE_ID
from analyticsdataserver.tests import TestCaseWithAuthentication


Expand All @@ -37,7 +34,7 @@ def test_default_fill(self):


# pylint: disable=no-member
class CourseViewTestCaseMixin(DemoCourseMixin):
class CourseViewTestCaseMixin(DemoCourseMixin, VerifyCsvResponseMixin):
model = None
api_root_path = '/api/v0/'
path = None
Expand Down Expand Up @@ -92,24 +89,8 @@ def assertCSVIsValid(self, course_id, filename):
csv_content_type = 'text/csv'
response = self.authenticated_get(path, HTTP_ACCEPT=csv_content_type)

# Validate the basic response status, content type, and filename
self.assertEquals(response.status_code, 200)
self.assertEquals(response['Content-Type'].split(';')[0], csv_content_type)
self.assertEquals(response['Content-Disposition'], u'attachment; filename={}'.format(filename))

# Validate the actual data
data = self.format_as_response(*self.get_latest_data(course_id=course_id))
data = map(flatten, data)

# The CSV renderer sorts the headers alphabetically
fieldnames = sorted(data[0].keys())

# Generate the expected CSV output
expected = StringIO.StringIO()
writer = csv.DictWriter(expected, fieldnames)
writer.writeheader()
writer.writerows(data)
self.assertEqual(response.content, expected.getvalue())
self.assertCsvResponseIsValid(response, filename, data)

def test_get_csv(self):
""" Verify the endpoint returns data that has been properly converted to CSV. """
Expand Down
Loading

0 comments on commit 4f6aa11

Please sign in to comment.