Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alexandre Ouellet's "Add minimal support for querying multiple catalogs" #403

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
15 changes: 15 additions & 0 deletions pyhive/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import datetime
import logging
import requests
import datetime
from requests.auth import HTTPBasicAuth
import os

Expand All @@ -41,6 +42,20 @@ def escape_datetime(self, item, format):
formatted = super(PrestoParamEscaper, self).escape_datetime(item, format, 3)
return "{} {}".format(_type, formatted)

def escape_item(self, item):
if isinstance(item, datetime.datetime):
return self.escape_datetime(item)
elif isinstance(item, datetime.date):
return self.escape_date(item)
else:
return super(PrestoParamEscaper, self).escape_item(item)

def escape_date(self, item):
return "date '{}'".format(item)

def escape_datetime(self, item):
return "timestamp '{}'".format(item)


_escaper = PrestoParamEscaper()

Expand Down
33 changes: 33 additions & 0 deletions pyhive/sqlalchemy_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
from sqlalchemy import exc
from sqlalchemy import types
from sqlalchemy import util

# TODO shouldn't use mysql type
from sqlalchemy.databases import mysql
from sqlalchemy.engine import default
from sqlalchemy.sql import compiler
from sqlalchemy.sql.compiler import SQLCompiler
from sqlalchemy.sql.expression import Alias

from pyhive import presto
from pyhive.common import UniversalSet
Expand Down Expand Up @@ -46,6 +48,37 @@ class PrestoCompiler(SQLCompiler):
def visit_char_length_func(self, fn, **kw):
return 'length{}'.format(self.function_argspec(fn, **kw))

def visit_column(self, column, add_to_result_map=None, include_table=True, **kwargs):
sql = super(PrestoCompiler, self).visit_column(
column, add_to_result_map, include_table, **kwargs
)
table = column.table
return self.__add_catalog(sql, table)

def visit_table(self, table, asfrom=False, iscrud=False, ashint=False,
fromhints=None, use_schema=True, **kwargs):
sql = super(PrestoCompiler, self).visit_table(
table, asfrom, iscrud, ashint, fromhints, use_schema, **kwargs
)
return self.__add_catalog(sql, table)

def __add_catalog(self, sql, table):
if table is None:
return sql

if isinstance(table, Alias):
return sql

if (
"presto" not in table.dialect_options
or "catalog" not in table.dialect_options["presto"]._non_defaults
):
return sql

catalog = table.dialect_options["presto"]._non_defaults["catalog"]
sql = "\"{catalog}\".{sql}".format(catalog=catalog, sql=sql)
return sql


class PrestoTypeCompiler(compiler.GenericTypeCompiler):
def visit_CLOB(self, type_, **kw):
Expand Down