Skip to content

Commit

Permalink
added redshift support for export
Browse files Browse the repository at this point in the history
  • Loading branch information
darthbear committed May 7, 2015
1 parent a098b6a commit 8895619
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 71 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ Data import | :white_check_mark:
Constraints export | :x:
Constraints import | :x:
Postgres | :white_check_mark:
Redshift | :white_check_mark:
MySQL | :white_check_mark:
Oracle | :x:
Vertica | :x:
Expand All @@ -227,3 +228,4 @@ Export to S3 | :x:
Import from S3 | :x:
Compatible with Turbine XML format | :x:
Common console | :x:
Export/Import Compression | :x:
22 changes: 22 additions & 0 deletions catdb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,24 @@
from contextlib import contextmanager
import sys


class CatDbException(Exception):
pass


@contextmanager
def open_output_file(filename):
if filename == '-':
yield sys.stdout
else:
with open(filename, 'w') as fd:
yield fd


@contextmanager
def open_input_file(filename):
if filename == '-':
yield sys.stdin
else:
with open(filename, 'r') as fd:
yield fd
70 changes: 38 additions & 32 deletions catdb/db.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from abc import abstractmethod
from contextlib import contextmanager
import csv
import importlib
import pkg_resources
from pyhocon import ConfigFactory
import sys
from catdb import CatDbException
from catdb import open_output_file, open_input_file


class Db(object):
Expand Down Expand Up @@ -124,12 +127,8 @@ def execute(self, query):

conn.close()

def import_from_file(self, fd, table, schema=None, delimiter='|', null_values='\\N'):
def import_from_file(self, filename, table, schema=None, delimiter='|', null_values='\\N'):
"""default implementation that should be overriden"""
conn = self.get_connection(False)
reader = csv.reader(fd, delimiter=delimiter, quoting=csv.QUOTE_MINIMAL)
# group rows in packets
header = next(reader)

def insert_all(rows):
query = "INSERT INTO {schema_str}{table} ({fields})\nVALUES".format(schema_str='' if schema is None else schema + '.',
Expand All @@ -140,39 +139,46 @@ def insert_all(rows):
cursor.execute(query)
conn.commit()

buffer = []
for row in reader:
buffer.append(row)
if len(buffer) >= Db.ROW_BUFFER_SIZE:
insert_all(buffer)
buffer = []
with open_input_file(filename) as fd:
conn = self.get_connection(False)
reader = csv.reader(fd, delimiter=delimiter, quoting=csv.QUOTE_MINIMAL)
# group rows in packets
header = next(reader)

buffer = []
for row in reader:
buffer.append(row)
if len(buffer) >= Db.ROW_BUFFER_SIZE:
insert_all(buffer)
buffer = []

if len(buffer) > 0:
insert_all(buffer)
if len(buffer) > 0:
insert_all(buffer)

def export_to_file(self, fd, table=None, schema=None, delimiter='|', null_value='\\N'):
def export_to_file(self, filename, table=None, schema=None, delimiter='|', null_value='\\N'):
"""default implementation that should be overriden"""
def format_row(row):
return [null_value if e is None else e for e in row]

writer = csv.writer(fd, delimiter=delimiter, quoting=csv.QUOTE_MINIMAL)
conn = self.get_connection(False)
try:
cursor = self.get_cursor(conn)
actual_table = ('' if schema is None else schema + '.') + table
cursor.execute('SELECT * FROM ' + actual_table)

rows = cursor.fetchmany()
if rows:
# write header
fields = [desc[0] for desc in cursor.description]
writer.writerow(fields)
while rows:
formatted_rows = [format_row(row) for row in rows]
writer.writerows(formatted_rows)
rows = cursor.fetchmany()
finally:
conn.close()
with open_output_file(filename) as fd:
writer = csv.writer(fd, delimiter=delimiter, quoting=csv.QUOTE_MINIMAL)
conn = self.get_connection(False)
try:
cursor = self.get_cursor(conn)
actual_table = ('' if schema is None else schema + '.') + table
cursor.execute('SELECT * FROM ' + actual_table)

rows = cursor.fetchmany()
if rows:
# write header
fields = [desc[0] for desc in cursor.description]
writer.writerow(fields)
while rows:
formatted_rows = [format_row(row) for row in rows]
writer.writerows(formatted_rows)
rows = cursor.fetchmany()
finally:
conn.close()


class DbManager:
Expand Down
38 changes: 13 additions & 25 deletions catdb/main.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,18 @@
import argparse
from contextlib import contextmanager
import json
import os
from pyhocon import ConfigFactory
import sys
from catdb import open_output_file, open_input_file
from catdb.db import DbManager


def main():

@contextmanager
def open_output_file(filename):
if filename == '-':
yield sys.stdout
else:
with open(filename, 'w') as fd:
yield fd

@contextmanager
def open_input_file(filename):
if filename == '-':
yield sys.stdin
else:
with open(filename, 'r') as fd:
yield fd

parent_parser = argparse.ArgumentParser(add_help=False)
parent_parser.add_argument('-d', '--database', help='database', required=True, action='store')
parent_parser.add_argument('-s', '--schema', help='schema', required=False, action='store', default=None)
parent_parser.add_argument('-t', '--table', help='table filter (using % as a wildcard)', required=False, action='store')
parent_parser.add_argument('-t', '--table', help='table filter (using % as a wildcard)', required=False,
action='store')
parent_parser.add_argument('-dr', '--dry-run', dest='dry_run', help='dry run', required=False, action='store_true')

argparser = argparse.ArgumentParser(description='export')
Expand All @@ -45,7 +29,9 @@ def open_input_file(filename):
home_dir = os.environ['HOME']
config_path = os.path.join(home_dir, '.catdb')
if not os.path.exists(config_path):
sys.stderr.write('File {config_path} not found. Go to https://github.com/chimpler/catdb for more details\n'.format(config_path=config_path))
sys.stderr.write(
'File {config_path} not found. Go to https://github.com/chimpler/catdb for more details\n'.format(
config_path=config_path))
sys.exit(1)

config = ConfigFactory.parse_file(config_path)
Expand All @@ -56,7 +42,8 @@ def open_input_file(filename):
print '\n'.join(db.list_tables(args.table, args.schema))
elif args.subparser_name == 'ddl':
if args.export_file:
ddl_str = json.dumps(db.get_ddl(args.table, args.schema), sort_keys=True, indent=config['ddl-format.indent'], separators=(',', ': '))
ddl_str = json.dumps(db.get_ddl(args.table, args.schema), sort_keys=True,
indent=config['ddl-format.indent'], separators=(',', ': '))
with open_output_file(args.export_file) as fd:
fd.write(ddl_str)
elif args.import_file:
Expand All @@ -69,12 +56,13 @@ def open_input_file(filename):
db.execute(table_statement)
elif args.subparser_name == 'data':
if args.export_file:
with open_output_file(args.export_file) as fd:
db.export_to_file(fd, args.table, args.schema, config['data-format.delimiter'], config['data-format.null'])
db.export_to_file(args.export_file, args.table, args.schema, config['data-format.delimiter'],
config['data-format.null'])

elif args.import_file:
with open_input_file(args.import_file) as fd:
db.import_from_file(fd, args.table, args.schema, config['data-format.delimiter'], config['data-format.null'])
db.import_from_file(args.import_fileport_file, args.table, args.schema, config['data-format.delimiter'],
config['data-format.null'])


if __name__ == '__main__':
main()
29 changes: 16 additions & 13 deletions catdb/postgres.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import psycopg2
from catdb import open_input_file, open_output_file
from catdb.db import Db


Expand All @@ -8,7 +9,7 @@ class Postgres(Db):
def __init__(self, params):
super(Postgres, self).__init__('postgres', params)

def __get_connect_params(self):
def _get_connect_params(self):
return {
'database': self._params['database'],
'host': self._params.get('hostname'),
Expand All @@ -17,9 +18,9 @@ def __get_connect_params(self):
'password': self._params.get('password', None)
}

def get_connection(self, use_dict_cursor=False, db=None):
def open_connection(self, use_dict_cursor=False, db=None):
params = {}
params.update(self.__get_connect_params())
params.update(self._get_connect_params())
if use_dict_cursor:
params.update({
'cursor_factory': 'psycopg2.extras.RealDictCursor'
Expand All @@ -30,18 +31,20 @@ def get_cursor(self, connection):
cursor = connection.cursor('cursor')
return cursor

def export_to_file(self, fd, table=None, schema=None, delimiter='|', null_value='\\N'):
with psycopg2.connect(**self.__get_connect_params()) as conn:
with conn.cursor() as cursor:
cursor.copy_to(fd, table=table, sep=delimiter, null=null_value)
def export_to_file(self, filename, table=None, schema=None, delimiter='|', null_value='\\N'):
with open_output_file(filename) as fd:
with psycopg2.connect(**self._get_connect_params()) as conn:
with conn.cursor() as cursor:
cursor.copy_to(fd, table=table, sep=delimiter, null=null_value)

def import_from_file(self, fd, table=None, schema=None, delimiter='|', null_value='\\N'):
with psycopg2.connect(**self.__get_connect_params()) as conn:
with conn.cursor() as cursor:
cursor.copy_from(fd, table=table, sep=delimiter, null=null_value)
def import_from_file(self, filename, table=None, schema=None, delimiter='|', null_value='\\N'):
with open_input_file(filename) as fd:
with psycopg2.connect(**self._get_connect_params()) as conn:
with conn.cursor() as cursor:
cursor.copy_from(fd, table=table, sep=delimiter, null=null_value)

def list_tables(self, filter=None, schema=None):
with psycopg2.connect(**self.__get_connect_params()) as conn:
with psycopg2.connect(**self._get_connect_params()) as conn:
with conn.cursor() as cursor:
cursor.execute(
"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema}' AND table_name LIKE '{filter}'".format(
Expand All @@ -53,7 +56,7 @@ def list_tables(self, filter=None, schema=None):
# http://stackoverflow.com/questions/2204058/list-columns-with-indexes-in-postgresql
# http://www.alberton.info/postgresql_meta_info.html#.VT2sIhPF-d4
def get_column_info(self, table, schema):
with psycopg2.connect(**self.__get_connect_params()) as conn:
with psycopg2.connect(**self._get_connect_params()) as conn:
with conn.cursor() as cursor:
cursor.execute(
"""SELECT
Expand Down
49 changes: 49 additions & 0 deletions catdb/redshift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import psycopg2
import time
from catdb.postgres import Postgres
from boto.s3.connection import S3Connection


class Redshift(Postgres):

def __init__(self, params):
super(Redshift, self).__init__(params)

def export_to_file(self, filename, table=None, schema=None, delimiter='|', null_value='\\N'):
"""TODO: Support explicit export to S3
:param filename:
:param table:
:param schema:
:param delimiter:
:param null_value:
:return:
"""
aws_config = self._params['aws']
key = aws_config['access_key_id']
secret = aws_config['secret_access_key']
bucket_name = aws_config['temp_bucket']
prefix = aws_config['temp_prefix']
conn = S3Connection(key, secret)
bucket = conn.get_bucket(bucket_name)

temp_file_prefix = 'catdb_{ts}'.format(ts=int(time.time() * 1000000))
s3_path_prefix = 's3://{bucket}/{prefix}/{file}'.format(
bucket=bucket_name,
prefix=prefix,
file=temp_file_prefix
)
s3_file = temp_file_prefix + '000'

with psycopg2.connect(**self._get_connect_params()) as conn:
with conn.cursor() as cursor:
cursor.execute(
"""
UNLOAD ('SELECT * FROM {schema}.{table}')
TO '{filename}'
CREDENTIALS 'aws_access_key_id={aws_key};aws_secret_access_key={aws_secret}'
PARALLEL OFF
""".format(schema=schema, table=table, filename=s3_path_prefix, aws_key=key, aws_secret=secret))

key = bucket.get_key('{prefix}/{file}'.format(prefix=prefix, file=s3_file))
key.get_contents_to_filename(filename)
key.delete()
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def run_tests(self):
install_requires=[
'pyhocon==0.3.1',
'psycopg2==2.6',
'PyMySQL==0.6.6'
'PyMySQL==0.6.6',
'boto==2.38.0'
] + (['importlib==1.0.3'] if sys.version_info[:2] == (2, 6) else []),
tests_require=[
'pytest',
Expand Down

0 comments on commit 8895619

Please sign in to comment.