Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
tree: 1e15a4d9fd
Fetching contributors…

Cannot retrieve contributors at this time

file 182 lines (136 sloc) 6.233 kb
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
import sqlalchemy

import types
import logging
log = logging.getLogger(__name__)


# TODO: Move this elsewhere or hopefully deprecate it in favour of something in sqlalchemy-migrate
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import Executable, ClauseElement

class InsertFromSelect(Executable, ClauseElement):
    def __init__(self, table, select, defaults=None):
        self.table = table
        self.select = select
        self.defaults = defaults or {}

@compiles(InsertFromSelect)
def visit_insert_from_select(element, compiler, **kw):
    insert_columns = [col.name for col in element.select.columns]
    select = element.select

    for k,v in element.defaults.iteritems():
        insert_columns.append(k)
        # TODO: Add intelligent casting of values
        select.append_column(sqlalchemy.literal(v))

    select_query = compiler.process(select)

    return "INSERT INTO {insert_table} ({insert_columns}) {select_query}".format(
        insert_table=compiler.process(element.table, asfrom=True),
        insert_columns=', '.join(insert_columns),
        select_query=select_query,
    )
##


def table_migrate(e1, e2, table, table2=None, convert_fn=None, limit=100000):
    if table2 is None:
        table2 = table

    count = e1.execute(table.count()).scalar()

    log.debug("Inserting {0} rows into: {1}".format(count, table2.name))
    for offset in xrange(0, count, limit):
        # FIXME: There's an off-by-one bug here?
        q = e1.execute(table.select().offset(offset).limit(limit))

        data = q.fetchall()
        if not data:
               continue

        if convert_fn:
            r = []
            for row in data:
                converted = convert_fn(row=row, old_table=table, new_table=table2)
                if isinstance(converted, types.GeneratorType):
                    r += list(converted)
                elif converted is not None:
                    r.append(converted)
            data = r

        if data:
            e2.execute(table2.insert(), data).close()
        log.debug("-> Inserted {0} rows into: {1}".format(len(data), table2.name))


def table_replace(table_old, table_new, select_query=None, backup_table_name=None, defaults=None):
    """
This method is extremely hacky, use at your own risk.

:param table_old: Original table object.
:param table_new: New table object which will be renamed to use table_old.name.
:param select_query: Custom query to use when porting data between tables. If None, do plain select everything.
:param backup_table_name: If None, leave no backup. Otherwise save the original table with that name.
"""
    import migrate # This helper requires sqlalchemy-migrate

    name_old = table_old.name
    con = table_new.bind.connect()
    t = con.begin()

    select_query = select_query or table_old.select()

    indexes = table_new.indexes
    table_new.indexes = set([])

    # Make sure the names aren't colliding
    if table_new.name == name_old:
        table_new.name += "_gratetmp"

    # Drop all the indices to avoid having to rename them with sensible names
    for idx in table_old.indexes:
        idx.drop()

    table_new.create(checkfirst=True)
    con.execute(InsertFromSelect(table_new, select_query, defaults))
    t.commit()

    if backup_table_name:
        table_old.rename(backup_table_name)
    else:
        table_old.drop()

    # Swap the table and readd all the indices
    table_new.rename(name_old)
    for idx in indexes:
        idx.create()


def migrate_replace(e, metadata, only_tables=None, skip_tables=None):
    """
Similar to migrate but uses in-place table_replace instead of row-by-row reinsert between two engines.
:param e: SQLAlchemy engine
:param metadata: MetaData containing target desired schema
"""

    metadata_old = sqlalchemy.MetaData(bind=e, reflect=True)
    metadata.bind = e

    for table_name, table in metadata_old.tables.items():
        if (only_tables and table_name not in only_tables) or \
           (skip_tables and table_name in skip_tables):
            log.info("Skipping table: {0}".format(table_name))
            continue

        log.info("Replacing table: {0}".format(table_name))
        table_new = metadata.tables[table_name]
        table_new.name += '_gratetmp'
        table_replace(table, table_new)



def migrate(e1, e2, metadata, convert_map=None, populate_fn=None, only_tables=None, skip_tables=None, limit=100000):
    """
:param e1: Source engine (schema reflected)
:param e2: Target engine (schema generated from ``metadata``)
:param metadata: MetaData containing target desired schema.
"""

    metadata_old = sqlalchemy.MetaData(bind=e1, reflect=True)

    metadata.bind = e2
    metadata.create_all(bind=e2)

    # We create a new metadata which isn't tarnished by fancy columns of the given metadata.
    # FIXME: Should convert functions be getting new_metadata too?
    metadata_new = sqlalchemy.MetaData(bind=e2, reflect=True)

    convert_map = convert_map or {}

    if callable(populate_fn):
        log.info("Running populate function.")
        populate_fn(metadata_from=metadata_old, metadata_to=metadata_new)

    for table in metadata_old.sorted_tables:
        table_name = table.name
        if (only_tables and table_name not in only_tables) or \
           (skip_tables and table_name in skip_tables):
            log.info("Skipping table: {0}".format(table_name))
            continue

        log.info("Migrating table: {0}".format(table_name))

        convert = convert_map.get(table_name)
        if not convert:
            new_table = metadata_new.tables.get(table_name)
            if new_table is None:
                log.info("No corresponding table found, skipping: {0}".format(table_name))
                continue

            table_migrate(e1, e2, table, new_table, limit=limit)
            continue

        for new_table_name, convert_fn in convert:
            new_table = metadata_new.tables.get(new_table_name)
            table_migrate(e1, e2, table, new_table, convert_fn=convert_fn, limit=limit)


def upgrade(e, upgrade_fn):
    metadata = sqlalchemy.MetaData(bind=e, reflect=True)
    upgrade_fn(metadata)
Something went wrong with that request. Please try again.