Permalink
Browse files

Framework, Postgres schema cache, Storm versioning, upgraders, basic …

…commands
  • Loading branch information...
0 parents commit 8bebabb1af0ab33439d296e28b5d4b23b3a2af93 @brendonh committed Sep 8, 2011
@@ -0,0 +1,50 @@
+import psycopg2
+
+from schemup.dbs import postgres
+from schemup.orms import storm
+from schemup.upgraders import upgrader
+from schemup import commands
+
+conn = psycopg2.connect("dbname=schemup_test")
+
+stormSchema = storm.StormSchema()
+postgresSchema = postgres.PostgresSchema(conn, dryRun=False)
+
+@stormSchema.versioned
+class Quick(object):
+ __storm_table__ = "quick"
+ __version__ = "bgh_3"
+
+@upgrader('quick', 'bgh_1', 'bgh_2')
+def quick_bgh1to2(dbSchema):
+ dbSchema.execute("ALTER TABLE quick ADD another VARCHAR NOT NULL DEFAULT 'hey'")
+
+@upgrader('quick', 'bgh_2', 'bgh_3')
+def quick_bgh2to3(dbSchema):
+ dbSchema.execute("ALTER TABLE quick ADD onemore INTEGER")
+
+
+@stormSchema.versioned
+class NewTable(object):
+ __storm_table__ = "new_table"
+ __version__ = "bgh_1"
+
+@upgrader('new_table', None, 'bgh_1')
+def new_table_create(dbSchema):
+ dbSchema.execute("CREATE TABLE new_table ("
+ " id SERIAL NOT NULL PRIMARY KEY,"
+ " name VARCHAR)")
+
+
+commands.upgrade(postgresSchema, stormSchema)
+
+validationError = commands.validate(postgresSchema, stormSchema)
+if validationError is not None:
+ errorType, errors = validationError
+ print "Validation failed (%s)" % errorType
+ for (tableName, actual, expected) in errors:
+ print "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
+ print "Table: %s" % tableName
+ print "- Actual: %s" % actual
+ print "- Expected: %s" % expected
+ raise SystemExit
No changes.
@@ -0,0 +1,59 @@
+from schemup import validator, upgraders
+
+def snapshot(dbSchema, ormSchema):
+ """
+ Write current versions to DB schema table.
+ Used only to initialize schemup on an existing DB
+ """
+
+ dbSchema.clearSchemaTable()
+
+ for tableName, version in ormSchema.getExpectedTableVersions():
+ dbSchema.setSchema(tableName, version)
+
+ dbSchema.commit()
+
+
+def validate(dbSchema, ormSchema):
+ """
+ Check DB versions against ORM versions, returning mismatches.
+ If there are version mismatches, check DB schemas against cache,
+ returning mismatches there.
+ """
+
+ mismatches = validator.findMismatches(dbSchema, ormSchema)
+
+ if mismatches:
+ return ('orm', mismatches)
+
+ schemaMismatches = validator.findSchemaMismatches(dbSchema)
+
+ if schemaMismatches:
+ return ('schema', schemaMismatches)
+
+ return None
+
+
+def upgrade(dbSchema, ormSchema):
+ """
+ Attempt to find upgrade paths for all out-of-sync tables,
+ and run them.
+ """
+
+ paths = [(tableName, upgraders.findUpgradePath(tableName, fromVersion, toVersion))
+ for (tableName, fromVersion, toVersion)
+ in validator.findMismatches(dbSchema, ormSchema)]
+
+ if not paths:
+ return
+
+ for tableName, path in paths:
+ print "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
+ print "Upgrading %s" % tableName
+ print "%s => %s" % (path.firstVersion(), path.lastVersion())
+ print "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
+ path.apply(dbSchema)
+ dbSchema.setSchema(tableName, path.lastVersion(), log=False)
+ dbSchema.printLog()
+
+ dbSchema.commit()
No changes.
@@ -0,0 +1,104 @@
+class PostgresSchema(object):
+
+ def __init__(self, conn, dryRun=False):
+ self.conn = conn
+ self.dryRun = dryRun
+ self.runLog = []
+
+
+ def execute(self, query, args=(), cur=None, log=True):
+ if cur is None:
+ cur = self.conn.cursor()
+
+ if log:
+ self.runLog.append(cur.mogrify(query, args))
+
+ if not self.dryRun:
+ cur.execute(query, args)
+
+ def flushLog(self):
+ log, self.runlog = self.runLog, []
+ return log
+
+ def printLog(self):
+ for line in self.flushLog():
+ print line
+
+ def commit(self):
+ if self.dryRun:
+ return
+ self.conn.commit()
+
+ def ensureSchemaTable(self):
+ cur = self.conn.cursor()
+ cur.execute(
+ "SELECT COUNT(*)"
+ " FROM information_schema.tables"
+ " WHERE table_name = 'schemup_tables'")
+
+ if cur.fetchone()[0]:
+ return
+
+ print "Creating schema table..."
+ cur.execute(
+ "CREATE TABLE schemup_tables ("
+ " table_name VARCHAR NOT NULL,"
+ " version VARCHAR NOT NULL,"
+ " is_current BOOLEAN NOT NULL DEFAULT 'f',"
+ " schema TEXT)")
+
+ self.conn.commit()
+
+
+ def clearSchemaTable(self):
+ self.execute("DELETE FROM schemup_tables")
+
+
+ def getSchema(self, tableName):
+ cur = self.conn.cursor()
+ cur.execute(
+ "SELECT column_name, data_type, is_nullable, column_default"
+ " FROM information_schema.columns"
+ " WHERE table_name = %s"
+ " ORDER BY column_name",
+ (tableName,))
+
+ return u"\n".join(u"|".join(unicode(c) for c in row) for row in cur)
+
+
+ def getVersionedTableSchemas(self):
+ cur = self.conn.cursor()
+ cur.execute(
+ "SELECT table_name, schema"
+ " FROM schemup_tables"
+ " WHERE is_current = 't'")
+ return cur
+
+
+ def setSchema(self, tableName, version, log=True):
+
+ schema = self.getSchema(tableName)
+
+ cur = self.conn.cursor()
+ self.execute(
+ "UPDATE schemup_tables"
+ " SET is_current = 'f'"
+ " WHERE table_name = %s",
+ (tableName,), cur, log)
+ self.execute(
+ "INSERT INTO schemup_tables"
+ " (table_name, version, is_current, schema)"
+ " VALUES (%s, %s, 't', %s)",
+ (tableName, version, schema), cur, log)
+
+
+ def getKnownTableVersions(self):
+ self.ensureSchemaTable()
+
+ cur = self.conn.cursor()
+ cur.execute(
+ "SELECT table_name, version"
+ " FROM schemup_tables"
+ " WHERE is_current = 't'")
+
+ return sorted(cur.fetchall())
No changes.
@@ -0,0 +1,12 @@
+
+class StormSchema(object):
+
+ def __init__(self):
+ self.modelCache = []
+
+ def versioned(self, cls):
+ self.modelCache.append((cls.__storm_table__, cls.__version__))
+ return cls
+
+ def getExpectedTableVersions(self):
+ return sorted(self.modelCache)
@@ -0,0 +1,84 @@
+from collections import deque
+
+# table -> {(from, to) -> function}
+registeredUpgraders = {}
+
+def registerUpgrader(tableName, fromVersion, toVersion, upgrader):
+ if tableName not in registeredUpgraders:
+ registeredUpgraders[tableName] = {}
+
+ tableUpgraders = registeredUpgraders[tableName]
+
+ if fromVersion not in tableUpgraders:
+ tableUpgraders[fromVersion] = {}
+
+ versionUpgraders = tableUpgraders[fromVersion]
+
+ if toVersion in versionUpgraders:
+ raise ValueError("Upgrader already exists for %s (%s => %s)" % (
+ tableName, fromVersion, toVersion))
+
+ versionUpgraders[toVersion] = upgrader
+
+
+def upgrader(tableName, fromVersion, toVersion):
+ """
+ Decorator shortcut for registerUpgrader
+ """
+ def decorate(function):
+ registerUpgrader(tableName, fromVersion, toVersion, function)
+ return function
+
+ return decorate
+
+
+class UpgradePath(object):
+ def __init__(self, steps=None, seen=None):
+ self.steps = steps or []
+ self.seen = seen or set()
+
+ def copy(self):
+ return UpgradePath(self.steps[:], self.seen.copy())
+
+ def push(self, version, upgrader):
+ if version in self.seen:
+ raise ValueError("Upgrader cycle", self.steps, version)
+ self.seen.add(version)
+ self.steps.append((version, upgrader))
+
+ def pushNew(self, version, upgrader):
+ copy = self.copy()
+ copy.push(version, upgrader)
+ return copy
+
+ def firstVersion(self):
+ return self.steps[0][0]
+
+ def lastVersion(self):
+ return self.steps[-1][0]
+
+ def apply(self, dbSchema):
+ for _version, upgrader in self.steps:
+ if upgrader is not None:
+ upgrader(dbSchema)
+
+ def __str__(self):
+ return "\n".join("-> %s: %s" % (v, u) for (v, u) in self.steps)
+
+
+def findUpgradePath(tableName, fromVersion, toVersion):
+ upgraders = registeredUpgraders.get(tableName, {})
+
+ paths = deque([ UpgradePath([(fromVersion, None)]) ])
+
+ while paths:
+ path = paths.popleft()
+
+ lastVersion = path.lastVersion()
+ if lastVersion == toVersion:
+ return path
+
+ for (nextVersion, upgrader) in upgraders.get(lastVersion, {}).iteritems():
+ paths.append( path.pushNew(nextVersion, upgrader) )
+
+ raise ValueError("No upgrade path for %s (%s -> %s)" % (tableName, fromVersion, toVersion))
@@ -0,0 +1,29 @@
+def findMismatches(dbSchema, ormSchema):
+ actual = dict(dbSchema.getKnownTableVersions())
+ expected = dict(ormSchema.getExpectedTableVersions())
+
+ tables = set(actual.keys()) | set(expected.keys())
+
+ mismatches = []
+
+ for table in tables:
+ exTable = expected.get(table)
+ acTable = actual.get(table)
+
+ if exTable == acTable:
+ continue
+
+ mismatches.append((table, acTable, exTable))
+
+ return mismatches
+
+
+def findSchemaMismatches(dbSchema):
+ errors = []
+ for tableName, expectedSchema in dbSchema.getVersionedTableSchemas():
+ actualSchema = dbSchema.getSchema(tableName)
+ if expectedSchema != actualSchema:
+ errors.append((tableName, actualSchema, expectedSchema))
+ return errors
+
+

0 comments on commit 8bebabb

Please sign in to comment.