Skip to content

Commit

Permalink
Merge pull request #43 from NielsZeilemaker/add_sql
Browse files Browse the repository at this point in the history
Expose location, clustered_by to dbt-spark
  • Loading branch information
jtcohen6 committed Feb 6, 2020
2 parents 2c31520 + f7f8d84 commit c1a53dd
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 2 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ The following configurations can be supplied to models run with the dbt-spark pl
| Option | Description | Required? | Example |
|---------|----------------------------------------------------|-------------------------|--------------------------|
| file_format | The file format to use when creating tables | Optional | `parquet` |

| location_root | The created table uses the specified directory to store its data. The table alias is appended to it. | Optional | `/mnt/root` |
| partition_by | Partition the created table by the specified columns. A directory is created for each partition. | Optional | `partition_1` |
| clustered_by | Each partition in the created table will be split into a fixed number of buckets by the specified columns. | Optional | `cluster_1` |
| buckets | The number of buckets to create while clustering | Required if `clustered_by` is specified | `8` |


**Incremental Models**
Expand Down
4 changes: 4 additions & 0 deletions dbt/adapters/spark/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ class SparkAdapter(SQLAdapter):
ConnectionManager = SparkConnectionManager
Relation = SparkRelation

AdapterSpecificConfigs = frozenset({"file_format", "location_root",
"partition_by", "clustered_by",
"buckets"})

@classmethod
def date_function(cls):
return 'CURRENT_TIMESTAMP()'
Expand Down
51 changes: 50 additions & 1 deletion dbt/include/spark/macros/adapters.sql
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,37 @@
{{ sql }}
{% endmacro %}


{% macro file_format_clause() %}
{%- set file_format = config.get('file_format', validator=validation.any[basestring]) -%}
{%- if file_format is not none %}
using {{ file_format }}
{%- endif %}
{%- endmacro -%}


{% macro location_clause() %}
{%- set location_root = config.get('location_root', validator=validation.any[basestring]) -%}
{%- set identifier = model['alias'] -%}
{%- if location_root is not none %}
location '{{ location_root }}/{{ identifier }}'
{%- endif %}
{%- endmacro -%}


{% macro comment_clause() %}
{%- set raw_persist_docs = config.get('persist_docs', {}) -%}

{%- if raw_persist_docs is mapping -%}
{%- set raw_relation = raw_persist_docs.get('relation', false) -%}
{%- if raw_relation -%}
comment '{{ model.description }}'
{% endif %}
{%- else -%}
{{ exceptions.raise_compiler_error("Invalid value provided for 'persist_docs'. Expected dict but got value: " ~ raw_persist_docs) }}
{% endif %}
{%- endmacro -%}

{% macro partition_cols(label, required=false) %}
{%- set cols = config.get('partition_by', validator=validation.any[list, basestring]) -%}
{%- if cols is not none %}
Expand All @@ -27,23 +51,48 @@
{%- endif %}
{%- endmacro -%}


{% macro clustered_cols(label, required=false) %}
{%- set cols = config.get('clustered_by', validator=validation.any[list, basestring]) -%}
{%- set buckets = config.get('buckets', validator=validation.any[int]) -%}
{%- if (cols is not none) and (buckets is not none) %}
{%- if cols is string -%}
{%- set cols = [cols] -%}
{%- endif -%}
{{ label }} (
{%- for item in cols -%}
{{ item }}
{%- if not loop.last -%},{%- endif -%}
{%- endfor -%}
) into {{ buckets }} buckets
{%- endif %}
{%- endmacro -%}


{% macro spark__create_table_as(temporary, relation, sql) -%}
{% if temporary -%}
{{ spark_create_temporary_view(relation, sql) }}
{%- else -%}
create table {{ relation }}
{{ file_format_clause() }}
{{ partition_cols(label="partitioned by") }}
{{ clustered_cols(label="clustered by") }}
{{ location_clause() }}
{{ comment_clause() }}
as
{{ sql }}
{%- endif %}
{%- endmacro -%}


{% macro spark__create_view_as(relation, sql) -%}
create view {{ relation }} as
create view {{ relation }}
{{ comment_clause() }}
as
{{ sql }}
{% endmacro %}


{% macro spark__get_columns_in_relation(relation) -%}
{% call statement('get_columns_in_relation', fetch_result=True) %}
describe {{ relation }}
Expand Down
124 changes: 124 additions & 0 deletions test/unit/test_macros.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import mock
import unittest
import re
from collections import defaultdict
from jinja2 import Environment, FileSystemLoader
from dbt.context.common import _add_validation


class TestSparkMacros(unittest.TestCase):

def setUp(self):
self.jinja_env = Environment(loader=FileSystemLoader('dbt/include/spark/macros'),
extensions=['jinja2.ext.do',])

self.config = {}

self.default_context = {}
self.default_context['validation'] = mock.Mock()
self.default_context['model'] = mock.Mock()
self.default_context['exceptions'] = mock.Mock()
self.default_context['config'] = mock.Mock()
self.default_context['config'].get = lambda key, default=None, **kwargs: self.config.get(key, default)


def __get_template(self, template_filename):
return self.jinja_env.get_template(template_filename, globals=self.default_context)


def __run_macro(self, template, name, temporary, relation, sql):
self.default_context['model'].alias = relation
value = getattr(template.module, name)(temporary, relation, sql)
return re.sub(r'\s\s+', ' ', value)


def test_macros_load(self):
self.jinja_env.get_template('adapters.sql')


def test_macros_create_table_as(self):
template = self.__get_template('adapters.sql')

self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'),
"create table my_table as select 1")


def test_macros_create_table_as_file_format(self):
template = self.__get_template('adapters.sql')


self.config['file_format'] = 'delta'
self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'),
"create table my_table using delta as select 1")


def test_macros_create_table_as_partition(self):
template = self.__get_template('adapters.sql')


self.config['partition_by'] = 'partition_1'
self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'),
"create table my_table partitioned by (partition_1) as select 1")


def test_macros_create_table_as_partitions(self):
template = self.__get_template('adapters.sql')


self.config['partition_by'] = ['partition_1', 'partition_2']
self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'),
"create table my_table partitioned by (partition_1,partition_2) as select 1")


def test_macros_create_table_as_cluster(self):
template = self.__get_template('adapters.sql')


self.config['clustered_by'] = 'cluster_1'
self.config['buckets'] = '1'
self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'),
"create table my_table clustered by (cluster_1) into 1 buckets as select 1")


def test_macros_create_table_as_clusters(self):
template = self.__get_template('adapters.sql')


self.config['clustered_by'] = ['cluster_1', 'cluster_2']
self.config['buckets'] = '1'
self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'),
"create table my_table clustered by (cluster_1,cluster_2) into 1 buckets as select 1")


def test_macros_create_table_as_location(self):
template = self.__get_template('adapters.sql')


self.config['location_root'] = '/mnt/root'
self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'),
"create table my_table location '/mnt/root/my_table' as select 1")


def test_macros_create_table_as_comment(self):
template = self.__get_template('adapters.sql')


self.config['persist_docs'] = {'relation': True}
self.default_context['model'].description = 'Description Test'
self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'),
"create table my_table comment 'Description Test' as select 1")


def test_macros_create_table_as_all(self):
template = self.__get_template('adapters.sql')

self.config['file_format'] = 'delta'
self.config['location_root'] = '/mnt/root'
self.config['partition_by'] = ['partition_1', 'partition_2']
self.config['clustered_by'] = ['cluster_1', 'cluster_2']
self.config['buckets'] = '1'
self.config['persist_docs'] = {'relation': True}
self.default_context['model'].description = 'Description Test'

self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'),
"create table my_table using delta partitioned by (partition_1,partition_2) clustered by (cluster_1,cluster_2) into 1 buckets location '/mnt/root/my_table' comment 'Description Test' as select 1")

0 comments on commit c1a53dd

Please sign in to comment.