Skip to content

Commit

Permalink
moved delimiter, indent to config, use builtin copy to/from for postgres
Browse files Browse the repository at this point in the history
  • Loading branch information
darthbear committed May 6, 2015
1 parent 763fdf1 commit a098b6a
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 82 deletions.
58 changes: 33 additions & 25 deletions .catdb
Original file line number Diff line number Diff line change
@@ -1,28 +1,36 @@
{
# block read by catdb
databases {
pg_testdb: ${defaults.credentials} {
hostname: localhost
database: testdb
type: postgres
}
my_testdb: ${defaults.credentials} {
hostname: localhost
database: testdb
type: mysql
}
}
# block read by catdb
databases {
pg_testdb: ${defaults.credentials} {
hostname: localhost
database: testdb
type: postgres
}
my_testdb: ${defaults.credentials} {
hostname: localhost
database: testdb
type: mysql
}
}

# block read by catdb
data-format {
delimiter: "|"
null: "\\N"
}

# default values to make things easier but not used directly by catdb
defaults {
credentials {
username: scott
password: tiger
}
aws {
region: us-east-1
aws_access_key_id: ${AWS_ACCESS_KEY_ID}
aws_secret_access_key: ${AWS_SECRET_ACCESS_KEY}
}
ddl-format {
indent: 4
}

# default values to make things easier but not used directly by catdb
defaults {
credentials {
username: scott
password: tiger
}
aws {
region: us-east-1
aws_access_key_id: ${AWS_ACCESS_KEY_ID}
aws_secret_access_key: ${AWS_SECRET_ACCESS_KEY}
}
}
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@

CatDB allows to migrate data from various databases.

### Installation

It is available on pypi so you can install it as follows:

$ pip install catdb

### Configuration

Create a file `$HOME/.catdb` with the list of databases to use:
Expand Down
53 changes: 30 additions & 23 deletions catdb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,21 @@ def __init__(self, dbname, params={}):
)

@abstractmethod
def list_tables(self, schema=None, table_filter=None):
def list_tables(self, table_filter=None, schema=None):
pass

@abstractmethod
def get_column_info(self, schema, table):
def get_column_info(self, table, schema=None):
pass

@abstractmethod
def get_connection(self, use_dict_cursor=True):
pass

@abstractmethod
def get_cursor(self, connection):
pass

def get_ddl(self, schema=None, table_filter=None):
"""translate db specific ddl to generic ddl"""
def get_default(col_type, value):
Expand Down Expand Up @@ -65,7 +69,7 @@ def get_column_def(entry):
return dict((k, v) for k, v in row.items() if v is not None)

def get_table_def(table):
meta = self.get_column_info(schema, table)
meta = self.get_column_info(table, schema)
return {
'name': table,
'columns': [get_column_def(col) for col in meta]
Expand Down Expand Up @@ -120,22 +124,21 @@ def execute(self, query):

conn.close()

def import_from_file(self, fd, schema=None, table=None, dry_run=False):
conn = self.get_connection()
reader = csv.reader(fd, delimiter='|', quotechar="'", quoting=csv.QUOTE_MINIMAL)
def import_from_file(self, fd, table, schema=None, delimiter='|', null_values='\\N'):
"""default implementation that should be overriden"""
conn = self.get_connection(False)
reader = csv.reader(fd, delimiter=delimiter, quoting=csv.QUOTE_MINIMAL)
# group rows in packets
header = next(reader)

def insert_all(rows):
query = "INSERT INTO {schema_str}{table} ({fields})\nVALUES".format(schema_str='' if schema is None else schema + '.',
table=table,
fields=','.join(header)) \
+ ',\n'.join('(' + ','.join("'" + f + "'" for f in row) + ')' for row in rows) + ';'

if dry_run:
print(query)
else:
conn.execute(query)
+ ',\n'.join('(' + ','.join(("'" + f + "'") if f != '\\N' else 'NULL' for f in row) + ')' for row in rows) + ';'
cursor = self.get_cursor(conn)
cursor.execute(query)
conn.commit()

buffer = []
for row in reader:
Expand All @@ -147,23 +150,27 @@ def insert_all(rows):
if len(buffer) > 0:
insert_all(buffer)

def export_to_file(self, fd, schema=None, table=None):
writer = csv.writer(fd, delimiter='|', quotechar="'", quoting=csv.QUOTE_MINIMAL)
def export_to_file(self, fd, table=None, schema=None, delimiter='|', null_value='\\N'):
"""default implementation that should be overriden"""
def format_row(row):
return [null_value if e is None else e for e in row]

writer = csv.writer(fd, delimiter=delimiter, quoting=csv.QUOTE_MINIMAL)
conn = self.get_connection(False)
try:
cursor = conn.cursor()
cursor = self.get_cursor(conn)
actual_table = ('' if schema is None else schema + '.') + table
cursor.execute('SELECT * FROM ' + actual_table)

# write header
fields = [desc[0] for desc in cursor.description]
writer.writerow(fields)

rows = cursor.fetchmany()
while rows:
for row in rows:
writer.writerow(row)
rows = cursor.fetchmany()
if rows:
# write header
fields = [desc[0] for desc in cursor.description]
writer.writerow(fields)
while rows:
formatted_rows = [format_row(row) for row in rows]
writer.writerows(formatted_rows)
rows = cursor.fetchmany()
finally:
conn.close()

Expand Down
61 changes: 33 additions & 28 deletions catdb/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
from contextlib import contextmanager
import json
import os
from pyhocon import ConfigFactory
Expand All @@ -8,6 +9,22 @@

def main():

@contextmanager
def open_output_file(filename):
if filename == '-':
yield sys.stdout
else:
with open(filename, 'w') as fd:
yield fd

@contextmanager
def open_input_file(filename):
if filename == '-':
yield sys.stdin
else:
with open(filename, 'r') as fd:
yield fd

parent_parser = argparse.ArgumentParser(add_help=False)
parent_parser.add_argument('-d', '--database', help='database', required=True, action='store')
parent_parser.add_argument('-s', '--schema', help='schema', required=False, action='store', default=None)
Expand Down Expand Up @@ -36,40 +53,28 @@ def main():

db = DbManager.get_db(db_config['type'], db_config)
if args.subparser_name == 'list':
print '\n'.join(db.list_tables(args.schema, args.table))
print '\n'.join(db.list_tables(args.table, args.schema))
elif args.subparser_name == 'ddl':
if args.export_file:
ddl_str = json.dumps(db.get_ddl(args.schema, args.table), sort_keys=True, indent=4, separators=(',', ': '))
if args.export_file == '-':
print ddl_str
else:
with open(args.export_file, 'w') as fd:
fd.write(ddl_str)
ddl_str = json.dumps(db.get_ddl(args.table, args.schema), sort_keys=True, indent=config['ddl-format.indent'], separators=(',', ': '))
with open_output_file(args.export_file) as fd:
fd.write(ddl_str)
elif args.import_file:
if args.import_file == '-':
ddl = json.loads(sys.stdin)
else:
with open(args.import_file, 'r') as fd:
ddl = json.loads(fd.read())

table_statement = db.create_database_statement(ddl, args.database, args.schema)
if args.dry_run:
print table_statement
else:
db.execute(table_statement)
with open_input_file(args.import_file) as fd:
ddl = json.loads(fd.read())
table_statement = db.create_database_statement(ddl, args.database, args.schema)
if args.dry_run:
print table_statement
else:
db.execute(table_statement)
elif args.subparser_name == 'data':
if args.export_file:
if args.export_file == '-':
db.export_to_file(sys.stdout, args.schema, args.table)
else:
with open(args.export_file, 'w') as fd:
db.export_to_file(fd, args.schema, args.table)
with open_output_file(args.export_file) as fd:
db.export_to_file(fd, args.table, args.schema, config['data-format.delimiter'], config['data-format.null'])

elif args.import_file:
if args.import_file == '-':
db.import_to_file(sys.stdin, args.schema, args.table, args.dry_run)
else:
with open(args.import_file, 'r') as fd:
db.import_from_file(fd, args.schema, args.table, args.dry_run)
with open_input_file(args.import_file) as fd:
db.import_from_file(fd, args.table, args.schema, config['data-format.delimiter'], config['data-format.null'])

if __name__ == '__main__':
main()
7 changes: 5 additions & 2 deletions catdb/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@ def get_connection(self, is_dict_cursor=True, db=None):
charset='utf8mb4',
cursorclass=cursor_class)

def list_tables(self, schema=None, table_filter=None):
def get_cursor(self, connection):
return connection.cursor()

def list_tables(self, table_filter=None, schema=None):
with self.get_connection(False) as cursor:
query = "SHOW TABLES LIKE '{filter}'".format(filter='%' if table_filter is None else table_filter)
cursor.execute(query)
return [table[0] for table in cursor.fetchall()]

def get_column_info(self, schema, table):
def get_column_info(self, table, schema=None):
def get_col_def(row):
# DOUBLE(10,2)
data_type_tokens = row['Type'].split('(')
Expand Down
18 changes: 16 additions & 2 deletions catdb/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,21 @@ def get_connection(self, use_dict_cursor=False, db=None):
})
return psycopg2.connect(**params)

def list_tables(self, schema=None, filter=None):
def get_cursor(self, connection):
cursor = connection.cursor('cursor')
return cursor

def export_to_file(self, fd, table=None, schema=None, delimiter='|', null_value='\\N'):
with psycopg2.connect(**self.__get_connect_params()) as conn:
with conn.cursor() as cursor:
cursor.copy_to(fd, table=table, sep=delimiter, null=null_value)

def import_from_file(self, fd, table=None, schema=None, delimiter='|', null_value='\\N'):
with psycopg2.connect(**self.__get_connect_params()) as conn:
with conn.cursor() as cursor:
cursor.copy_from(fd, table=table, sep=delimiter, null=null_value)

def list_tables(self, filter=None, schema=None):
with psycopg2.connect(**self.__get_connect_params()) as conn:
with conn.cursor() as cursor:
cursor.execute(
Expand All @@ -38,7 +52,7 @@ def list_tables(self, schema=None, filter=None):

# http://stackoverflow.com/questions/2204058/list-columns-with-indexes-in-postgresql
# http://www.alberton.info/postgresql_meta_info.html#.VT2sIhPF-d4
def get_column_info(self, schema, table):
def get_column_info(self, table, schema):
with psycopg2.connect(**self.__get_connect_params()) as conn:
with conn.cursor() as cursor:
cursor.execute(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_get_column_info(self, mock_pymysql):
'database': 'test'
})

assert mysql.get_column_info(None, 'test') == [
assert mysql.get_column_info('test', None) == [
{
'column': 'field',
'type': 'varchar',
Expand Down
2 changes: 1 addition & 1 deletion tests/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_get_column_info(self, mock_psycopg2):
'database': 'test'
})

assert postgres.get_column_info(None, 'test') == [
assert postgres.get_column_info('test', None) == [
{
'column': 'field',
'type': 'character varying',
Expand Down

0 comments on commit a098b6a

Please sign in to comment.