diff --git a/pyhive/hive.py b/pyhive/hive.py index c128748..5ce65f1 100644 --- a/pyhive/hive.py +++ b/pyhive/hive.py @@ -13,7 +13,7 @@ import re from decimal import Decimal from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED, create_default_context - +from typing import Dict, Optional from TCLIService import TCLIService from TCLIService import constants @@ -49,7 +49,7 @@ } -def get_sasl_client(host, sasl_auth, service=None, username=None, password=None): +def get_sasl_client(host, sasl_auth, service=None, username=None, password=None, extra_sasl_client_attrs=None): import sasl sasl_client = sasl.Client() sasl_client.setAttr('host', host) @@ -62,11 +62,15 @@ def get_sasl_client(host, sasl_auth, service=None, username=None, password=None) else: raise ValueError("sasl_auth only supports GSSAPI and PLAIN") + if extra_sasl_client_attrs: + for k, v in extra_sasl_client_attrs.items(): + sasl_client.setAttr(k, v) + sasl_client.init() return sasl_client -def get_pure_sasl_client(host, sasl_auth, service=None, username=None, password=None): +def get_pure_sasl_client(host, sasl_auth, service=None, username=None, password=None, extra_sasl_client_attrs=None): from pyhive.sasl_compat import PureSASLClient if sasl_auth == 'GSSAPI': @@ -75,17 +79,20 @@ def get_pure_sasl_client(host, sasl_auth, service=None, username=None, password= sasl_kwargs = {'username': username, 'password': password} else: raise ValueError("sasl_auth only supports GSSAPI and PLAIN") - + + if extra_sasl_client_attrs: + sasl_kwargs.update(extra_sasl_client_attrs) + return PureSASLClient(host=host, **sasl_kwargs) -def get_installed_sasl(host, sasl_auth, service=None, username=None, password=None): +def get_installed_sasl(host, sasl_auth, service=None, username=None, password=None, extra_sasl_client_attrs: Optional[Dict[str, str]]=None): try: - return get_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password) + return get_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password, extra_sasl_client_attrs=extra_sasl_client_attrs) # The sasl library is available except ImportError: # Fallback to pure-sasl library - return get_pure_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password) + return get_pure_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password, extra_sasl_client_attrs=extra_sasl_client_attrs) def _parse_timestamp(value): @@ -159,7 +166,8 @@ def __init__( password=None, check_hostname=None, ssl_cert=None, - thrift_transport=None + thrift_transport=None, + extra_sasl_client_attrs: Optional[Dict[str, str]] = None ): """Connect to HiveServer2 @@ -172,6 +180,7 @@ def __init__( :param password: Use with auth='LDAP' or auth='CUSTOM' only :param thrift_transport: A ``TTransportBase`` for custom advanced usage. Incompatible with host, port, auth, kerberos_service_name, and password. + :param extra_sasl_client_attrs: Extra SASL client attributes. The way to support LDAP and GSSAPI is originated from cloudera/Impyla: https://github.com/cloudera/impyla/blob/255b07ed973d47a3395214ed92d35ec0615ebf62 @@ -250,7 +259,11 @@ def __init__( # Password doesn't matter in NONE mode, just needs to be nonempty. password = 'x' - self._transport = thrift_sasl.TSaslClientTransport(lambda: get_installed_sasl(host=host, sasl_auth=sasl_auth, service=kerberos_service_name, username=username, password=password), sasl_auth, socket) + self._transport = thrift_sasl.TSaslClientTransport( + lambda: get_installed_sasl( + host=host, sasl_auth=sasl_auth, service=kerberos_service_name, username=username, password=password, extra_sasl_client_attrs=extra_sasl_client_attrs + ), sasl_auth, socket + ) else: # All HS2 config options: # https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2#SettingUpHiveServer2-Configuration