Skip to content

Commit

Permalink
Small fixes to update mode (#60)
Browse files Browse the repository at this point in the history
- Fix describe (not to return autoincrement column)
- Add check for validity for update_keys in write
- Fix hardcoded reference to autoincrement column
- Read data in a transaction
- Correct update return values depending on autoincrement mode
  • Loading branch information
akariv authored and roll committed Mar 30, 2017
1 parent cc86d5e commit ed62cf9
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 16 deletions.
7 changes: 6 additions & 1 deletion jsontableschema_sql/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def descriptor_to_columns_and_constraints(prefix, bucket, descriptor,
return (columns, constraints, indexes)


def columns_and_constraints_to_descriptor(prefix, tablename, columns, constraints):
def columns_and_constraints_to_descriptor(prefix, tablename, columns,
constraints, autoincrement_column):
"""Convert SQLAlchemy columns and constraints to descriptor.
"""

Expand All @@ -134,6 +135,8 @@ def columns_and_constraints_to_descriptor(prefix, tablename, columns, constraint
# Fields
fields = []
for column in columns:
if column.name == autoincrement_column:
continue
field_type = None
for key, value in mapping.items():
if isinstance(column.type, key):
Expand All @@ -153,6 +156,8 @@ def columns_and_constraints_to_descriptor(prefix, tablename, columns, constraint
for constraint in constraints:
if isinstance(constraint, PrimaryKeyConstraint):
for column in constraint.columns:
if column.name == autoincrement_column:
continue
pk.append(column.name)
if len(pk) > 0:
if len(pk) == 1:
Expand Down
24 changes: 16 additions & 8 deletions jsontableschema_sql/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,22 +172,27 @@ def describe(self, bucket, descriptor=None):
if descriptor is None:
table = self.__get_table(bucket)
descriptor = mappers.columns_and_constraints_to_descriptor(
self.__prefix, table.name, table.columns, table.constraints)
self.__prefix, table.name, table.columns, table.constraints,
self.__autoincrement)

return descriptor

def iter(self, bucket):

# Get result
table = self.__get_table(bucket)
# Streaming could be not working for some backends:
# http://docs.sqlalchemy.org/en/latest/core/connections.html
select = table.select().execution_options(stream_results=True)
result = select.execute()

# Yield data
for row in result:
yield list(row)
# Make sure we close the transaction after iterating,
# otherwise it is left hanging
with self.__connection.begin():
# Streaming could be not working for some backends:
# http://docs.sqlalchemy.org/en/latest/core/connections.html
select = table.select().execution_options(stream_results=True)
result = select.execute()

# Yield data
for row in result:
yield list(row)

def read(self, bucket):

Expand All @@ -198,6 +203,9 @@ def read(self, bucket):

def write(self, bucket, rows, keyed=False, as_generator=False, update_keys=None):

if update_keys is not None and len(update_keys) == 0:
raise ValueError('update_keys cannot be an empty list')

table = self.__get_table(bucket)
descriptor = self.describe(bucket)

Expand Down
16 changes: 10 additions & 6 deletions jsontableschema_sql/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def write(self, rows, keyed):
for wr in self.__insert():
yield wr
ret = self.__update(row)
if ret > 0:
if ret is not None:
yield WrittenRow(keyed_row,
True,
ret if self.autoincrement else None)
Expand Down Expand Up @@ -82,14 +82,18 @@ def __update(self, row):
expr = self.table.update().values(row)
for key in self.update_keys:
expr = expr.where(getattr(self.table.c, key) == row[key])
expr = expr.returning(getattr(self.table.c, '_id'))
if self.autoincrement:
expr = expr.returning(getattr(self.table.c, self.autoincrement))
res = expr.execute()
if res.rowcount > 0:
first = next(iter(res))
last_row_id = first[0]
return last_row_id
if self.autoincrement:
first = next(iter(res))
last_row_id = first[0]
return last_row_id
else:
return 0
else:
return 0
return None

@staticmethod
def __convert_to_keyed(schema, row):
Expand Down
25 changes: 24 additions & 1 deletion tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def test_storage():

def test_update():


# Get resources
descriptor = json.load(io.open('data/original.json', encoding='utf-8'))
original_rows = Stream('data/original.csv', headers=1).open().read()
Expand All @@ -92,22 +93,31 @@ def test_update():
engine = create_engine(os.environ['DATABASE_URL'])

# Storage
storage = Storage(engine=engine, prefix='test_update_', autoincrement='_id')
storage = Storage(engine=engine, prefix='test_update_', autoincrement='__id')

# Delete buckets
storage.delete()

# Create buckets
storage.create('colors', descriptor)


# Write data to buckets
storage.write('colors', original_rows, update_keys=update_keys)

gen = storage.write('colors', update_rows, update_keys=update_keys, as_generator=True)
gen = list(gen)
assert len(gen) == 5
assert len(list(filter(lambda i: i.updated, gen))) == 3
assert list(map(lambda i: i.updated_id, gen)) == [5, 3, 6, 4, 5]

storage = Storage(engine=engine, prefix='test_update_', autoincrement='__id')
gen = storage.write('colors', update_rows, update_keys=update_keys, as_generator=True)
gen = list(gen)
assert len(gen) == 5
assert len(list(filter(lambda i: i.updated, gen))) == 5
assert list(map(lambda i: i.updated_id, gen)) == [5, 3, 6, 4, 5]

# Create new storage to use reflection only
storage = Storage(engine=engine, prefix='test_update_')

Expand All @@ -127,6 +137,19 @@ def test_update():
6: 'grey'
}

# Storage without autoincrement
storage = Storage(engine=engine, prefix='test_update_')
storage.delete()
storage.create('colors', descriptor)

storage.write('colors', original_rows, update_keys=update_keys)
gen = storage.write('colors', update_rows, update_keys=update_keys, as_generator=True)
gen = list(gen)
assert len(gen) == 5
assert len(list(filter(lambda i: i.updated, gen))) == 3
assert list(map(lambda i: i.updated_id, gen)) == [None, None, None, None, None]


def test_bad_type():

# Engine
Expand Down

0 comments on commit ed62cf9

Please sign in to comment.