Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Add equality support for Project/ModelState

  • Loading branch information...
commit 05656f2388b1989c9e99e1ff2aae8b2e1c805af2 1 parent 9027da6
@andrewgodwin andrewgodwin authored
Showing with 61 additions and 0 deletions.
  1. +21 −0 django/db/migrations/state.py
  2. +40 −0 tests/migrations/test_state.py
View
21 django/db/migrations/state.py
@@ -59,6 +59,14 @@ def from_app_cache(cls, app_cache):
models[(model_state.app_label, model_state.name.lower())] = model_state
return cls(models)
+ def __eq__(self, other):
+ if set(self.models.keys()) != set(other.models.keys()):
+ return False
+ return all(model == other.models[key] for key, model in self.models.items())
+
+ def __ne__(self, other):
+ return not (self == other)
+
class ModelState(object):
"""
@@ -167,3 +175,16 @@ def get_field_by_name(self, name):
if fname == name:
return field
raise ValueError("No field called %s on model %s" % (name, self.name))
+
+ def __eq__(self, other):
+ return (
+ (self.app_label == other.app_label) and
+ (self.name == other.name) and
+ (len(self.fields) == len(other.fields)) and
+ all((k1 == k2 and (f1.deconstruct()[1:] == f2.deconstruct()[1:])) for (k1, f1), (k2, f2) in zip(self.fields, other.fields)) and
+ (self.options == other.options) and
+ (self.bases == other.bases)
+ )
+
+ def __ne__(self, other):
+ return not (self == other)
View
40 tests/migrations/test_state.py
@@ -175,3 +175,43 @@ class Meta:
project_state.add_model_state(ModelState.from_model(F))
with self.assertRaises(InvalidBasesError):
project_state.render()
+
+ def test_equality(self):
+ """
+ Tests that == and != are implemented correctly.
+ """
+
+ # Test two things that should be equal
+ project_state = ProjectState()
+ project_state.add_model_state(ModelState(
+ "migrations",
+ "Tag",
+ [
+ ("id", models.AutoField(primary_key=True)),
+ ("name", models.CharField(max_length=100)),
+ ("hidden", models.BooleanField()),
+ ],
+ {},
+ None,
+ ))
+ other_state = project_state.clone()
+ self.assertEqual(project_state, project_state)
+ self.assertEqual(project_state, other_state)
+ self.assertEqual(project_state != project_state, False)
+ self.assertEqual(project_state != other_state, False)
+
+ # Make a very small change (max_len 99) and see if that affects it
+ project_state = ProjectState()
+ project_state.add_model_state(ModelState(
+ "migrations",
+ "Tag",
+ [
+ ("id", models.AutoField(primary_key=True)),
+ ("name", models.CharField(max_length=99)),
+ ("hidden", models.BooleanField()),
+ ],
+ {},
+ None,
+ ))
+ self.assertNotEqual(project_state, other_state)
+ self.assertEqual(project_state == other_state, False)
Please sign in to comment.
Something went wrong with that request. Please try again.