Derek Deming made this script to create a class that allows us to easily connect to Snowflake

Update this script if you need to change connection settings (schema, database, warehouse, etc)

In [0]:
from pyspark.sql.types import *
from pyspark.sql.functions import *

In [0]:
class SnowflakeDataTool:
    def __init__(self):
        self._sfTables = []
        # Snowflake connection object
        self._sfConnection = None
        self.connect()
    #----- PUBLIC METHODS-----#
    def connect(self):
        """Connect to the Snowflake database
        :param schema(str): Optional parameter. Set the schema to use in queries.
        """
        options = self._connection_options()
        connObject = spark.read.format("snowflake").options(**options)
        self._sfConnection = connObject
        if not self.is_connected():
            print("Failed to establish a connection. Check your connection options.")
        else:
            self._query_table_list()
        return None
    def is_connected(self):
        """Verify if connected to Snowflake database"""
        try:
            query = "SELECT current_version()"
            _ = self._sfConnection.option("query",query).load()
            status = True
        except Exception:
            status = False
        return status
    def list_tables(self):
        """View the list of tables in the database"""
        return self._sfTables
    def query_column_metadata(self, tableName):
        """Retrieve a table's (tableName) column metadata"""
        OP = 'ORDINAL_POSITION'
        columns = ['COLUMN_NAME','DATA_TYPE',OP,'IS_NULLABLE']
        table = tableName.upper()
        sqlQuery = f"SELECT {','.join(columns)} FROM information_schema.columns WHERE table_name = '{table}' ORDER BY {OP}"
        column_data_df = self.sql(sqlQuery)
        return column_data_df
    def sql(self, query:str):
        """Execute a SQL query against the database"""
        try:
            results = self._sfConnection.option("query", query).load()
        except Exception as e:
            raise
        return results
    #----- PRIVATE METHODS -----#
    def _connection_options(self):
        """Internal function used to retrieve Snowflake connection options"""
        connOptions = {
            "sfUrl": 'https://swire.east-us-2.azure.snowflakecomputing.com',
            "sfUser": dbutils.secrets.get(scope="azure_kv_p_001", key="sf-p-user-001"),
            "sfPassword": dbutils.secrets.get(scope="azure_kv_p_001", key="sf-p-pwd-001"),
            "sfDatabase": "DB_SWIRE_BI_P_EDW",
            "sfSchema": "TRANSFORMED",
            "sfWarehouse": "SWIRE_WH_ELDAMO_M"
        }
        return connOptions
    def _query_table_list(self):
        query = "SELECT concat(t.TABLE_SCHEMA,'.',t.TABLE_NAME) AS SFTABLE FROM\
        information_schema.tables as t WHERE t.table_schema != 'INFORMATION_SCHEMA'"
        tableList = self.sql(query).collect()
        _ = [self._sfTables.append(t.asDict().get('SF_TABLE')) for t in tableList]
        return None