Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional extra_sasl_client_attrs to hive Connection #466

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 22 additions & 9 deletions pyhive/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import re
from decimal import Decimal
from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED, create_default_context

from typing import Dict, Optional

from TCLIService import TCLIService
from TCLIService import constants
Expand Down Expand Up @@ -49,7 +49,7 @@
}


def get_sasl_client(host, sasl_auth, service=None, username=None, password=None):
def get_sasl_client(host, sasl_auth, service=None, username=None, password=None, extra_sasl_client_attrs=None):
import sasl
sasl_client = sasl.Client()
sasl_client.setAttr('host', host)
Expand All @@ -62,11 +62,15 @@ def get_sasl_client(host, sasl_auth, service=None, username=None, password=None)
else:
raise ValueError("sasl_auth only supports GSSAPI and PLAIN")

if extra_sasl_client_attrs:
for k, v in extra_sasl_client_attrs.items():
sasl_client.setAttr(k, v)

sasl_client.init()
return sasl_client


def get_pure_sasl_client(host, sasl_auth, service=None, username=None, password=None):
def get_pure_sasl_client(host, sasl_auth, service=None, username=None, password=None, extra_sasl_client_attrs=None):
from pyhive.sasl_compat import PureSASLClient

if sasl_auth == 'GSSAPI':
Expand All @@ -75,17 +79,20 @@ def get_pure_sasl_client(host, sasl_auth, service=None, username=None, password=
sasl_kwargs = {'username': username, 'password': password}
else:
raise ValueError("sasl_auth only supports GSSAPI and PLAIN")


if extra_sasl_client_attrs:
sasl_kwargs.update(extra_sasl_client_attrs)

return PureSASLClient(host=host, **sasl_kwargs)


def get_installed_sasl(host, sasl_auth, service=None, username=None, password=None):
def get_installed_sasl(host, sasl_auth, service=None, username=None, password=None, extra_sasl_client_attrs: Optional[Dict[str, str]]=None):
try:
return get_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password)
return get_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password, extra_sasl_client_attrs=extra_sasl_client_attrs)
# The sasl library is available
except ImportError:
# Fallback to pure-sasl library
return get_pure_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password)
return get_pure_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password, extra_sasl_client_attrs=extra_sasl_client_attrs)


def _parse_timestamp(value):
Expand Down Expand Up @@ -159,7 +166,8 @@ def __init__(
password=None,
check_hostname=None,
ssl_cert=None,
thrift_transport=None
thrift_transport=None,
extra_sasl_client_attrs: Optional[Dict[str, str]] = None
):
"""Connect to HiveServer2

Expand All @@ -172,6 +180,7 @@ def __init__(
:param password: Use with auth='LDAP' or auth='CUSTOM' only
:param thrift_transport: A ``TTransportBase`` for custom advanced usage.
Incompatible with host, port, auth, kerberos_service_name, and password.
:param extra_sasl_client_attrs: Extra SASL client attributes.

The way to support LDAP and GSSAPI is originated from cloudera/Impyla:
https://github.com/cloudera/impyla/blob/255b07ed973d47a3395214ed92d35ec0615ebf62
Expand Down Expand Up @@ -250,7 +259,11 @@ def __init__(
# Password doesn't matter in NONE mode, just needs to be nonempty.
password = 'x'

self._transport = thrift_sasl.TSaslClientTransport(lambda: get_installed_sasl(host=host, sasl_auth=sasl_auth, service=kerberos_service_name, username=username, password=password), sasl_auth, socket)
self._transport = thrift_sasl.TSaslClientTransport(
lambda: get_installed_sasl(
host=host, sasl_auth=sasl_auth, service=kerberos_service_name, username=username, password=password, extra_sasl_client_attrs=extra_sasl_client_attrs
), sasl_auth, socket
)
else:
# All HS2 config options:
# https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2#SettingUpHiveServer2-Configuration
Expand Down