Skip to content

Commit

Permalink
Merge 1103952 into 5d3b81c
Browse files Browse the repository at this point in the history
  • Loading branch information
igncampa committed Apr 12, 2018
2 parents 5d3b81c + 1103952 commit cab9677
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 3 deletions.
40 changes: 37 additions & 3 deletions postgres_copy/copy_from.py
Expand Up @@ -153,20 +153,54 @@ def validate_mapping(self):
Raises errors if something goes wrong. Returns nothing if everything is kosher.
"""
# Get the model name for a more verbose output on error
_model_name = str(self.model._meta).split('.')[1]

# Make sure all of the CSV headers in the mapping actually exist
failing_headers = []
for map_header in self.mapping.values():
if map_header not in self.headers:
raise ValueError("Header '{}' not found in CSV file".format(map_header))
failing_headers.append(map_header)
if len(failing_headers) == 1:
raise ValueError("Header '{}' not found in CSV file.".format(failing_headers[0]))
elif len(failing_headers) > 1:
raise ValueError("Headers '{}' not found in CSV file.".format(
"', '".join(h for h in failing_headers)
))
else:
pass

# Make sure all the model fields in the mapping actually exist
failing_fields = []
for map_field in self.mapping.keys():
if not self.get_field(map_field):
raise FieldDoesNotExist("Model does not include {} field".format(map_field))
failing_fields.append(map_field)
if len(failing_fields) == 1:
raise FieldDoesNotExist("Model '{}' does not include field '{}'.".format(
_model_name, failing_fields[0]
))
elif len(failing_fields) > 1:
raise FieldDoesNotExist("Model '{}' does not include fields '{}'.".format(
_model_name, "', '".join(f for f in failing_fields)
))
else:
pass

# Make sure any static mapping columns exist
failing_static = []
for static_field in self.static_mapping.keys():
if not self.get_field(static_field):
raise ValueError("Model does not include {} field".format(static_field))
failing_static.append(static_field)
if len(failing_static) == 1:
raise ValueError("Model '{}' does not include field '{}'.".format(
_model_name, failing_static[0]
))
elif len(failing_static) > 1:
raise ValueError("Model '{}' does not include fields '{}'.".format(
_model_name, "', '".join(s for s in failing_static)
))
else:
pass

#
# CREATE commands
Expand Down
53 changes: 53 additions & 0 deletions tests/tests.py
@@ -1,5 +1,7 @@
import os
import csv
import sys
from collections import OrderedDict
from datetime import date
from .models import (
MockObject,
Expand All @@ -19,6 +21,9 @@

class BaseTest(TestCase):

if sys.version_info.major == 2:
assertRaisesRegex = unittest.TestCase.assertRaisesRegexp

def setUp(self):
self.data_dir = os.path.join(os.path.dirname(__file__), 'data')
self.name_path = os.path.join(self.data_dir, 'names.csv')
Expand Down Expand Up @@ -281,6 +286,22 @@ def test_bad_header(self):
dict(name='NAME1', number='NUMBER', dt='DATE'),
)

def test_bad_header_error_msg(self):
with self.assertRaisesRegex(ValueError, "Header 'NAME1' not found in CSV file."):
CopyMapping(
MockObject,
self.name_path,
dict(name='NAME1', number='NUMBER', dt='DATE'),
)

def test_multiple_bad_headers_error_msg(self):
with self.assertRaisesRegex(ValueError, "Headers 'NAME1', 'NUMBER1' not found in CSV file."):
CopyMapping(
MockObject,
self.name_path,
OrderedDict(name='NAME1', number='NUMBER1', dt='DATE'),
)

def test_bad_field(self):
with self.assertRaises(FieldDoesNotExist):
CopyMapping(
Expand All @@ -289,6 +310,22 @@ def test_bad_field(self):
dict(name1='NAME', number='NUMBER', dt='DATE'),
)

def test_bad_field_error_msg(self):
with self.assertRaisesRegex(FieldDoesNotExist, "Model 'mockobject' does not include field 'name1'."):
CopyMapping(
MockObject,
self.name_path,
dict(name1='NAME', number='NUMBER', dt='DATE'),
)

def test_multiple_bad_fields_error_msg(self):
with self.assertRaisesRegex(FieldDoesNotExist, "Model 'mockobject' does not include fields 'name1', 'number1'."):
CopyMapping(
MockObject,
self.name_path,
OrderedDict(name1='NAME', number1='NUMBER', dt='DATE'),
)

def test_limited_fields(self):
CopyMapping(
LimitedMockObject,
Expand Down Expand Up @@ -519,6 +556,22 @@ def test_bad_static_values(self):
static_mapping=dict(static_bad=1)
)

def test_bad_static_values_error_msg(self):
with self.assertRaisesRegex(ValueError, "Model 'extendedmockobject' does not include field 'static_bad'."):
ExtendedMockObject.objects.from_csv(
self.name_path,
dict(name='NAME', number='NUMBER', dt='DATE'),
static_mapping=dict(static_bad=1)
)

def test_multiple_bad_static_values_error_msg(self):
with self.assertRaisesRegex(ValueError, "Model 'extendedmockobject' does not include fields 'static_bad1', 'static_bad2'."):
ExtendedMockObject.objects.from_csv(
self.name_path,
dict(name='NAME', number='NUMBER', dt='DATE'),
static_mapping=OrderedDict(static_bad1=1, static_bad2=2)
)

def test_overload_save(self):
OverloadMockObject.objects.from_csv(
self.name_path,
Expand Down

0 comments on commit cab9677

Please sign in to comment.