Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use adapter functions in dbt seed #516

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 9 additions & 33 deletions dbt/seeder.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,19 @@
import os
import fnmatch
from csvkit import table as csv_table, sql as csv_sql
from sqlalchemy.dialects import postgresql as postgresql_dialect
from dbt.adapters.factory import get_adapter
import psycopg2

from dbt.source import Source
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.adapters.factory import get_adapter
import dbt.exceptions


class Seeder:
def __init__(self, project):
self.project = project
run_environment = self.project.run_environment()

def find_csvs(self):
return Source(self.project).get_csvs(self.project['data-paths'])

def drop_table(self, cursor, schema, table):
sql = 'drop table if exists "{schema}"."{table}" cascade'.format(
schema=schema, table=table
)
logger.info("Dropping table {}.{}".format(schema, table))
cursor.execute(sql)

def truncate_table(self, cursor, schema, table):
sql = 'truncate table "{schema}"."{table}"'.format(
schema=schema, table=table
)
logger.info("Truncating table {}.{}".format(schema, table))
cursor.execute(sql)

def create_table(self, cursor, schema, table, virtual_table):
sql_table = csv_sql.make_table(virtual_table, db_schema=schema)
create_table_sql = csv_sql.make_create_table_statement(
Expand Down Expand Up @@ -67,36 +49,30 @@ def quote_or_null(s):
.format(len(virtual_table.to_rows()), schema, table))
cursor.execute(insert_sql)

def existing_tables(self, cursor, schema):
sql = ("select tablename as name from pg_tables where "
"schemaname = '{schema}'".format(schema=schema))

cursor.execute(sql)
existing = set([row[0] for row in cursor.fetchall()])
return existing

def do_seed(self, schema, cursor, drop_existing):
existing_tables = self.existing_tables(cursor, schema)
profile = self.project.run_environment()
adapter = get_adapter(profile)

existing = adapter.query_for_existing(profile, schema)

csvs = self.find_csvs()
statuses = []
for csv in csvs:

table_name = csv.name
fh = open(csv.filepath)
virtual_table = csv_table.Table.from_csv(fh, table_name)

if table_name in existing_tables:
if table_name in existing:
if drop_existing:
self.drop_table(cursor, schema, table_name)
adapter.drop(profile, table_name, existing[table_name])
self.create_table(
cursor,
schema,
table_name,
virtual_table
)
else:
self.truncate_table(cursor, schema, table_name)
adapter.truncate(profile, table_name, table_name)
else:
self.create_table(cursor, schema, table_name, virtual_table)

Expand Down Expand Up @@ -125,7 +101,7 @@ def do_seed(self, schema, cursor, drop_existing):
def seed(self, drop_existing=False):
profile = self.project.run_environment()

if profile.get('type') == 'snowflake':
if profile.get('type') not in 'snowflake':
raise dbt.exceptions.NotImplementedException(
"`seed` operation is not supported for snowflake.")

Expand Down