Skip to content

Commit

Permalink
Merge 03970f3 into 79f71ef
Browse files Browse the repository at this point in the history
  • Loading branch information
ckirby committed Oct 10, 2016
2 parents 79f71ef + 03970f3 commit 8114bd5
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 2 deletions.
4 changes: 4 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ Keyword Arguments
for every row in the database by providing a dictionary
with the name of the columns as keys and the static
inputs as values.

``ignore_headers`` A list of headers from your csv that don't have
equivalent fields in your model. These columns will
be ignored.
===================== =====================================================


Expand Down
15 changes: 14 additions & 1 deletion postgres_copy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(
model,
csv_path,
mapping,
ignore_headers=None,
using=None,
delimiter=',',
null=None,
Expand All @@ -37,6 +38,10 @@ def __init__(
if self.conn.vendor != 'postgresql':
raise TypeError("Only PostgreSQL backends supported")
self.backend = self.conn.ops
if ignore_headers is None:
self.ignore_headers = []
else:
self.ignore_headers = ignore_headers
self.delimiter = delimiter
self.null = null
self.encoding = encoding
Expand All @@ -48,13 +53,17 @@ def __init__(
# Connect the headers from the CSV with the fields on the model
self.field_header_crosswalk = []
inverse_mapping = {v: k for k, v in self.mapping.items()}
for ignore in self.ignore_headers:
inverse_mapping[ignore] = ignore.lower()
for h in self.get_headers():
try:
f_name = inverse_mapping[h]
except KeyError:
raise ValueError("Map does not include %s field" % h)
try:
f = [f for f in self.model._meta.fields if f.name == f_name][0]
if f_name not in [ih.lower() for ih in self.ignore_headers]:
f = [f for f in self.model._meta.fields
if f.name == f_name][0]
except IndexError:
raise ValueError("Model does not include %s field" % f_name)
self.field_header_crosswalk.append((f, h))
Expand Down Expand Up @@ -204,6 +213,8 @@ def prep_insert(self):
model_fields = []

for field, header in self.field_header_crosswalk:
if header in self.ignore_headers:
continue
model_fields.append('"%s"' % field.get_attname_column()[1])

for k in self.static_mapping.keys():
Expand All @@ -213,6 +224,8 @@ def prep_insert(self):

temp_fields = []
for field, header in self.field_header_crosswalk:
if header in self.ignore_headers:
continue
string = '"%s"' % header
if hasattr(field, 'copy_template'):
string = field.copy_template % dict(name=header)
Expand Down
13 changes: 13 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,16 @@ class Meta:
def copy_name_template(self):
return 'upper("%(name)s")'
copy_name_template.copy_type = 'text'


class LimitedMockObject(models.Model):
name = models.CharField(max_length=500)
dt = models.DateField(null=True)

class Meta:
app_label = 'tests'

def copy_name_template(self):
return 'upper("%(name)s")'
copy_name_template.copy_type = 'text'

28 changes: 27 additions & 1 deletion tests/tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from datetime import date
from .models import MockObject, ExtendedMockObject
from .models import MockObject, ExtendedMockObject, LimitedMockObject
from postgres_copy import CopyMapping
from django.test import TestCase

Expand All @@ -18,6 +18,7 @@ def setUp(self):
def tearDown(self):
MockObject.objects.all().delete()
ExtendedMockObject.objects.all().delete()
LimitedMockObject.objects.all().delete()

def test_bad_call(self):
with self.assertRaises(TypeError):
Expand Down Expand Up @@ -57,6 +58,17 @@ def test_bad_field(self):
dict(name1='NAME', number='NUMBER', dt='DATE'),
)

def test_limited_fields(self):
try:
CopyMapping(
LimitedMockObject,
self.name_path,
dict(name='NAME', dt='DATE'),
ignore_headers=['NUMBER']
)
except ValueError:
self.fail("Failed trying to ignore fields")

def test_simple_save(self):
c = CopyMapping(
MockObject,
Expand All @@ -71,6 +83,20 @@ def test_simple_save(self):
date(2012, 1, 1)
)

def test_limited_save(self):
c = CopyMapping(
LimitedMockObject,
self.name_path,
dict(name='NAME', dt='DATE'),
ignore_headers=['NUMBER']
)
c.save()
self.assertEqual(LimitedMockObject.objects.count(), 3)
self.assertEqual(
LimitedMockObject.objects.get(name='BEN').dt,
date(2012, 1, 1)
)

def test_save_foreign_key(self):
c = CopyMapping(
MockObject,
Expand Down

0 comments on commit 8114bd5

Please sign in to comment.