In [152]:
# This script collects statistics on various delta lake tables and determines if they need to be vacuumed
# It optimizes and Vacuum all the delta lake tables

**Configs for Delta lake Maintenance **

In [153]:
# Configs for Delta lake Maintenance 

# Decides the recent  timeperiod to retain history , Default 1 days (24 * 7 = 168 hours)
vacuumRetentionInHours = 168

# Low version count indicates the table has not been through any updates and does not need Maintenance(optimize and Vacuum)
minVersionCount = 3

# New tables within a certain time period will be ignore from Maintenance(optimize and Vacuum)
ignore_New_Tabledays = 7

In [None]:
# Configs for Delta lake execution

# By default threads is set equal to the number of executor cores. "threads" variable can be tuned to a particular number depending on requirements

import os
executorCores = (sc._jsc.sc().getExecutorMemoryStatus().keySet().size()-1)*os.cpu_count() 

min_Workers_Heavyprocess = executorCores / 2
# atleast 2 cores per thread will be provided

max_Workers_Lightprocess = executorCores
# atleast 1 core per thread will be provided

**List all Databases of Spark**

In [154]:
fullDatabaseList = spark.sql("Show Databases").rdd.map(lambda r : r[0]).collect()

**Blacklist databases** that dont need to be optimized through this script

In [155]:
# Blacklist databases that dont need to be optimized through this script
# ignoreDatabases = ['onehrsi','crm','aurora']
ignoreDatabases = []

dbList = [i for i in fullDatabaseList if i not in ignoreDatabases]

print(dbList)

In [156]:
# temp restriction for debugging (optional)

# dbList = [ 'genie', 'temp']

In [157]:
class TableDetail:
  def __init__(self, tableName, location, isDelta):     
        # Instance Variable    
        self.tableName = tableName
        self.location = location 
        self.isDelta = isDelta

**Key Method to run tasks in parallel**

In [158]:
import concurrent.futures

workerCount=4
exceptionDetails = []

def runParallelTasks(taskName, paramList, workerCount):   
    resultDict = {}

    with concurrent.futures.ThreadPoolExecutor(workerCount) as executor:
        taskResults = {executor.submit(taskName, param) : (param) for param in paramList}
        for f in concurrent.futures.as_completed(taskResults.keys()):
            result = f.result()
            if result:
                resultDict[taskResults[f]] = result      
    return resultDict

In [159]:
def flattenDictionaryValues(resultDict):
    flatTabList = []
    for k in resultDict.keys():
        for t2 in resultDict[k]:
            flatTabList.append(t2)
    return flatTabList

**Lists all tables in the Hive metastore**

In [160]:
def getAllTableNames(databaseName):
    #code to fetch tables and verify if delta
    print("Initiate get tables list for database: "+ databaseName)
    newTables=[]
    if databaseName == "default":
        return newTables
    
    tablesDF = spark.sql("Show tables from "+databaseName).collect()
    for t in tablesDF:
        if databaseName == "default":
            tName = t[1]
            newTables.append(tName)
        elif t[2] == False : 
            tName = databaseName + "." + t[1] 
            newTables.append(tName)  
        
        print( "TableFullName : " + databaseName+ "." + t[1]  )    
    return newTables

In [161]:
max_workers = max_Workers_Lightprocess

tResults = runParallelTasks(getAllTableNames, dbList, max_workers)

allTables = flattenDictionaryValues(tResults)
print("All Tables found : "+str(len(allTables)))

**Blacklist tables** that dont need to be optimized through this script

In [162]:
# Blacklist tables that dont need to be optimized through this script
# ignoreTables = ['ihr.tabfilesvw']
ignoreTables = []

allApprovedTables = [i for i in allTables if i not in ignoreTables]

# print(allApprovedTables)

In [163]:
import sys

deltaTablesDetails = []

def getDeltaDetails(tableName):
  query = "Describe detail "+ tableName
  result = False
  try:
    tableDetail = spark.sql(query).collect()
    if tableDetail[0].format == "delta" and tableDetail[0].id is not None :
      result = True 
      deltaTablesDetails.append(TableDetail(tableName, tableDetail[0].location, 'True') )
    else:
      result = False
  except Exception as e:
    exceptionInfo = sys.exc_info()[0]
    # Non delta tables throw exception. Uncomment below to log all errors

    # print("Exception occurred :" + tableName)
    # print(sys.exc_info()[0])
    # print(str(e))    
    # exceptionDetails.append((tableName,"TableDetailStage", str(e)))
    # print("Exception occurred :" + tableName +" : " + str(e))
    
  return result

In [164]:
max_workers=max_Workers_Lightprocess

deltaTablesDetails = []
tResults = runParallelTasks(getDeltaDetails, allApprovedTables, max_workers)


print("All Delta Tables found : "+str(len(tResults)))

In [165]:
class TableHistoryDetail:
  def __init__(self, tableName, version, timestamp):     
        # Instance Variable    
        self.tableName = tableName
        self.version = version 
        self.timestamp = timestamp

In [166]:
# deltaTablesDetails
tableHistoryDetails = []

def getTableHistory(deltaTableDetail):
    tabName = deltaTableDetail.tableName
    try:
        tabHistory = spark.sql("Describe history "+ tabName).collect()
        for thd in tabHistory:
            tableHistoryDetails.append(TableHistoryDetail(tabName, thd.version, thd.timestamp))
    except Exception as e:    
        exceptionInfo = sys.exc_info()[0]
        exceptionDetails.append((tabName,"TableHistoryVersionStage", str(e)))
        print("Exception occurred at gathering History :" + tabName)
    
    return True

**Get History and version of tables**

In [167]:
max_workers = max_Workers_Lightprocess
tableHistoryDetails = []
tResults = runParallelTasks(getTableHistory, deltaTablesDetails, max_workers)


In [168]:
from pyspark.sql.types import *

field = [StructField("tableName", StringType(), True), StructField("version",LongType(), True),StructField("timestamp", TimestampType(), True)]
schema = StructType(field)

cols = ['tableName','version', 'timestamp'] 

In [169]:
tableHistoryList = []

for k in tableHistoryDetails:
        tableHistoryList.append([k.tableName, k.version, k.timestamp])
    
thDF = spark.createDataFrame(tableHistoryList, cols)

In [170]:
thDF.createOrReplaceTempView("TabHistoryVw")

In [171]:
%%sql
CREATE OR REPLACE TEMPORARY VIEW TabVersionVW AS (
Select TableName, count(*) as VersionCount, min(Timestamp) as OldestVersionDate,  max(Timestamp) as MostRecentVersionDate  from TabHistoryVw
group by tableName
order by 2 desc, 3 ASC
)

In [172]:
%%sql
drop table if exists TabVersion;

In [173]:
%%sql
Create table if not exists TabVersion
Select * from TabVersionVW

In [174]:
class TableFiles:
  def __init__(self, tableName, location, fileCount, totalSizeInBytes, totalSizeInGBs):     
        # Instance Variable    
        self.tableName = tableName
        self.location = location 
        self.fileCount = fileCount
        self.totalSizeInBytes = totalSizeInBytes
        self.totalSizeInGBs = totalSizeInGBs

In [175]:
def get_dir_content(ls_path):
  baseFilesInfo = mssparkutils.fs.ls(ls_path)
  subdir_filesInfo = [get_dir_content(p.path) for p in baseFilesInfo if p.isDir() and p.path != ls_path]
  flat_subdir_FilesInfo = [p for subFileInfo in subdir_filesInfo for p in subFileInfo]
  return list(map(lambda p: p, baseFilesInfo)) + flat_subdir_FilesInfo

In [176]:
tabFiles = []
def getTableFileDetails(tabDetail):
    tabName = tabDetail.tableName
    tabLocation = tabDetail.location
    
    try:
        files =  get_dir_content(tabLocation)
        fileCount = len(files)
        s = 0
        for f in files:
            s = s + f.size
        totalSizeInBytes = float(s)
        totalSizeInGBs = float(s) / 1073741824
        tabFile = TableFiles(tabName, tabLocation, fileCount, totalSizeInBytes, totalSizeInGBs)
        print(tabName+" : "+ str(totalSizeInGBs) + "GB")
        tabFiles.append(tabFile)
    except Exception as e:    
        exceptionInfo = sys.exc_info()[0]
        exceptionDetails.append((tabName,"ListFilesInfoStage", str(e)))
        print("Exception occurred at gathering files info stage :" + tabName)

    return True

In [177]:
max_workers = max_Workers_Lightprocess
tabFiles = []

tResults = runParallelTasks(getTableFileDetails, deltaTablesDetails, max_workers)

In [178]:
tabFilesDF = spark.createDataFrame(tabFiles, None)
tabFilesDF.createOrReplaceTempView("tabFilesVW")

In [179]:
%%sql
CREATE OR REPLACE TEMPORARY VIEW DeltaLakeTablesVW AS 
(
Select TV.TableName, totalSizeInGBs, OldestversionDate, VersionCount, FileCount, location, totalSizeInBytes, MostRecentVersionDate
-- from TabVersionVW TV
from TabVersion TV
left Join tabFilesVW TF on TF.tableName = TV.tableName 
order by totalSizeInGBs desc, VersionCount desc
)

In [180]:
%%sql
Drop table if exists DeltaLakeTables

In [181]:
%%sql
Create table if not exists DeltaLakeTables
Select * from DeltaLakeTablesVW

In [182]:
%%sql
-- Display data statistics of Delta lake tables 
Select * from DeltaLakeTables
order by totalSizeInGBs desc, VersionCount desc

In [183]:
%%sql
-- Delta Lake Size for this ADB workspace
select sum(totalSizeInGBs) from DeltaLakeTables 

In [184]:
%%sql
-- Delta Lake Size for this ADB workspace that can be cleaned up with Vacuum scripts
Select sum(totalSizeInGBs) from DeltaLakeTables where VersionCount >= 7  and DateDiff(Current_TimeStamp(),OldestVersionDate) > 7 --  

In [185]:
scriptGenerateQuery = f"""Select distinct Concat('Optimize ', tableName) as optimizeScript, Concat('Vacuum ', tableName, ' retain {vacuumRetentionInHours} hours') As vacuumScript 
from DeltaLakeTables 
where VersionCount >= {minVersionCount}  and DateDiff(Current_TimeStamp(),OldestVersionDate) > {ignore_New_Tabledays} """

print(scriptGenerateQuery) 

In [186]:
runScripts = []

runScripts = spark.sql(scriptGenerateQuery). select("optimizeScript","vacuumScript").rdd.map(lambda r : (r[0],r[1])).collect()

In [187]:
def execOptimizeVacuum(query):
    print("Initiate optimize :  "+ query[0])
    spark.sql(query[0])

    print("completed optimize :  "+ query[0])

    spark.sql(query[1])
    print("completed vacuum :  "+ query[1])
    
    return True

In [188]:
max_workers = min_Workers_Heavyprocess

tResults = runParallelTasks(execOptimizeVacuum, runScripts, max_workers)

In [189]:
%%sql
-- Clean up with below script once done
drop table if exists TabVersion;
drop table if exists DeltaLakeTables;