Skip to content
Browse files

Inter-upgrader dependencies and resolution

  • Loading branch information...
1 parent 3b08778 commit ba44d2a8aea0662651ad417b6143c7770a7dca68 @brendonh committed Sep 9, 2011
Showing with 137 additions and 39 deletions.
  1. +10 −7 schemup/schemup/commands.py
  2. +9 −1 schemup/schemup/dbs/postgres.py
  3. +118 −31 schemup/schemup/upgraders.py
View
17 schemup/schemup/commands.py
@@ -42,20 +42,23 @@ def upgrade(dbSchema, ormSchema):
and run them.
"""
+ import pprint
+
paths = [(tableName, upgraders.findUpgradePath(tableName, fromVersion, toVersion))
for (tableName, fromVersion, toVersion)
in validator.findMismatches(dbSchema, ormSchema)]
if not paths:
return
+ stepGraph = upgraders.UpgradeStepGraph()
+
for tableName, path in paths:
- print "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
- print "Upgrading %s" % tableName
- print "%s => %s" % (path.firstVersion(), path.lastVersion())
- print "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
- path.apply(dbSchema)
- dbSchema.setSchema(tableName, path.lastVersion(), log=False)
- dbSchema.printLog()
+ path.addToGraph(stepGraph)
+
+ stepGraph.calculateEdges()
+
+ for upgrader in stepGraph.topologicalSort():
+ upgrader.run(dbSchema)
dbSchema.commit()
View
10 schemup/schemup/dbs/postgres.py
@@ -17,7 +17,7 @@ def execute(self, query, args=(), cur=None, log=True):
cur.execute(query, args)
def flushLog(self):
- log, self.runlog = self.runLog, []
+ log, self.runLog = self.runLog, []
return log
def printLog(self):
@@ -65,6 +65,14 @@ def getSchema(self, tableName):
return u"\n".join(u"|".join(unicode(c) for c in row) for row in cur)
+ def getTableVersions(self):
+ cur = self.conn.cursor()
+ cur.execute(
+ "SELECT table_name, version"
+ " FROM schemup_tables"
+ " WHERE is_current = 't'")
+ return cur
+
def getVersionedTableSchemas(self):
cur = self.conn.cursor()
View
149 schemup/schemup/upgraders.py
@@ -3,7 +3,37 @@
# table -> {(from, to) -> function}
registeredUpgraders = {}
-def registerUpgrader(tableName, fromVersion, toVersion, upgrader):
+
+class Upgrader(object):
+ def __init__(self, tableName, fromVersion, toVersion, upgrader, dependencies=()):
+ self.tableName = tableName
+ self.fromVersion = fromVersion
+ self.toVersion = toVersion
+ self.upgrader = upgrader
+ self.dependencies = list(dependencies)
+
+ def copy(self):
+ return Upgrader(self.tableName, self.fromVersion, self.toVersion,
+ self.upgrader, self.dependencies)
+
+ def run(self, dbSchema):
+ if self.upgrader is None:
+ return
+
+ print "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
+ print "Upgrading %s" % self.tableName
+ print "%s => %s" % (self.fromVersion, self.toVersion)
+ self.upgrader(dbSchema)
+ dbSchema.setSchema(self.tableName, self.toVersion, log=False)
+ dbSchema.printLog()
+
+
+ def __repr__(self):
+ return "<%s: %s => %s (%s)>" % (
+ self.tableName, self.fromVersion, self.toVersion, self.upgrader)
+
+
+def registerUpgrader(tableName, fromVersion, toVersion, upgrader, dependencies=()):
if tableName not in registeredUpgraders:
registeredUpgraders[tableName] = {}
@@ -18,20 +48,51 @@ def registerUpgrader(tableName, fromVersion, toVersion, upgrader):
raise ValueError("Upgrader already exists for %s (%s => %s)" % (
tableName, fromVersion, toVersion))
- versionUpgraders[toVersion] = upgrader
+ versionUpgraders[toVersion] = Upgrader(
+ tableName, fromVersion, toVersion, upgrader, dependencies)
-def upgrader(tableName, fromVersion, toVersion):
+def upgrader(tableName, fromVersion, toVersion, dependencies=()):
"""
Decorator shortcut for registerUpgrader
"""
def decorate(function):
- registerUpgrader(tableName, fromVersion, toVersion, function)
+ registerUpgrader(tableName, fromVersion, toVersion, function, dependencies)
return function
return decorate
+# ------------------------------------------------------------
+
+def findUpgradePath(tableName, fromVersion, toVersion):
+ upgraders = registeredUpgraders.get(tableName, {})
+
+ initialPath = UpgradePath([])
+
+ if fromVersion is None:
+ initialPath.push(Upgrader(tableName, None, fromVersion, None))
+ else:
+ for upgrader in findUpgradePath(tableName, None, fromVersion).steps:
+ stubUpgrader = upgrader.copy()
+ stubUpgrader.upgrader = None
+ initialPath.push(stubUpgrader)
+
+ paths = deque([ initialPath ])
+
+ while paths:
+ path = paths.popleft()
+
+ lastVersion = path.lastVersion()
+ if lastVersion == toVersion:
+ return path
+
+ for (nextVersion, upgrader) in upgraders.get(lastVersion, {}).iteritems():
+ paths.append( path.pushNew(upgrader) )
+
+ raise ValueError("No upgrade path for %s (%s -> %s)" % (tableName, fromVersion, toVersion))
+
+
class UpgradePath(object):
def __init__(self, steps=None, seen=None):
self.steps = steps or []
@@ -40,45 +101,71 @@ def __init__(self, steps=None, seen=None):
def copy(self):
return UpgradePath(self.steps[:], self.seen.copy())
- def push(self, version, upgrader):
- if version in self.seen:
+ def push(self, upgrader):
+ if upgrader.toVersion in self.seen:
raise ValueError("Upgrader cycle", self.steps, version)
- self.seen.add(version)
- self.steps.append((version, upgrader))
+ self.seen.add(upgrader.toVersion)
+ self.steps.append(upgrader)
- def pushNew(self, version, upgrader):
+ def pushNew(self, upgrader):
copy = self.copy()
- copy.push(version, upgrader)
+ copy.push(upgrader)
return copy
def firstVersion(self):
- return self.steps[0][0]
+ return self.steps[0].toVersion
def lastVersion(self):
- return self.steps[-1][0]
+ return self.steps[-1].toVersion
- def apply(self, dbSchema):
- for _version, upgrader in self.steps:
- if upgrader is not None:
- upgrader(dbSchema)
+ def addToGraph(self, graph):
+ prev = None
+ for origUpgrader in self.steps:
+ upgrader = origUpgrader.copy()
+ if prev is not None:
+ upgrader.dependencies.append((prev.tableName, prev.toVersion))
+ graph.addUpgrader(upgrader)
+ prev = upgrader
def __str__(self):
- return "\n".join("-> %s: %s" % (v, u) for (v, u) in self.steps)
+ return "\n".join("-> %s: %s" % (u.toVersion, u.upgrader) for u in self.steps)
-def findUpgradePath(tableName, fromVersion, toVersion):
- upgraders = registeredUpgraders.get(tableName, {})
- paths = deque([ UpgradePath([(fromVersion, None)]) ])
-
- while paths:
- path = paths.popleft()
+class UpgradeStepGraph(object):
+ def __init__(self):
+ self.nodes = {}
+ self.edges = {}
- lastVersion = path.lastVersion()
- if lastVersion == toVersion:
- return path
-
- for (nextVersion, upgrader) in upgraders.get(lastVersion, {}).iteritems():
- paths.append( path.pushNew(nextVersion, upgrader) )
-
- raise ValueError("No upgrade path for %s (%s -> %s)" % (tableName, fromVersion, toVersion))
+ def addUpgrader(self, upgrader):
+ self.nodes[(upgrader.tableName, upgrader.toVersion)] = upgrader
+
+ def calculateEdges(self):
+ for fromKey, upgrader in self.nodes.iteritems():
+ if fromKey not in self.edges:
+ self.edges[fromKey] = set()
+
+ for toKey in upgrader.dependencies:
+ if toKey not in self.nodes:
+ raise ValueError(
+ "Upgrader %s has unmet dependency on %s"
+ % (upgrader, toKey))
+
+ self.edges[fromKey].add(toKey)
+
+ def topologicalSort(self):
+ edges = self.edges.copy()
+ path = []
+ while True:
+ freeKeys = set(key for (key, deps) in edges.iteritems() if not deps)
+ if not freeKeys:
+ break
+ path.extend(freeKeys)
+ edges = dict((key, deps - freeKeys) for (key, deps) in edges.iteritems()
+ if key not in freeKeys)
+
+ if edges:
+ raise ValueError(
+ "Cyclic upgrader dependencies", self.edges)
+
+ return [self.nodes[key] for key in path]

0 comments on commit ba44d2a

Please sign in to comment.
Something went wrong with that request. Please try again.