Skip to content

Commit

Permalink
Merge e7ec401 into bdd99d1
Browse files Browse the repository at this point in the history
  • Loading branch information
theGOTOguy committed Nov 29, 2018
2 parents bdd99d1 + e7ec401 commit 4c00ce5
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 26 deletions.
47 changes: 42 additions & 5 deletions postgres_copy/copy_from.py
Expand Up @@ -4,7 +4,9 @@
Handlers for working with PostgreSQL's COPY command.
"""
import os
import six
import sys
import tempfile
import csv
import logging
from collections import OrderedDict
Expand Down Expand Up @@ -32,7 +34,8 @@ def __init__(
force_null=None,
encoding=None,
ignore_conflicts=False,
static_mapping=None
static_mapping=None,
max_rows=None
):
# Set the required arguments
self.model = model
Expand All @@ -57,6 +60,8 @@ def __init__(
self.encoding = encoding
self.supports_ignore_conflicts = True
self.ignore_conflicts = ignore_conflicts
self.max_rows = max_rows

if static_mapping is not None:
self.static_mapping = OrderedDict(static_mapping)
else:
Expand Down Expand Up @@ -267,20 +272,52 @@ def pre_copy(self, cursor):
def copy(self, cursor):
"""
Generate and run the COPY command to copy data from csv to temp table.
Calls `self.pre_copy(cursor)` and `self.post_copy(cursor)` respectively
before and after running copy
cursor:
A cursor object on the db
A cursor object on the db
"""
# Run pre-copy hook
self.pre_copy(cursor)

logger.debug("Running COPY command")
copy_sql = self.prep_copy()
logger.debug(copy_sql)
cursor.copy_expert(copy_sql, self.csv_file)

if not self.max_rows:
cursor.copy_expert(copy_sql, self.csv_file)
else: # Split the CSV up into smaller tables.
# This header will be shared across many smaller temp files.
header = self.csv_file.readline()

# Keep going through the whole source file.
line = True
while line:
line_count = 0
# Create a temp file with the prescribed number of rows from
# the source file.
encoding_args = {}
if six.PY3:
# Postgres COPY expects UTF-8 format, only.
encoding_args = {'encoding': 'utf-8'}

with tempfile.NamedTemporaryFile(
"w", delete=False, **encoding_args) as chunk_file:
chunk_file.write(header)

while line_count < self.max_rows and line:
line = self.csv_file.readline()
if line:
chunk_file.write(line)
line_count += 1

chunk_file.close()

with open(chunk_file.name, "r") as chunk_stream:
cursor.copy_expert(copy_sql, chunk_stream)

# We're done with the temp file now, go ahead and get rid of it.
os.remove(chunk_file.name)

# At this point all data has been loaded to the temp table
self.csv_file.close()
Expand Down
9 changes: 5 additions & 4 deletions postgres_copy/managers.py
Expand Up @@ -39,8 +39,8 @@ def edit_schema(self, schema_editor, method_name, args):
"""
try:
getattr(schema_editor, method_name)(*args)
except Exception:
logger.debug("Edit of {}.{} failed. Skipped".format(schema_editor, method_name))
except Exception as err:
logger.debug("Edit of {}.{} failed with message {}. Skipped".format(schema_editor, method_name, str(err)))
pass

def drop_constraints(self):
Expand Down Expand Up @@ -128,7 +128,8 @@ class CopyQuerySet(ConstraintQuerySet):
"""
Subclass of QuerySet that adds from_csv and to_csv methods.
"""
def from_csv(self, csv_path, mapping=None, drop_constraints=True, drop_indexes=True, silent=True, **kwargs):
def from_csv(self, csv_path, mapping=None, drop_constraints=True,
drop_indexes=True, silent=True, batch_size=None, **kwargs):
"""
Copy CSV file from the provided path to the current model using the provided mapping.
"""
Expand All @@ -146,7 +147,7 @@ def from_csv(self, csv_path, mapping=None, drop_constraints=True, drop_indexes=T
"anyway. Either remove the transaction block, or set "
"drop_constraints=False and drop_indexes=False.")

mapping = CopyMapping(self.model, csv_path, mapping, **kwargs)
mapping = CopyMapping(self.model, csv_path, mapping, max_rows=batch_size, **kwargs)

if drop_constraints:
self.drop_constraints()
Expand Down
44 changes: 27 additions & 17 deletions tests/tests.py
Expand Up @@ -101,11 +101,19 @@ def test_export_to_file(self, _):
def test_export_to_str(self, _):
self._load_objects(self.name_path)
export = MockObject.objects.to_csv()
self.assertEqual(export, b"""id,name,num,dt,parent_id
86,BEN,1,2012-01-01,
87,JOE,2,2012-01-02,
88,JANE,3,2012-01-03,
""")
lines = export.decode('utf-8').split('\n')
self.assertEqual(5, len(lines))
self.assertEqual(lines[0], "id,name,num,dt,parent_id")

# Because the id field can vary depending on other tests,
# we only check that the individual data lines in the CSV
# end with the appropriate fields
self.assertTrue(lines[1].endswith(",BEN,1,2012-01-01,"))
self.assertTrue(lines[2].endswith(",JOE,2,2012-01-02,"))
self.assertTrue(lines[3].endswith(",JANE,3,2012-01-03,"))

# We end with a newline.
self.assertEqual(lines[4], '')

@mock.patch("django.db.connection.validate_no_atomic_block")
def test_export_header_setting(self, _):
Expand Down Expand Up @@ -350,18 +358,20 @@ def test_simple_save_with_fileobject(self, _):
date(2012, 1, 1)
)

def test_atomic_block(self):
with transaction.atomic():
try:
f = open(self.name_path, 'r')
MockObject.objects.from_csv(
f,
dict(name='NAME', number='NUMBER', dt='DATE')
)
self.fail("Expected TransactionManagementError.")
except TransactionManagementError:
# Expected
pass
@mock.patch("django.db.connection.validate_no_atomic_block")
def test_simple_save_with_fileobject_and_batches(self, _):
f = open(self.name_path, 'r')
MockObject.objects.from_csv(
f,
dict(name='NAME', number='NUMBER', dt='DATE'),
batch_size=1
)
self.assertEqual(MockObject.objects.count(), 3)
self.assertEqual(MockObject.objects.get(name='BEN').number, 1)
self.assertEqual(
MockObject.objects.get(name='BEN').dt,
date(2012, 1, 1)
)

@mock.patch("django.db.connection.validate_no_atomic_block")
def test_simple_save(self, _):
Expand Down

0 comments on commit 4c00ce5

Please sign in to comment.