Skip to content

Commit

Permalink
Added an auto-increment column (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
akariv authored and roll committed Mar 6, 2017
1 parent c6109b8 commit 32f3827
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 17 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,4 @@ target/
# Extras
tabulator
jsontableschema
!/jsontableschema_sql/mappers.py
13 changes: 11 additions & 2 deletions jsontableschema_sql/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
Text, VARCHAR, Float, Integer, Boolean, Date, Time, DateTime)
from sqlalchemy.dialects.postgresql import ARRAY, JSON, JSONB, UUID


# Module API


def bucket_to_tablename(prefix, bucket):
"""Convert bucket to SQLAlchemy tablename.
"""
Expand All @@ -27,7 +27,8 @@ def tablename_to_bucket(prefix, tablename):
return None


def descriptor_to_columns_and_constraints(prefix, bucket, descriptor, index_fields):
def descriptor_to_columns_and_constraints(prefix, bucket, descriptor,
index_fields, autoincrement):
"""Convert descriptor to SQLAlchemy columns and constraints.
"""

Expand All @@ -52,6 +53,8 @@ def descriptor_to_columns_and_constraints(prefix, bucket, descriptor, index_fiel
'geojson': JSONB,
}

if autoincrement is not None:
columns.append(Column(autoincrement, Integer, autoincrement=True, nullable=False))
# Fields
for field in descriptor['fields']:
try:
Expand All @@ -76,6 +79,12 @@ def descriptor_to_columns_and_constraints(prefix, bucket, descriptor, index_fiel
if pk is not None:
if isinstance(pk, six.string_types):
pk = [pk]
if autoincrement is not None:
if pk is not None:
pk = [autoincrement] + pk
else:
pk = [autoincrement]
if pk is not None:
constraint = PrimaryKeyConstraint(*pk)
constraints.append(constraint)

Expand Down
8 changes: 5 additions & 3 deletions jsontableschema_sql/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ class Storage(object):

# Public

def __init__(self, engine, dbschema=None, prefix='', reflect_only=None):
def __init__(self, engine, dbschema=None, prefix='', reflect_only=None,
autoincrement=None):

# Set attributes
self.__connection = engine.connect()
self.__dbschema = dbschema
self.__prefix = prefix
self.__descriptors = {}
self.__autoincrement = autoincrement
if reflect_only is not None:
self.__only = reflect_only
else:
Expand Down Expand Up @@ -120,7 +122,7 @@ def create(self, bucket, descriptor, force=False, indexes_fields=None):
jsontableschema.validate(descriptor)
tablename = mappers.bucket_to_tablename(self.__prefix, bucket)
columns, constraints, indexes = mappers.descriptor_to_columns_and_constraints(
self.__prefix, bucket, descriptor, index_fields)
self.__prefix, bucket, descriptor, index_fields, self.__autoincrement)
Table(tablename, self.__metadata, *(columns+constraints+indexes))

# Create tables, update metadata
Expand Down Expand Up @@ -199,7 +201,7 @@ def write(self, bucket, rows, keyed=False, as_generator=False, update_keys=None)
table = self.__get_table(bucket)
descriptor = self.describe(bucket)

writer = StorageWriter(table, descriptor, update_keys)
writer = StorageWriter(table, descriptor, update_keys, self.__autoincrement)

with self.__connection.begin():
gen = writer.write(rows, keyed)
Expand Down
43 changes: 33 additions & 10 deletions jsontableschema_sql/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@


BUFFER_SIZE = 1000
WrittenRow = namedtuple('WrittenRow', ['row', 'updated'])
WrittenRow = namedtuple('WrittenRow', ['row', 'updated', 'updated_id'])


class StorageWriter(object):

def __init__(self, table, descriptor, update_keys):
def __init__(self, table, descriptor, update_keys, autoincrement):

self.table = table
self.descriptor = descriptor
self.update_keys = update_keys
self.autoincrement = autoincrement
if update_keys is not None:
self.__prepare_bloom()
self.__buffer = []
Expand All @@ -41,32 +42,54 @@ def write(self, rows, keyed):
keyed_row = row

if self.__check_existing(keyed_row):
self.__insert()
if self.__update(row):
yield WrittenRow(keyed_row, True)
for wr in self.__insert():
yield wr
ret = self.__update(row)
if ret > 0:
yield WrittenRow(keyed_row,
True,
ret if self.autoincrement else None)
continue

self.__buffer.append(keyed_row)

if len(self.__buffer) > BUFFER_SIZE:
self.__insert()
yield WrittenRow(keyed_row, False)
for wr in self.__insert():
yield wr

self.__insert()
for wr in self.__insert():
yield wr

def __insert(self):
if len(self.__buffer) > 0:
# Insert data
self.table.insert().execute(self.__buffer)
statement = self.table.insert()
if self.autoincrement:
statement = statement.returning(getattr(self.table.c, self.autoincrement))
statement = statement.values(self.__buffer)
res = statement.execute()
for id, in res:
row = self.__buffer.pop(0)
yield WrittenRow(row, False, id)
else:
statement.execute(self.__buffer)
for row in self.__buffer:
yield WrittenRow(row, False, None)
# Clean memory
self.__buffer = []

def __update(self, row):
expr = self.table.update().values(row)
for key in self.update_keys:
expr = expr.where(getattr(self.table.c, key) == row[key])
expr = expr.returning(getattr(self.table.c, '_id'))
res = expr.execute()
return res.rowcount > 0
if res.rowcount > 0:
first = next(iter(res))
last_row_id = first[0]
return last_row_id
else:
return 0

@staticmethod
def __convert_to_keyed(schema, row):
Expand Down
22 changes: 20 additions & 2 deletions tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_update():
engine = create_engine(os.environ['DATABASE_URL'])

# Storage
storage = Storage(engine=engine, prefix='test_update_')
storage = Storage(engine=engine, prefix='test_update_', autoincrement='_id')

# Delete buckets
storage.delete()
Expand All @@ -106,6 +106,7 @@ def test_update():
gen = list(gen)
assert len(gen) == 5
assert len(list(filter(lambda i: i.updated, gen))) == 3
assert list(map(lambda i: i.updated_id, gen)) == [5, 3, 6, 4, 5]

# Create new storage to use reflection only
storage = Storage(engine=engine, prefix='test_update_')
Expand All @@ -114,7 +115,7 @@ def test_update():

assert len(rows) == 6
color_by_person = dict(
(row[0], row[2])
(row[1], row[3])
for row in rows
)
assert color_by_person == {
Expand All @@ -126,6 +127,23 @@ def test_update():
6: 'grey'
}

def test_bad_type():

# Engine
engine = create_engine(os.environ['DATABASE_URL'])

# Storage
storage = Storage(engine=engine, prefix='test_bad_type_')
with pytest.raises(TypeError):
storage.create('bad_type', {
'fields': [
{
'name': 'bad_field',
'type': 'any'
}
]
})


def test_only_parameter():
# Check the 'only' parameter
Expand Down

0 comments on commit 32f3827

Please sign in to comment.