Skip to content

Commit

Permalink
Merge ea821c2 into 86670c1
Browse files Browse the repository at this point in the history
  • Loading branch information
ckirby committed Feb 28, 2017
2 parents 86670c1 + ea821c2 commit d4d7593
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 18 deletions.
95 changes: 79 additions & 16 deletions postgres_copy/__init__.py
Expand Up @@ -89,10 +89,84 @@ def validate_mapping(self):
if not self.get_field(static_field):
raise ValueError("Model does not include %s field" % static_field)

def create(self, cursor):
"""
Generate and run create sql for the temp table.
Runs a DROP on same prior to CREATE to avoid collisions.
cursor:
A cursor object on the db
"""
self.drop(cursor)
create_sql = self.prep_create()
cursor.execute(create_sql)

def pre_copy(self, cursor):
pass

def copy(self, cursor):
"""
Generate and run the COPY command to copy data from csv to temp table.
Calls `self.pre_copy(cursor)` and `self.post_copy(cursor)` respectively
before and after running copy
cursor:
A cursor object on the db
"""
self.pre_copy(cursor)
copy_sql = self.prep_copy()
fp = open(self.csv_path, 'r')
cursor.copy_expert(copy_sql, fp)
self.post_copy(cursor)

def post_copy(self, cursor):
pass

def pre_insert(self, cursor):
pass

def insert(self, cursor):
"""
Generate and run the INSERT command to move data from the temp table
to the concrete table.
Calls `self.pre_copy(cursor)` and `self.post_copy(cursor)` respectively
before and after running copy
returns: the count of rows inserted
cursor:
A cursor object on the db
"""
self.pre_insert(cursor)
insert_sql = self.prep_insert()
cursor.execute(insert_sql)
insert_count = cursor.rowcount
self.post_insert(cursor)

return insert_count

def post_insert(self, cursor):
pass

def drop(self, cursor):
"""
Generate and run the DROP command for the temp table.
cursor:
A cursor object on the db
"""
drop_sql = self.prep_drop()
cursor.execute(drop_sql)

def save(self, silent=False, stream=sys.stdout):
"""
Saves the contents of the CSV file to the database.
Override this method and use 'self.create(cursor)`,
`self.copy(cursor)`, `self.insert(cursor)`, and `self.drop(cursor)`
if you need functionality other than the default create/copy/insert/drop
workflow.
silent:
By default, non-fatal error notifications are printed to stdout,
but this keyword may be set to disable these notifications.
Expand All @@ -106,22 +180,11 @@ def save(self, silent=False, stream=sys.stdout):
stream.write("Loading CSV to %s\n" % self.model.__name__)

# Connect to the database
cursor = self.conn.cursor()

# Create all of the raw SQL
drop_sql = self.prep_drop()
create_sql = self.prep_create()
copy_sql = self.prep_copy()
insert_sql = self.prep_insert()

# Run all of the raw SQL
cursor.execute(drop_sql)
cursor.execute(create_sql)
fp = open(self.csv_path, 'r')
cursor.copy_expert(copy_sql, fp)
cursor.execute(insert_sql)
insert_count = cursor.rowcount
cursor.execute(drop_sql)
with self.conn.cursor() as c:
self.create(c)
self.copy(c)
insert_count = self.insert(c)
self.drop(c)

if not silent:
stream.write(
Expand Down
15 changes: 15 additions & 0 deletions tests/models.py
@@ -1,5 +1,6 @@
from django.db import models
from .fields import MyIntegerField
from postgres_copy import CopyMapping


class MockObject(models.Model):
Expand Down Expand Up @@ -60,3 +61,17 @@ def copy_upper_name_template(self):
def copy_lower_name_template(self):
return 'lower("%(name)s")'
copy_lower_name_template.copy_type = 'text'


class HookedCopyMapping(CopyMapping):
def pre_copy(self, cursor):
self.ran_pre_copy = True

def post_copy(self, cursor):
self.ran_post_copy = True

def pre_insert(self, cursor):
self.ran_pre_insert = True

def post_insert(self, cursor):
self.ran_post_insert = True
62 changes: 60 additions & 2 deletions tests/tests.py
@@ -1,11 +1,12 @@
import os
from datetime import date

from .models import (
MockObject,
ExtendedMockObject,
LimitedMockObject,
OverloadMockObject
)
OverloadMockObject,
HookedCopyMapping)
from postgres_copy import CopyMapping
from django.test import TestCase

Expand Down Expand Up @@ -269,3 +270,60 @@ def test_missing_overload_field(self):
self.name_path,
dict(name='NAME', number='NUMBER', dt='DATE', missing='NAME'),
)


def test_save_steps(self):
c = CopyMapping(
MockObject,
self.name_path,
dict(name='NAME', number='NUMBER', dt='DATE'),
)
cursor = c.conn.cursor()

c.create(cursor)
cursor.execute("""SELECT count(*) FROM %s;""" % c.temp_table_name)
self.assertEquals(cursor.fetchone()[0], 0)
cursor.execute("""SELECT count(*) FROM %s;""" % c.model._meta.db_table)
self.assertEquals(cursor.fetchone()[0], 0)

c.copy(cursor)
cursor.execute("""SELECT count(*) FROM %s;""" % c.temp_table_name)
self.assertEquals(cursor.fetchone()[0], 3)
cursor.execute("""SELECT count(*) FROM %s;""" % c.model._meta.db_table)
self.assertEquals(cursor.fetchone()[0], 0)

c.insert(cursor)
cursor.execute("""SELECT count(*) FROM %s;""" % c.model._meta.db_table)
self.assertEquals(cursor.fetchone()[0], 3)

c.drop(cursor)
self.assertEquals(cursor.statusmessage, 'DROP TABLE')
cursor.close()

def test_hooks(self):
c = HookedCopyMapping(
MockObject,
self.name_path,
dict(name='NAME', number='NUMBER', dt='DATE'),
)
cursor = c.conn.cursor()

c.create(cursor)
self.assertRaises(AttributeError, lambda: c.ran_pre_copy)
self.assertRaises(AttributeError, lambda: c.ran_post_copy)
self.assertRaises(AttributeError, lambda: c.ran_pre_insert)
self.assertRaises(AttributeError, lambda: c.ran_post_insert)
c.copy(cursor)
self.assertTrue(c.ran_pre_copy)
self.assertTrue(c.ran_post_copy)
self.assertRaises(AttributeError, lambda: c.ran_pre_insert)
self.assertRaises(AttributeError, lambda: c.ran_post_insert)

c.insert(cursor)
self.assertTrue(c.ran_pre_copy)
self.assertTrue(c.ran_post_copy)
self.assertTrue(c.ran_pre_insert)
self.assertTrue(c.ran_post_insert)

c.drop(cursor)
cursor.close()

0 comments on commit d4d7593

Please sign in to comment.