In [None]:
%sh
pip install adal

In [None]:
# Import packages
import adal
import datetime

In [None]:
# Set up database access
class sp_mssql_access:
    
    def __init__(self, tenant_id, authority, resource_app_id_url, service_principal_id, service_principal_secret, sql_hostname, database_name, database_port):
        
        self.tenant_id = tenant_id
        self.authority = authority
        self.resource_app_id_url = resource_app_id_url
        self.service_principal_id = service_principal_id
        self.service_principal_secret = service_principal_secret
        self.sql_hostname = sql_hostname
        self.database_name = database_name
        self.database_port = database_port
        self.jdbcUrl = "jdbc:sqlserver://{0}:{1};database={2}".format(self.sql_hostname, self.database_port, self.database_name)
        self.token = None
        
    def _gen_token(self):
        context = adal.AuthenticationContext(self.authority)
        self.token = context.acquire_token_with_client_credentials(self.resource_app_id_url, self.service_principal_id, self.service_principal_secret)
        self.token['expiresOn'] = datetime.datetime.now() + datetime.timedelta(0, self.token['expiresIn'])
        
    def __getattribute__(self, name):
        if name == 'token':
            if super().__getattribute__(name) is None: self._gen_token()
            return super().__getattribute__(name)
        
        elif name == 'connectionProperties':
            if (self.token['expiresOn'] - datetime.timedelta(0, 1800)) <= datetime.datetime.now():
                self._gen_token()
            return {
                "accessToken": self.token['accessToken'],
                "driver" : "com.microsoft.sqlserver.jdbc.SQLServerDriver"
            }    
        else:
            return super().__getattribute__(name)

In [None]:
# connection arguments
connect_arg = {
  "user_tenant_id": "your_user_tenant_id",
  "user_authority": "your_user_authority_url",
  "user_resource_app_id_url": "https://database.windows.net/",
  "user_service_principal_id_scope": "your_user_service_principal_id_scope",
  "user_service_principal_id_key": "your_user_service_principal_id_key", 
  "user_service_principal_secret_scope": "your_user_service_principal_secret_scope",
  "user_service_principal_secret_key": "your_user_service_principal_secret_key",
  "user_sql_hostname": "your_user_sql_hostname",
  "user_database_name": "your_user_database_name",
  "user_database_port": "your_user_database_port"
}

In [None]:
## db
db_access = sp_mssql_access(
    tenant_id = connect_arg.get('user_tenant_id'), 
    authority = connect_arg.get("user_authority"), 
    resource_app_id_url = connect_arg.get("user_resource_app_id_url"), 
    service_principal_id = dbutils.secrets.get(scope=connect_arg.get("user_service_principal_id_scope"), key=connect_arg.get("user_service_principal_id_key")), 
    service_principal_secret = dbutils.secrets.get(scope=connect_arg.get("user_service_principal_secret_scope"), key=connect_arg.get("user_service_principal_secret_key")), 
    sql_hostname = connect_arg.get("user_sql_hostname"), 
    database_name = connect_arg.get("user_database_name"), 
    database_port = connect_arg.get("user_database_port")
)

In [None]:
# check accessible tables
pushdown_query = "(select * from INFORMATION_SCHEMA.TABLES) myalias"
accessible_table_list = spark.read\
.jdbc(
    url = db_access.jdbcUrl, 
    table=pushdown_query, 
    properties = db_access.connectionProperties
)
print('DB Connection is set up!')