Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add az cli auth #35

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,29 @@ client_id: clientid
client_secret: ActiveDirectoryIntegrated
```

##### CLI
Use the authentication of the Azure command line interface (CLI). First log in:

```bash
az login
```

Then, set `authentication` in `profiles.yml` to `CLI`:

```
authentication: CLI
```

An alternative route for using a service principal is:

```
az login --service-principal --username $CLIENTID --password $SECRET --tenant $TENANTID
```

This avoids storing a secret as plain text in `profiles.yml`.

Source: https://docs.microsoft.com/en-us/cli/azure/create-an-azure-service-principal-azure-cli#sign-in-using-a-service-principal

## Table Materializations
CTAS allows you to materialize tables with indices and distributions at creation time, which obviates the need for post-hooks to set indices.

Expand Down
203 changes: 142 additions & 61 deletions dbt/adapters/sqlserver/connections.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,21 @@
from contextlib import contextmanager

import pyodbc
import os
import time
import struct
import time
from contextlib import contextmanager
from dataclasses import dataclass
from itertools import chain, repeat
from typing import Callable, Mapping, Optional

import dbt.exceptions
import pyodbc
from dbt.adapters.base import Credentials
from dbt.adapters.sql import SQLConnectionManager
from azure.identity import DefaultAzureCredential

from dbt.logger import GLOBAL_LOGGER as logger

from dataclasses import dataclass
from typing import Optional
from azure.core.credentials import AccessToken
from azure.identity import AzureCliCredential, DefaultAzureCredential


def create_token(tenant_id, client_id, client_secret):
# bc DefaultAzureCredential will look in env variables
os.environ["AZURE_TENANT_ID"] = tenant_id
os.environ["AZURE_CLIENT_ID"] = client_id
os.environ["AZURE_CLIENT_SECRET"] = client_secret

token = DefaultAzureCredential().get_token("https://database.windows.net//.default")
# convert to byte string interspersed with the 1-byte
# TODO decide which is cleaner?
# exptoken=b''.join([bytes({i})+bytes(1) for i in bytes(token.token, "UTF-8")])
exptoken = bytes(1).join([bytes(i, "UTF-8") for i in token.token]) + bytes(1)
# make c object with bytestring length prefix
tokenstruct = struct.pack("=i", len(exptoken)) + exptoken

return tokenstruct
AZURE_CREDENTIAL_SCOPE = "https://database.windows.net//.default"


@dataclass
Expand Down Expand Up @@ -74,7 +59,7 @@ def _connection_keys(self):
# raise NotImplementedError
if self.windows_login is True:
self.authentication = "Windows Login"

return (
"server",
"database",
Expand All @@ -84,10 +69,101 @@ def _connection_keys(self):
"client_id",
"authentication",
"encrypt",
"trust_cert"
"trust_cert",
)


def convert_bytes_to_mswindows_byte_string(value: bytes) -> bytes:
"""
Convert bytes to a Microsoft windows byte string.

Parameters
----------
value : bytes
The bytes.

Returns
-------
out : bytes
The Microsoft byte string.
"""
encoded_bytes = bytes(chain.from_iterable(zip(value, repeat(0))))
return struct.pack("<i", len(encoded_bytes)) + encoded_bytes


def convert_access_token_to_mswindows_byte_string(token: AccessToken) -> bytes:
"""
Convert an access token to a Microsoft windows byte string.

Parameters
----------
token : AccessToken
The token.

Returns
-------
out : bytes
The Microsoft byte string.
"""
value = bytes(token.token, "UTF-8")
return convert_bytes_to_mswindows_byte_string(value)


def get_cli_access_token(credentials: SQLServerCredentials) -> AccessToken:
"""
Get an Azure access token using the CLI credentials

First login with:

```bash
az login
```

Parameters
----------
credentials: SQLServerConnectionManager
The credentials.

Returns
-------
out : AccessToken
Access token.
"""
_ = credentials
token = AzureCliCredential().get_token(AZURE_CREDENTIAL_SCOPE)
return token


def get_sp_access_token(credentials: SQLServerCredentials) -> AccessToken:
"""
Get an Azure access token using the SP credentials.

Parameters
----------
credentials : SQLServerCredentials
Credentials.

Returns
-------
out : AccessToken
The access token.
"""
# bc DefaultAzureCredential will look in env variables
os.environ["AZURE_TENANT_ID"] = credentials.tenant_id
os.environ["AZURE_CLIENT_ID"] = credentials.client_id
os.environ["AZURE_CLIENT_SECRET"] = credentials.client_secret

token = DefaultAzureCredential().get_token(AZURE_CREDENTIAL_SCOPE)
return token


AZURE_AUTH_FUNCTION_TYPE = Callable[[SQLServerCredentials], AccessToken]
AZURE_AUTH_FUNCTIONS: Mapping[str, AZURE_AUTH_FUNCTION_TYPE] = {
"ServicePrincipal": get_sp_access_token,
"CLI": get_cli_access_token,
}


class SQLServerConnectionManager(SQLConnectionManager):
TYPE = "sqlserver"
TOKEN = None
Expand Down Expand Up @@ -135,8 +211,9 @@ def open(cls, connection):
con_str.append(f"DRIVER={{{credentials.driver}}}")

if "\\" in credentials.host:
# if there is a backslash \ in the host name the host is a sql-server named instance
# in this case then port number has to be omitted
# if there is a backslash \ in the host name the host is a
# sql-server named instance in this case then port number has
# to be omitted
con_str.append(f"SERVER={credentials.host}")
else:
con_str.append(f"SERVER={credentials.host},{credentials.port}")
Expand All @@ -159,55 +236,53 @@ def open(cls, connection):
elif type_auth == "ActiveDirectoryMsi":
raise ValueError("ActiveDirectoryMsi is not supported yet")

elif type_auth == "ServicePrincipal":
app_id = getattr(credentials, "AppId", None)
app_secret = getattr(credentials, "AppSecret", None)

elif getattr(credentials, "windows_login", False):
con_str.append(f"trusted_connection=yes")
con_str.append("trusted_connection=yes")
elif type_auth == "sql":
con_str.append("Authentication=SqlPassword")
con_str.append(f"UID={{{credentials.UID}}}")
con_str.append(f"PWD={{{credentials.PWD}}}")

if not getattr(credentials, "encrypt", False):
con_str.append(f"Encrypt=yes")
con_str.append("Encrypt=yes")
if not getattr(credentials, "trust_cert", False):
con_str.append(f"TrustServerCertificate=yes")

con_str.append("TrustServerCertificate=yes")

con_str_concat = ";".join(con_str)

con_str_concat = ';'.join(con_str)

index = []
for i, elem in enumerate(con_str):
if 'pwd=' in elem.lower():
if "pwd=" in elem.lower():
index.append(i)

if len(index) !=0 :
con_str[index[0]]="PWD=***"

con_str_display = ';'.join(con_str)

logger.debug(f'Using connection string: {con_str_display}')
if len(index) != 0:
con_str[index[0]] = "PWD=***"

if type_auth != "ServicePrincipal":
handle = pyodbc.connect(con_str_concat, autocommit=True)
con_str_display = ";".join(con_str)

elif type_auth == "ServicePrincipal":
logger.debug(f"Using connection string: {con_str_display}")

# create token if it does not exist
if type_auth in AZURE_AUTH_FUNCTIONS.keys():
if cls.TOKEN is None:
tenant_id = getattr(credentials, "tenant_id", None)
client_id = getattr(credentials, "client_id", None)
client_secret = getattr(credentials, "client_secret", None)
azure_auth_function = AZURE_AUTH_FUNCTIONS[type_auth]
token = azure_auth_function(credentials)
cls.TOKEN = convert_access_token_to_mswindows_byte_string(
token
)

cls.TOKEN = create_token(tenant_id, client_id, client_secret)
# Source:
# https://docs.microsoft.com/en-us/sql/connect/odbc/using-azure-active-directory?view=sql-server-ver15#authenticating-with-an-access-token
SQL_COPT_SS_ACCESS_TOKEN = 1256

handle = pyodbc.connect(
con_str_concat, attrs_before={1256: cls.TOKEN}, autocommit=True
)
attrs_before = {SQL_COPT_SS_ACCESS_TOKEN: cls.TOKEN}
else:
attrs_before = {}

handle = pyodbc.connect(
con_str_concat,
attrs_before=attrs_before,
autocommit=True,
)

connection.state = "open"
connection.handle = handle
Expand Down Expand Up @@ -235,22 +310,28 @@ def add_commit_query(self):
# return self.add_query('COMMIT TRANSACTION', auto_begin=False)
pass

def add_query(self, sql, auto_begin=True, bindings=None, abridge_sql_log=False):
def add_query(
self, sql, auto_begin=True, bindings=None, abridge_sql_log=False
):

connection = self.get_thread_connection()

if auto_begin and connection.transaction_open is False:
self.begin()

logger.debug('Using {} connection "{}".'.format(self.TYPE, connection.name))
logger.debug(
'Using {} connection "{}".'.format(self.TYPE, connection.name)
)

with self.exception_handler(sql):
if abridge_sql_log:
logger.debug("On {}: {}....".format(connection.name, sql[0:512]))
logger.debug(
"On {}: {}....".format(connection.name, sql[0:512])
)
else:
logger.debug("On {}: {}".format(connection.name, sql))
pre = time.time()

cursor = connection.handle.cursor()

# pyodbc does not handle a None type binding!
Expand Down