Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release 1.0.4 #7

Merged
merged 8 commits into from
Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,12 @@
## v1.0.4

### Bugfixes
* Add support for partition fields of type timestamp
* Use correct escaper for INSERT queries
* Share same boto session between every calls

### Features
* Get model owner from manifest

## v1.0.3
* Fix issue on fetching partitions from glue, using pagination
12 changes: 6 additions & 6 deletions dbt/adapters/athena/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from tenacity.stop import stop_after_attempt
from tenacity.wait import wait_exponential

from dbt.adapters.athena.session import get_boto3_session

logger = AdapterLogger("Athena")


Expand All @@ -51,7 +53,8 @@ def unique_field(self):
return self.host

def _connection_keys(self) -> Tuple[str, ...]:
return "s3_staging_dir", "work_group", "region_name", "database", "schema", "poll_interval", "aws_profile_name", "endpoing_url"
return "s3_staging_dir", "work_group", "region_name", "database", "schema", "poll_interval", \
"aws_profile_name", "endpoing_url"


class AthenaCursor(Cursor):
Expand Down Expand Up @@ -140,13 +143,12 @@ def open(cls, connection: Connection) -> Connection:
handle = AthenaConnection(
s3_staging_dir=creds.s3_staging_dir,
endpoint_url=creds.endpoint_url,
region_name=creds.region_name,
schema_name=creds.schema,
work_group=creds.work_group,
cursor_class=AthenaCursor,
formatter=AthenaParameterFormatter(),
poll_interval=creds.poll_interval,
profile_name=creds.aws_profile_name,
session=get_boto3_session(connection),
retry_config=RetryConfig(
attempt=creds.num_retries,
exceptions=(
Expand Down Expand Up @@ -213,9 +215,7 @@ def format(
raise ProgrammingError("Query is none or empty.")
operation = operation.strip()

if operation.upper().startswith("SELECT") or operation.upper().startswith(
"WITH"
):
if operation.upper().startswith(("SELECT", "WITH", "INSERT")):
escaper = _escape_presto
else:
# Fixes ParseException that comes with newer version of PyAthena
Expand Down
48 changes: 34 additions & 14 deletions dbt/adapters/athena/impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import agate
import re
import boto3
from botocore.exceptions import ClientError
from itertools import chain
from threading import Lock
Expand All @@ -20,6 +19,7 @@

boto3_client_lock = Lock()


class AthenaAdapter(SQLAdapter):
ConnectionManager = AthenaConnectionManager
Relation = AthenaRelation
Expand Down Expand Up @@ -61,8 +61,8 @@ def clean_up_partitions(
client = conn.handle

with boto3_client_lock:
glue_client = boto3.client('glue', region_name=client.region_name)
s3_resource = boto3.resource('s3', region_name=client.region_name)
glue_client = client.session.client('glue', region_name=client.region_name)
s3_resource = client.session.resource('s3', region_name=client.region_name)
paginator = glue_client.get_paginator("get_partitions")
partition_params = {
"DatabaseName": database_name,
Expand All @@ -74,7 +74,8 @@ def clean_up_partitions(
partitions = partition_pg.build_full_result().get('Partitions')
s3_rg = re.compile('s3://([^/]*)/(.*)')
for partition in partitions:
logger.debug("Deleting objects for partition '{}' at '{}'", partition["Values"], partition["StorageDescriptor"]["Location"])
logger.debug("Deleting objects for partition '{}' at '{}'", partition["Values"],
partition["StorageDescriptor"]["Location"])
m = s3_rg.match(partition["StorageDescriptor"]["Location"])
if m is not None:
bucket_name = m.group(1)
Expand All @@ -90,7 +91,7 @@ def clean_up_table(
conn = self.connections.get_thread_connection()
client = conn.handle
with boto3_client_lock:
glue_client = boto3.client('glue', region_name=client.region_name)
glue_client = client.session.client('glue', region_name=client.region_name)
try:
table = glue_client.get_table(
DatabaseName=database_name,
Expand All @@ -108,7 +109,7 @@ def clean_up_table(
if m is not None:
bucket_name = m.group(1)
prefix = m.group(2)
s3_resource = boto3.resource('s3', region_name=client.region_name)
s3_resource = client.session.resource('s3', region_name=client.region_name)
s3_bucket = s3_resource.Bucket(bucket_name)
s3_bucket.objects.filter(Prefix=prefix).delete()

Expand All @@ -118,13 +119,33 @@ def quote_seed_column(
) -> str:
return super().quote_seed_column(column, False)

def _join_catalog_table_owners(self, table: agate.Table, manifest: Manifest) -> agate.Table:
owners = []
# Get the owner for each model from the manifest
for node in manifest.nodes.values():
if node.resource_type == "model":
owners.append({
"table_database": node.database,
"table_schema": node.schema,
"table_name": node.alias,
"table_owner": node.config.meta.get("owner"),
})
owners_table = agate.Table.from_object(owners)

# Join owners with the results from catalog
join_keys = ["table_database", "table_schema", "table_name"]
return table.join(
right_table=owners_table,
left_key=join_keys,
right_key=join_keys,
)

def _get_one_catalog(
self,
information_schema: InformationSchema,
schemas: Dict[str, Optional[Set[str]]],
manifest: Manifest,
) -> agate.Table:

kwargs = {"information_schema": information_schema, "schemas": schemas}
table = self.execute_macro(
GET_CATALOG_MACRO_NAME,
Expand All @@ -134,9 +155,8 @@ def _get_one_catalog(
manifest=manifest,
)

results = self._catalog_filter_table(table, manifest)
return results

filtered_table = self._catalog_filter_table(table, manifest)
return self._join_catalog_table_owners(filtered_table, manifest)

def _get_catalog_schemas(self, manifest: Manifest) -> AthenaSchemaSearchMap:
info_schema_name_map = AthenaSchemaSearchMap()
Expand All @@ -155,8 +175,8 @@ def _get_data_catalog(self, catalog_name):
conn = self.connections.get_thread_connection()
client = conn.handle
with boto3_client_lock:
athena_client = boto3.client('athena', region_name=client.region_name)
athena_client = client.session.client('athena', region_name=client.region_name)

response = athena_client.get_data_catalog(Name=catalog_name)
return response['DataCatalog']

Expand All @@ -175,15 +195,15 @@ def list_relations_without_caching(
conn = self.connections.get_thread_connection()
client = conn.handle
with boto3_client_lock:
glue_client = boto3.client('glue', region_name=client.region_name)
glue_client = client.session.client('glue', region_name=client.region_name)
paginator = glue_client.get_paginator('get_tables')

kwargs = {
'DatabaseName': schema_relation.schema,
}
# If the catalog is `awsdatacatalog` we don't need to pass CatalogId as boto3 infers it from the account Id.
if catalog_id:
kwargs['CatalogId'] = catalog_id
kwargs['CatalogId'] = catalog_id
page_iterator = paginator.paginate(**kwargs)

relations = []
Expand Down
5 changes: 3 additions & 2 deletions dbt/adapters/athena/query_headers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import dbt.adapters.base.query_headers


class _QueryComment(dbt.adapters.base.query_headers._QueryComment):
"""
Athena DDL does not always respect /* ... */ block quotations.
This function is the same as _QueryComment.add except that
Athena DDL does not always respect /* ... */ block quotations.
This function is the same as _QueryComment.add except that
a leading "-- " is prepended to the query_comment and any newlines
in the query_comment are replaced with " ". This allows the default
query_comment to be added to `create external table` statements.
Expand Down
3 changes: 2 additions & 1 deletion dbt/adapters/athena/relation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Dict, Optional, Set
from typing import Dict, Optional, Set

from dbt.adapters.base.relation import BaseRelation, InformationSchema, Policy

Expand All @@ -16,6 +16,7 @@ class AthenaRelation(BaseRelation):
quote_character: str = ""
include_policy: Policy = AthenaIncludePolicy()


class AthenaSchemaSearchMap(Dict[InformationSchema, Dict[str, Set[Optional[str]]]]):
"""A utility class to keep track of what information_schema tables to
search for what schemas and relations. The schema and relation values are all
Expand Down
25 changes: 25 additions & 0 deletions dbt/adapters/athena/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Optional

import boto3.session
from dbt.contracts.connection import Connection


__BOTO3_SESSION__: Optional[boto3.session.Session] = None


def get_boto3_session(connection: Connection = None) -> boto3.session.Session:
def init_session():
global __BOTO3_SESSION__
__BOTO3_SESSION__ = boto3.session.Session(
region_name=connection.credentials.region_name,
profile_name=connection.credentials.aws_profile_name,
)

if not __BOTO3_SESSION__:
if connection is None:
raise RuntimeError(
'A Connection object needs to be passed to initialize the boto3 session for the first time'
)
init_session()

return __BOTO3_SESSION__
4 changes: 1 addition & 3 deletions dbt/include/athena/macros/adapters/metadata.sql
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
else table_type
end as table_type,

null as table_owner,
null as table_comment

from {{ information_schema }}.tables
Expand Down Expand Up @@ -54,8 +53,7 @@
columns.column_name,
columns.column_index,
columns.column_type,
columns.column_comment,
tables.table_owner
columns.column_comment

from tables
join columns
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
{%- set value = "'" + col + "'" -%}
{%- elif column_type == 'date' -%}
{%- set value = "'" + col|string + "'" -%}
{%- elif column_type == 'timestamp' -%}
{%- set value = "'" + col|string + "'" -%}
{%- else -%}
{%- do exceptions.raise_compiler_error('Need to add support for column type ' + column_type) -%}
{%- endif -%}
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
package_name = "dbt-athena-community"

dbt_version = "1.0"
package_version = "1.0.3"
package_version = "1.0.4"
description = """The athena adapter plugin for dbt (data build tool)"""

if not package_version.startswith(dbt_version):
Expand Down