In [None]:
import math
from airflow.hooks.mysql_hook import MySqlHook
import logging
from datetime import datetime,timezone
import boto3
from botocore.exceptions import ClientError
import concurrent.futures


logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S')
root = logging.getLogger()
root.setLevel(logging.INFO)


class RestoreTables:
    __schema_name = None
    __table_name = None
    msc = None
    
    def __init__(self, table, msc, client):
        self.__schema_name = table.split('.')[0]
        self.__table_name_name = table.split('.')[1]
        self.msc = msc
        self.s3_client = client
        self.s3_bucket = None
        
        
    def get_partitions(self):
        get_partition_query = """
        select PART_NAME FROM hive.PARTITIONS WHERE 
        TBL_ID=(SELECT tbl_id FROM hive.TBLS t, hive.DBS d WHERE t.TBL_NAME=%s and t.DB_ID=d.DB_ID and d.NAME=%s);"""
        param = (self.__table_name_name,self.__schema_name)
        return self.msc.get_pandas_df(get_partition_query, param)
    
    def get_table_location(self):
        get_table_location = """
        SELECT location FROM hive.TBLS t JOIN hive.DBS d ON t.DB_ID = d.DB_ID JOIN hive.SDS s ON t.SD_ID = s.SD_ID
        WHERE TBL_NAME = %s AND d.NAME= %s"""
        param = (self.__table_name_name,self.__schema_name)
        return self.msc.get_pandas_df(get_table_location, param)
    
    def get_versioned_obj(self):
        restore_list = []
        partition_df = self.get_partitions()
        location_df = self.get_table_location()
        print("Location:",location_df['location'][0])
        self.s3_bucket = location_df['location'][0].split('/')[2]
        starting_prefix = location_df['location'][0].split('/')[3]
        total_partitions = len(partition_df.index)
        for row, record in partition_df.iterrows():
            prefix = starting_prefix +'/'+ record['PART_NAME'] 
            try:
                response = self.s3_client.list_object_versions(Bucket=self.s3_bucket,Prefix=prefix)
            except Exception as e:
                print(e)
                print('\nError for: {0}'.format(response))
                raise e
            else:        
                if 'DeleteMarkers' in response:
                    logging.info('Processing {0}/{1}:{2}'.format(row+1, total_partitions, prefix ))
                    for markers in response['DeleteMarkers']:
                        if markers['IsLatest']:
                            restore_list.append((markers['Key'],markers['VersionId']))
        return restore_list      
    
    def restore_objects(self,s3_obj):
        logging.info("Restoring object:",s3_obj)
        key = s3_obj[0]
        versionid = s3_obj[1]
        response = self.s3_client.delete_object(Bucket=self.s3_bucket,Key=key,VersionId=versionid)
        return response
        

            
msc = MySqlHook('etl_metastore')
s3_client = boto3.client('s3')
table_list = ['company.cust_orders']
restore_list=[]

# response =s3_client.list_object_versions(Bucket='xst_hive-prd-data', Prefix= 'cust_orders/date_key=20160618/ver=1')
# if 'DeleteMarkers' in response:
#     for markers in response['DeleteMarkers']:
#         if markers['IsLatest']:
#             print("Key:",markers['Key'])
#             print("Version:",markers['VersionId'])
# else:
#     print("No versioned data")

for table in table_list:
    rt = RestoreTables(table,msc,s3_client)
    list_restore = rt.get_versioned_obj()
    if list_restore:
        logging.info("Objects to be restored:",len(list_restore))
        with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
            executor.map(rt.restore_objects, list_restore)
#         for obj in list_restore:
#             rt.restore_objects(obj)    
    
