In [None]:
import json
import requests
import re
from delta.tables import *
from pyspark.sql import functions as F


In [None]:
# getting keyvault name stored in spark configuration
keyvault_name = spark.conf.get('key_vault_name')
storage_account_name = spark.conf.get('storage_account_name')

# we can create the databricks secretScope with same name as our Azure Keyvault to which it is connected
tenant_id = dbutils.secret.get(keyvault_name,'tennt_id')

#set spark configuration for Azure Storage account connection using service principal
spark.conf.set("fs.azure.account.auth.type", "OAuth") 
spark.conf.set("fs.azure.account.oauth.provider.type", "org.apache.hadoop.fs.azurebfs.oauth2.ClientCredsTokenProvider") 
spark.conf.set("fs.azure.account.oauth2.client.id", dbutils.secrets.get(keyvault_name,'adls-client-id')) 
spark.conf.set("fs.azure.account.oauth2.client.secret", dbutils.secrets.get(keyvault_name,'adls-client-secret')) 
spark.conf.set("fs.azure.account.oauth2.client.endpoint", f"https://login.microsoftonline.com/{tenant_id}/oauth2/token")


In [None]:
from azure.identity import ClientSecretCredential
from azure.monitor.query import LogsQueryClient,LogsQueryStatus

class sp_cred:
    def __init__(self,tenant_id,client_id,client_secret):
        self.tenant_id = tenant_id
        self.client_id = client_id
        self.client_secret = client_secret
    def get_credential(self):
        credential = ClientSecretCredential(
            tenant_id=self.tenant_id,
            client_id = self.client_id,
            client_secret = self.client_secret
        )

        return credential

class Log_fetch(sp_cred):
    def __init__(self, tenant_id, client_id, client_secret,workspace_id,timespan):
        super().__init__(tenant_id, client_id, client_secret)
        self.timespan=timespan
        self.workspace_id = workspace_id

    def fetch_logs(self,query_):
        cred = super().get_credential()
        log_analytics_client = LogsQueryClient(cred)
        response = log_analytics_client.query_workspace(
            workspace_id = self.workspace_id,
            query = query_,
            timespan = self.timespan
        )
        if response.status == LogsQueryStatus.SUCCESS:
            data = response.tables
            return data
        else:
            return 'Response Fail'
        
    def fetch_log_gt_xuid_cnt(self,query_base):
        query_ =  query_base + """| count"""
        response_data = self.fetch_logs(query_)
        if response_data == 'Response Fail':
            return 'Response Failed, Please check and try again'
        else:
            for table in response_data:
                iteration_count = table.rows[0][0]
        return iteration_count


In [None]:
# Extracting max read id from the watermark table - for incrementally loading the data 
# Here maintaining watermark table in delta table

def get_max_uid(log_name,tab_name,env_name):
 if not(spark._jsparkSession.catalog().tableExists(tab_name)):
    max_last_uid = '0'
 else:
    max_last_uid = spark.sql(f"""select max(uid_max) as max_uid from {tab_name} where pipeline_name='{log_name}' and env='{env_name}'""").collect()[0][0]
    if max_last_uid is None:
        max_last_uid = '0'

 return max_last_uid

# COMMAND ----------

#UPSERT
def merge_into(max_uid_last,tab_name,env_name):
    if not(spark._jsparkSession.catalog().tableExists(tab_name)):
        query_ = f'''create table {tab_name}
        ( uid_max string,
        pipeline_name string,
        env string)
        using DELTA; '''
        # if not(spark._jsparkSession.catalog().tableExists(tab_name)):
        # query_ = f'''create table {tab_name}
        # ( uid_max string,
        # pipeline_name string,
        # env string)
        # using DELTA; '''
        spark.sql(query_)
    sink_table = DeltaTable.forName(spark, tab_name)
    (sink_table.alias('source').merge(
    max_uid_last.alias('updates'),'source.pipeline_name = updates.pipeline_name and source.env=updates.env'
    ).
    whenMatchedUpdate(set =
    {'uid_max':"updates.uid_max",
    'pipeline_name':"updates.pipeline_name",
    'env':"updates.env"})
    .whenNotMatchedInsert(values =
    {'uid_max':"updates.uid_max",
    'pipeline_name':"updates.pipeline_name",
    'env':"updates.env"})
    ).execute()
    print('Rows has been updated.')

