Skip to content

Commit

Permalink
Merge pull request #329 from andy8zhao/moreStaticColumnFix
Browse files Browse the repository at this point in the history
More static column related save/update/delete fix
  • Loading branch information
rustyrazorblade committed Jan 23, 2015
2 parents 32811f1 + 2acfcbe commit e52352f
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 11 deletions.
5 changes: 2 additions & 3 deletions cqlengine/columns.py
Expand Up @@ -26,6 +26,7 @@ def __init__(self, instance, column, value):
self.column = column
self.previous_value = deepcopy(value)
self.value = value
self.explicit = False

@property
def deleted(self):
Expand Down Expand Up @@ -140,9 +141,7 @@ def validate(self, value):
if there's a problem
"""
if value is None:
if self.has_default:
return self.get_default()
elif self.required:
if self.required:
raise ValidationError('{} - None values are not allowed'.format(self.column_name or self.db_field))
return value

Expand Down
7 changes: 6 additions & 1 deletion cqlengine/models.py
Expand Up @@ -336,6 +336,8 @@ def __init__(self, **values):
if value is not None or isinstance(column, columns.BaseContainerColumn):
value = column.to_python(value)
value_mngr = column.value_manager(self, column, value)
if name in values:
value_mngr.explicit = True
self._values[name] = value_mngr

# a flag set by the deserializer to indicate
Expand Down Expand Up @@ -490,7 +492,10 @@ def column_family_name(cls, include_keyspace=True):
def validate(self):
""" Cleans and validates the field values """
for name, col in self._columns.items():
val = col.validate(getattr(self, name))
v = getattr(self, name)
if v is None and not self._values[name].explicit and col.has_default:
v = col.get_default()
val = col.validate(v)
setattr(self, name, val)

### Let an instance be used like a dict of its columns keys/values
Expand Down
27 changes: 22 additions & 5 deletions cqlengine/query.py
Expand Up @@ -806,7 +806,9 @@ def update(self, **values):
if col.is_primary_key:
raise ValidationError("Cannot apply update to primary key '{}' for {}.{}".format(col_name, self.__module__, self.model.__name__))

# we should not provide default values in this use case.
val = col.validate(val)

if val is None:
nulled_columns.add(col_name)
continue
Expand Down Expand Up @@ -914,11 +916,17 @@ def update(self):
if self.instance is None:
raise CQLEngineException("DML Query intance attribute is None")
assert type(self.instance) == self.model
static_update_only = True
null_clustering_key = False if len(self.instance._clustering_keys) == 0 else True
static_changed_only = True
statement = UpdateStatement(self.column_family_name, ttl=self._ttl,
timestamp=self._timestamp, transactions=self._transaction)
for name, col in self.instance._clustering_keys.items():
null_clustering_key = null_clustering_key and col._val_is_null(getattr(self.instance, name, None))
#get defined fields and their column names
for name, col in self.model._columns.items():
# if clustering key is null, don't include non static columns
if null_clustering_key and not col.static and not col.partition_key:
continue
if not col.is_primary_key:
val = getattr(self.instance, name, None)
val_mgr = self.instance._values[name]
Expand All @@ -931,7 +939,7 @@ def update(self):
if not val_mgr.changed and not isinstance(col, Counter):
continue

static_update_only = (static_update_only and col.static)
static_changed_only = static_changed_only and col.static
if isinstance(col, (BaseContainerColumn, Counter)):
# get appropriate clause
if isinstance(col, List): klass = ListUpdateClause
Expand All @@ -953,7 +961,8 @@ def update(self):

if statement.get_context_size() > 0 or self.instance._has_counter:
for name, col in self.model._primary_keys.items():
if static_update_only and (not col.partition_key):
# only include clustering key if clustering key is not null, and non static columns are changed to avoid cql error
if (null_clustering_key or static_changed_only) and (not col.partition_key):
continue
statement.add_where_clause(WhereClause(
col.db_field_name,
Expand All @@ -962,7 +971,8 @@ def update(self):
))
self._execute(statement)

self._delete_null_columns()
if not null_clustering_key:
self._delete_null_columns()

def save(self):
"""
Expand All @@ -980,7 +990,12 @@ def save(self):
return self.update()
else:
insert = InsertStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp, if_not_exists=self._if_not_exists)
static_save_only = False if len(self.instance._clustering_keys) == 0 else True
for name, col in self.instance._clustering_keys.items():
static_save_only = static_save_only and col._val_is_null(getattr(self.instance, name, None))
for name, col in self.instance._columns.items():
if static_save_only and not col.static and not col.partition_key:
continue
val = getattr(self.instance, name, None)
if col._val_is_null(val):
if self.instance._values[name].changed:
Expand All @@ -996,7 +1011,8 @@ def save(self):
if not insert.is_empty:
self._execute(insert)
# delete any nulled columns
self._delete_null_columns()
if not static_save_only:
self._delete_null_columns()

def delete(self):
""" Deletes one instance """
Expand All @@ -1005,6 +1021,7 @@ def delete(self):

ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp)
for name, col in self.model._primary_keys.items():
if (not col.partition_key) and (getattr(self.instance, name) is None): continue
ds.add_where_clause(WhereClause(
col.db_field_name,
EqualsOperator(),
Expand Down
2 changes: 1 addition & 1 deletion cqlengine/tests/columns/test_container_columns.py
Expand Up @@ -529,4 +529,4 @@ def tearDownClass(cls):
drop_table(TestCamelMapModel)

def test_camelcase_column(self):
TestCamelMapModel.create(partition=None, camelMap={'blah': 1})
TestCamelMapModel.create(camelMap={'blah': 1})
12 changes: 11 additions & 1 deletion cqlengine/tests/columns/test_static_column.py
Expand Up @@ -9,7 +9,6 @@


class TestStaticModel(Model):

partition = columns.UUID(primary_key=True, default=uuid4)
cluster = columns.UUID(primary_key=True, default=uuid4)
static = columns.Text(static=True)
Expand Down Expand Up @@ -60,3 +59,14 @@ def test_static_only_updates(self):
actual = TestStaticModel.get(partition=u.partition)
assert actual.static == "it's still shared"

@skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0")
def test_static_with_null_cluster_key(self):
""" Tests that save/update/delete works for static column works when clustering key is null"""
instance = TestStaticModel.create(cluster=None, static = "it's shared")
instance.save()

u = TestStaticModel.get(partition=instance.partition)
u.static = "it's still shared"
u.update()
actual = TestStaticModel.get(partition=u.partition)
assert actual.static == "it's still shared"

0 comments on commit e52352f

Please sign in to comment.