From fb2268a24c49f3a6cd8647a4eee62d338f9ae956 Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Fri, 7 Sep 2018 21:36:09 -0500 Subject: [PATCH] Allow fields/columns to be list of strings with bulk INSERT. Refs #1712. --- peewee.py | 13 +++++++++++-- tests/model_sql.py | 15 ++++++++++++++- tests/sql.py | 13 +++++++++++++ 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/peewee.py b/peewee.py index db4230a5e..71cbbff7f 100644 --- a/peewee.py +++ b/peewee.py @@ -2208,8 +2208,17 @@ def _generate_insert(self, insert, ctx): columns = sorted(accum, key=lambda obj: obj.get_sort_key(ctx)) rows_iter = itertools.chain(iter((row,)), rows_iter) else: - columns = list(columns) - value_lookups = dict((column, column) for column in columns) + clean_columns = [] + value_lookups = {} + for column in columns: + if isinstance(column, basestring): + column_obj = getattr(self.table, column) + else: + column_obj = column + value_lookups[column_obj] = column + clean_columns.append(column_obj) + + columns = clean_columns for col in sorted(defaults, key=lambda obj: obj.get_sort_key(ctx)): if col not in value_lookups: columns.append(col) diff --git a/tests/model_sql.py b/tests/model_sql.py index 3fcf165ce..d0001616a 100644 --- a/tests/model_sql.py +++ b/tests/model_sql.py @@ -14,7 +14,7 @@ class TestModelSQL(ModelDatabaseTestCase): database = get_in_memory_db() - requires = [Category, Note, Person, Relationship] + requires = [Category, Note, Person, Relationship, User] def test_select(self): query = (Person @@ -281,6 +281,19 @@ def test_insert_many(self): 'VALUES (?, ?), (?, ?)'), [1, 'note-1', 2, 'note-2']) + def test_insert_many_list_with_fields(self): + data = [(i,) for i in ('charlie', 'huey', 'zaizee')] + query = User.insert_many(data, fields=[User.username]) + self.assertSQL(query, ( + 'INSERT INTO "users" ("username") VALUES (?), (?), (?)'), + ['charlie', 'huey', 'zaizee']) + + # Use field name instead of field obj. + query = User.insert_many(data, fields=['username']) + self.assertSQL(query, ( + 'INSERT INTO "users" ("username") VALUES (?), (?), (?)'), + ['charlie', 'huey', 'zaizee']) + def test_insert_query(self): select = (Person .select(Person.id, Person.first) diff --git a/tests/sql.py b/tests/sql.py index 94082bfd6..a3fe7d6ea 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -509,6 +509,19 @@ def test_insert_list(self): 'INSERT INTO "person" ("name") VALUES (?), (?), (?)'), ['charlie', 'huey', 'zaizee']) + def test_insert_list_with_columns(self): + data = [(i,) for i in ('charlie', 'huey', 'zaizee')] + query = Person.insert(data, columns=[Person.name]) + self.assertSQL(query, ( + 'INSERT INTO "person" ("name") VALUES (?), (?), (?)'), + ['charlie', 'huey', 'zaizee']) + + # Use column name instead of column instance. + query = Person.insert(data, columns=['name']) + self.assertSQL(query, ( + 'INSERT INTO "person" ("name") VALUES (?), (?), (?)'), + ['charlie', 'huey', 'zaizee']) + def test_insert_query(self): source = User.select(User.c.username).where(User.c.admin == False) query = Person.insert(source, columns=[Person.name])