Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose location, clustered_by to dbt-spark #43

Merged
merged 11 commits into from Feb 6, 2020
5 changes: 4 additions & 1 deletion README.md
Expand Up @@ -82,7 +82,10 @@ The following configurations can be supplied to models run with the dbt-spark pl
| Option | Description | Required? | Example |
|---------|----------------------------------------------------|-------------------------|--------------------------|
| file_format | The file format to use when creating tables | Optional | `parquet` |

| location_root | The created table uses the specified directory to store its data. The table alias is appended to it. | Optional | `/mnt/root` |
| partition_by | Partition the created table by the specified columns. A directory is created for each partition. | Optional | `partition_1` |
| clustered_by | Each partition in the created table will be split into a fixed number of buckets by the specified columns. | Optional | `cluster_1` |
| buckets | The number of buckets to create while clustering | Required if `clustered_by` is specified | `8` |


**Incremental Models**
Expand Down
4 changes: 4 additions & 0 deletions dbt/adapters/spark/impl.py
Expand Up @@ -16,6 +16,10 @@ class SparkAdapter(SQLAdapter):
ConnectionManager = SparkConnectionManager
Relation = SparkRelation

AdapterSpecificConfigs = frozenset({"file_format", "location_root",
"partition_by", "clustered_by",
"buckets"})

@classmethod
def date_function(cls):
return 'CURRENT_TIMESTAMP()'
Expand Down
51 changes: 50 additions & 1 deletion dbt/include/spark/macros/adapters.sql
Expand Up @@ -5,13 +5,37 @@
{{ sql }}
{% endmacro %}


{% macro file_format_clause() %}
{%- set file_format = config.get('file_format', validator=validation.any[basestring]) -%}
{%- if file_format is not none %}
using {{ file_format }}
{%- endif %}
{%- endmacro -%}


{% macro location_clause() %}
{%- set location_root = config.get('location_root', validator=validation.any[basestring]) -%}
{%- set identifier = model['alias'] -%}
{%- if location_root is not none %}
location '{{ location_root }}/{{ identifier }}'
{%- endif %}
{%- endmacro -%}


{% macro comment_clause() %}
{%- set raw_persist_docs = config.get('persist_docs', {}) -%}

{%- if raw_persist_docs is mapping -%}
{%- set raw_relation = raw_persist_docs.get('relation', false) -%}
{%- if raw_relation -%}
comment '{{ model.description }}'
{% endif %}
{%- else -%}
{{ exceptions.raise_compiler_error("Invalid value provided for 'persist_docs'. Expected dict but got value: " ~ raw_persist_docs) }}
{% endif %}
{%- endmacro -%}

{% macro partition_cols(label, required=false) %}
{%- set cols = config.get('partition_by', validator=validation.any[list, basestring]) -%}
{%- if cols is not none %}
Expand All @@ -27,23 +51,48 @@
{%- endif %}
{%- endmacro -%}


{% macro clustered_cols(label, required=false) %}
{%- set cols = config.get('clustered_by', validator=validation.any[list, basestring]) -%}
{%- set buckets = config.get('buckets', validator=validation.any[int]) -%}
{%- if (cols is not none) and (buckets is not none) %}
{%- if cols is string -%}
{%- set cols = [cols] -%}
{%- endif -%}
{{ label }} (
{%- for item in cols -%}
{{ item }}
{%- if not loop.last -%},{%- endif -%}
{%- endfor -%}
) into {{ buckets }} buckets
{%- endif %}
{%- endmacro -%}


{% macro spark__create_table_as(temporary, relation, sql) -%}
{% if temporary -%}
{{ spark_create_temporary_view(relation, sql) }}
{%- else -%}
create table {{ relation }}
{{ file_format_clause() }}
{{ partition_cols(label="partitioned by") }}
{{ clustered_cols(label="clustered by") }}
{{ location_clause() }}
{{ comment_clause() }}
as
{{ sql }}
{%- endif %}
{%- endmacro -%}


{% macro spark__create_view_as(relation, sql) -%}
create view {{ relation }} as
create view {{ relation }}
{{ comment_clause() }}
as
{{ sql }}
{% endmacro %}


{% macro spark__get_columns_in_relation(relation) -%}
{% call statement('get_columns_in_relation', fetch_result=True) %}
describe {{ relation }}
Expand Down
124 changes: 124 additions & 0 deletions test/unit/test_macros.py
@@ -0,0 +1,124 @@
import mock
import unittest
import re
from collections import defaultdict
from jinja2 import Environment, FileSystemLoader
from dbt.context.common import _add_validation


class TestSparkMacros(unittest.TestCase):

def setUp(self):
self.jinja_env = Environment(loader=FileSystemLoader('dbt/include/spark/macros'),
extensions=['jinja2.ext.do',])

self.config = {}

self.default_context = {}
self.default_context['validation'] = mock.Mock()
self.default_context['model'] = mock.Mock()
self.default_context['exceptions'] = mock.Mock()
self.default_context['config'] = mock.Mock()
self.default_context['config'].get = lambda key, default=None, **kwargs: self.config.get(key, default)


def __get_template(self, template_filename):
return self.jinja_env.get_template(template_filename, globals=self.default_context)


def __run_macro(self, template, name, temporary, relation, sql):
self.default_context['model'].alias = relation
value = getattr(template.module, name)(temporary, relation, sql)
return re.sub(r'\s\s+', ' ', value)


def test_macros_load(self):
self.jinja_env.get_template('adapters.sql')


def test_macros_create_table_as(self):
template = self.__get_template('adapters.sql')

self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'),
"create table my_table as select 1")


def test_macros_create_table_as_file_format(self):
template = self.__get_template('adapters.sql')


self.config['file_format'] = 'delta'
self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'),
"create table my_table using delta as select 1")


def test_macros_create_table_as_partition(self):
template = self.__get_template('adapters.sql')


self.config['partition_by'] = 'partition_1'
self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'),
"create table my_table partitioned by (partition_1) as select 1")


def test_macros_create_table_as_partitions(self):
template = self.__get_template('adapters.sql')


self.config['partition_by'] = ['partition_1', 'partition_2']
self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'),
"create table my_table partitioned by (partition_1,partition_2) as select 1")


def test_macros_create_table_as_cluster(self):
template = self.__get_template('adapters.sql')


self.config['clustered_by'] = 'cluster_1'
self.config['buckets'] = '1'
self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'),
"create table my_table clustered by (cluster_1) into 1 buckets as select 1")


def test_macros_create_table_as_clusters(self):
template = self.__get_template('adapters.sql')


self.config['clustered_by'] = ['cluster_1', 'cluster_2']
self.config['buckets'] = '1'
self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'),
"create table my_table clustered by (cluster_1,cluster_2) into 1 buckets as select 1")


def test_macros_create_table_as_location(self):
template = self.__get_template('adapters.sql')


self.config['location_root'] = '/mnt/root'
self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'),
"create table my_table location '/mnt/root/my_table' as select 1")


def test_macros_create_table_as_comment(self):
template = self.__get_template('adapters.sql')


self.config['persist_docs'] = {'relation': True}
self.default_context['model'].description = 'Description Test'
self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'),
"create table my_table comment 'Description Test' as select 1")


def test_macros_create_table_as_all(self):
template = self.__get_template('adapters.sql')

self.config['file_format'] = 'delta'
self.config['location_root'] = '/mnt/root'
self.config['partition_by'] = ['partition_1', 'partition_2']
self.config['clustered_by'] = ['cluster_1', 'cluster_2']
self.config['buckets'] = '1'
self.config['persist_docs'] = {'relation': True}
self.default_context['model'].description = 'Description Test'

self.assertEqual(self.__run_macro(template, 'spark__create_table_as', False, 'my_table', 'select 1'),
"create table my_table using delta partitioned by (partition_1,partition_2) clustered by (cluster_1,cluster_2) into 1 buckets location '/mnt/root/my_table' comment 'Description Test' as select 1")