Skip to content

Commit

Permalink
Merge pull request #106 from california-civic-data-coalition/105
Browse files Browse the repository at this point in the history
Support binary mode of file-like objects passed to from_csv()
  • Loading branch information
palewire committed Apr 25, 2019
2 parents 14803db + d39e734 commit 032083e
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 9 deletions.
34 changes: 28 additions & 6 deletions postgres_copy/copy_from.py
Expand Up @@ -5,13 +5,13 @@
"""
import os
import sys
import csv
import logging
from collections import OrderedDict
from django.db import NotSupportedError
from django.db import connections, router
from django.core.exceptions import FieldDoesNotExist
from django.contrib.humanize.templatetags.humanize import intcomma
from django.utils.encoding import force_bytes, force_text
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -150,13 +150,35 @@ def get_headers(self):
Returns the column headers from the csv as a list.
"""
logger.debug("Retrieving headers from {}".format(self.csv_file))
# Open it as a CSV
csv_reader = csv.reader(self.csv_file, delimiter=self.delimiter)
# Pop the headers
headers = next(csv_reader)

# determine what mode the file is opened in
file_mode = getattr(
self.csv_file, 'mode', getattr(
self.csv_file, '_mode', None
)
)
# take the user-defined encoding, or assume utf-8
encoding = self.encoding or 'utf-8'
# if file is in binary mode...
if 'b' in file_mode:
# ...coerce delimiter to binary...
delimiter = force_bytes(self.delimiter, encoding=encoding)
# ...and coerce each header item to str (and strip whitespace)
headers = [
force_text(h, encoding=encoding).strip()
for h in self.csv_file.readline().split(delimiter)
]
# if not in binary mode...
else:
delimiter = self.delimiter
# ...just strip whitespace on each header item
headers = [
h.strip()
for h in self.csv_file.readline().split(delimiter)
]
# Move back to the top of the file
self.csv_file.seek(0)
# Return the headers

return headers

def validate_mapping(self):
Expand Down
20 changes: 17 additions & 3 deletions tests/tests.py
Expand Up @@ -102,9 +102,9 @@ def test_export_to_str(self, _):
self._load_objects(self.name_path)
export = MockObject.objects.to_csv()
self.assertEqual(export, b"""id,name,num,dt,parent_id
86,BEN,1,2012-01-01,
87,JOE,2,2012-01-02,
88,JANE,3,2012-01-03,
89,BEN,1,2012-01-01,
90,JOE,2,2012-01-02,
91,JANE,3,2012-01-03,
""")

@mock.patch("django.db.connection.validate_no_atomic_block")
Expand Down Expand Up @@ -350,6 +350,20 @@ def test_simple_save_with_fileobject(self, _):
date(2012, 1, 1)
)

@mock.patch("django.db.connection.validate_no_atomic_block")
def test_save_with_binary_fileobject(self, _):
f = open(self.name_path, 'rb')
MockObject.objects.from_csv(
f,
dict(name='NAME', number='NUMBER', dt='DATE')
)
self.assertEqual(MockObject.objects.count(), 3)
self.assertEqual(MockObject.objects.get(name='BEN').number, 1)
self.assertEqual(
MockObject.objects.get(name='BEN').dt,
date(2012, 1, 1)
)

def test_atomic_block(self):
with transaction.atomic():
try:
Expand Down

0 comments on commit 032083e

Please sign in to comment.