In [None]:
# needed libararies 
import os
import re
import argparse
import logging 
import time
from google.oauth2 import service_account
from google.cloud import bigquery
from dict_helpers import schema_yellow_dict2009, schema_yellow_dict2010, schema_yellow_dict, schema_green_dict, schema_fhv_dict, schema_fhvhv_dict
# from dict_helpers import schema_yellow_dict2009_types, schema_yellow_dict2010_types, schema_yellow_dict_types, schema_green_dict_types, schema_fhv_dict_types, schema_fhvhv_dict_types
from datetime import datetime
from dateutil.relativedelta import relativedelta
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from py4j.protocol import Py4JJavaError

def parquet_downloaded(data):
    if type(data) != str:
        return True
    else:
        return False 
        
def parquet_not_downloaded(data):
    if data == None:
        return True
    else: 
        return False

def abort_pipeline():
    print("faulty parquet file downloaded, iteration run aborted")

################################################local######################################################################
def load_trip_data(spark, url, filename):
    
    # down load parquet
    print(f'fetching {filename} from {url}')
    os.system(f'curl -O {url}') 
    
    try:
        # read parquet into environment 
        df = spark.read \
            .option("header", "true") \
            .parquet(filename)
        return df
    except Py4JJavaError as e:
        os.system(f"rm -r {filename}")
        print(f"{filename} not a valid parquet file, aborted pipeline!!!!") 
        return None

#############################################dataproc######################################################################
# def load_trip_data(spark, url, filename, bucket_name):
    
#     # down load parquet
#     print(f'fetching {filename} from {url}')
#     os.system(f'curl -O {url}') 
#     os.system(f'gsutil -m cp {filename} gs://original_parquets_url/')
    
#     try:
#         # read parquet into environment 
#         df = spark.read \
#             .option("header", "true") \
#             .parquet(f"gs://original_parquets_url/{filename}")
#         print(f'loaded gs://original_parquets_url/{filename} in spark session ')
#         return df
#     except Py4JJavaError as e:
#         # os.system(f"rm -r {filename}")
#         print(f"aborted pipeline for gs://original_parquets_url/{filename} due to error")
#         return None
###########################################################################################################################

def dimension_name_cleanup(df):
    print(f"Spark DF currently has the following columns: {', '.join(df.columns)}")
    # cleanup column names 
    for i in range(len(df.columns)):
        col_name = df.columns[i]
        col_name_new = re.sub('(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z]{1}[a-z])', '_', col_name).lower()
        df = df.withColumnRenamed(col_name, col_name_new) 
    print(f"Spark DF has been updated with the following columns: {', '.join(df.columns)}")
    return df

###########################################local##########################################################################

def data_2_gcp_cloud_storage(df, table_name, year_month, root_path, filename):
    # repartition df for export 
    pickup_col = [col for col in df.columns if 'pickup' in col][0]
    col_name = [re.sub('datetime|date_time', 'date', col) for col in df.columns if 'pickup' in col][0]
    
    df = df.withColumn('parquet_source_path', F.lit(root_path)) \
            .alias('a') \
            .select('a.*', \
                    F.date_trunc('day', f'a.{pickup_col}') \
                    .alias(col_name))
    
    print(f"number of unique partitions for DF: {df.select(col_name).distinct().count()}")

    # repartition and save locally 
    os.system(f'mkdir {table_name}_{year_month}')

    print(f"saving partitions in {table_name}_{year_month}")
    
    df.repartition(col_name) \
        .write.parquet(f'{table_name}_{year_month}', mode = 'overwrite') 

    # print(f'removing unneeded _SUCCESS file from folder')
    os.system(f'rm -r {table_name}_{year_month}/_SUCCESS')

    print(f"loading parquets to {root_path}")
    
    # copy parquets to cloud storage
    os.system(f'gsutil -m cp -r {table_name}_{year_month}/*.parquet {root_path}')
    
    print('cleaning up environment for next run')
    # remove parquets locally to make sure pull from gcp worked
    os.system(f'rm -r {table_name}_{year_month}')
    os.system(f'rm {filename}')

#########################################dataproc##########################################################################
#
# def data_2_gcp_cloud_storage(df, root_path):
#     # repartition df for export 
#     pickup_col = [col for col in df.columns if 'pickup' in col][0]
#     col_name = [re.sub('datetime|date_time', 'date', col) for col in df.columns if 'pickup' in col][0]
    
#     df = df.alias('a') \
#             .select('a.*', \
#                     F.date_trunc('day', f'a.{pickup_col}') \
#                     .alias(col_name))
    
#     print(f"number of unique partitions for DF: {df.select(col_name).distinct().count()}")

#     # repartition and save locally 
#     os.system(f'gcloud storage folders create {root_path}')

#     print(f"loading parquets to {root_path}")
    
#     df.repartition(col_name) \
#         .write.parquet(root_path, mode = 'overwrite')

#     print(f'removing unneeded _SUCCESS file from folder')
#     os.system(f'gsutil rm -r {root_path}/_SUCCESS')
    
#     print(f"loading partitions to {root_path} complete")
###########################################################################################################################

def bucket_2_bigquery(gcp_project_name, table_name, root_path):
    
    # connecting to bigquery 
    credentials = service_account.Credentials.from_service_account_file(os.getenv('GOOGLE_APPLICATION_CREDENTIALS'))
    project_id = gcp_project_name

    # cloud storage centric vars 
    table_name = table_name

    # query centric vars 
    if 'yellow_tripdata_2009' in root_path:
        col_param = ' '.join([key + ' ' + item + ',' for key, item in schema_yellow_dict2009.items()])[:-1]
        col_names = ' '.join([key + ',' for key in schema_yellow_dict2009.keys()])[:-1]
    elif 'yellow_tripdata_2010' in root_path:
        col_param = ' '.join([key + ' ' + item + ',' for key, item in schema_yellow_dict2010.items()])[:-1]
        col_names = ' '.join([key + ',' for key in schema_yellow_dict2010.keys()])[:-1]
    elif 'yellow' in root_path:
        col_param = ' '.join([key + ' ' + item + ',' for key, item in schema_yellow_dict.items()])[:-1]
        col_names = ' '.join([key + ',' for key in schema_yellow_dict.keys()])[:-1]
    elif 'green' in root_path:
        col_param = ' '.join([key + ' ' + item + ',' for key, item in schema_green_dict.items()])[:-1]
        col_names = ' '.join([key + ',' for key in schema_green_dict.keys()])[:-1]
    elif 'fhv' in root_path:
        col_param = ' '.join([key + ' ' + item + ',' for key, item in schema_fhv_dict.items()])[:-1]
        col_names = ' '.join([key + ',' for key in schema_fhv_dict.keys()])[:-1]
    elif 'fhvhv' in root_path:
        col_param = ' '.join([key + ' ' + item + ',' for key, item in schema_fhvhv_dict.items()])[:-1]
        col_names = ' '.join([key + ',' for key in schema_fhvhv_dict.keys()])[:-1]
    else:
        pass

    q1a = f"""create schema if not exists `{project_id}`.`nytaxi_raw`
    options (location = 'EU')
    """

    q1b = f"""create schema if not exists `{project_id}`.`nytaxi_stage`
    options (location = 'EU')
    """

    q1c = f"""create schema if not exists `{project_id}`.`nytaxi_transform`
    options (location = 'EU')
    """

    q1d = f"""create schema if not exists `{project_id}`.`nytaxi_prod`
    options (location = 'EU')
    """

    q2 = f"""create or replace external table `{project_id}`.`nytaxi_raw.external_{table_name}`
    options (
    format = 'PARQUET',
    uris = ['{root_path}/*']
    )
    """

    if 'yellow_tripdata_2009' in root_path:
        q3a = f"""create table if not exists `{project_id}`.`nytaxi_stage`.`{table_name}_2009`
        ({col_param})
        """
        q3b = f"""insert into `{project_id}`.`nytaxi_stage`.`{table_name}_2009`
        ({col_names})
        select *, current_timestamp() from `{project_id}.nytaxi_raw.external_{table_name}`
        """
    elif 'yellow_tripdata_2010' in root_path:
        q3a = f"""create table if not exists `{project_id}`.`nytaxi_stage`.`{table_name}_2010`
        ({col_param})
        """
        q3b = f"""insert into `{project_id}`.`nytaxi_stage`.`{table_name}_2010`
        ({col_names})
        select *, current_timestamp() from `{project_id}.nytaxi_raw.external_{table_name}`
        """
    else:
        q3a = f"""create table if not exists `{project_id}`.`nytaxi_stage`.`{table_name}`
        ({col_param})
        """
        q3b = f"""insert into `{project_id}`.`nytaxi_stage`.`{table_name}`
        ({col_names})
        select *, current_timestamp() from `{project_id}.nytaxi_raw.external_{table_name}`
        """
    

    # get BigQuery connection 
    client = bigquery.Client(credentials = credentials, project = project_id)

    # execute queries 
    print('creating if not already present schemas')
    client.query(q1a)
    client.query(q1b)
    client.query(q1c)
    client.query(q1d)

    print('creating external table')
    time.sleep(10)
    client.query(q2)

    print('populating stage table')
    time.sleep(10)
    client.query(q3a)
    time.sleep(10)
    client.query(q3b)

    print('loading data to stage complete')

    print('cleanuped up env for next run')

if __name__ == '__main__':
    # print(os.system('pwd'))
    # print('parsing input arguments')
    # parser = argparse.ArgumentParser()
    
    # parser.add_argument('--table_name')
    # parser.add_argument('--start_date')
    # parser.add_argument('--end_date')
    
    # args = parser.parse_args()
    
    # date ranges
    start_dt = datetime.strptime('2010-01-01','%Y-%m-%d')#datetime.strptime(args.start_date,'%Y-%m-%d')
    end_dt = datetime.strptime('2010-01-01','%Y-%m-%d')
    delta = relativedelta(months=1)
    
    # var that stays the same through out run 
    table_name = 'yellow_tripdata'#args.table_name
    bucket_name_data = 'taxi-data-extract'
    gcp_project_name = 'pipeline-analysis-446021'
    
    print(f"will be fetching parquest for {table_name}, from {start_dt} - {end_dt}")

#########################################local#############################################################################
    os.system('gcloud auth activate-service-account --key-file $GOOGLE_APPLICATION_CREDENTIALS')
###########################################################################################################################
    
    # start spark session 
    print('starting spark session')
    spark = SparkSession.builder \
        .master("local[4]") \
        .appName('extract-load-spark') \
        .getOrCreate()
        
    while start_dt <= end_dt:
    
        # vars for fetching parquet
        year_month = start_dt.strftime("%Y-%m")
        filename = f"{table_name}_{year_month}.parquet"
        url = f"https://d37ci6vzurychx.cloudfront.net/trip-data/{filename}"
        
        print(f'starting iteration for {table_name}, {year_month}')
        
        # vars for gcp 
        root_path = f"gs://{bucket_name_data}/{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}_{table_name}_{year_month}"
        
        # running throught the ETL pipeline
#########################################local#############################################################################
        df = load_trip_data(spark, url, filename)
#########################################dataproc##########################################################################
        # df = load_trip_data(spark, url, filename, bucket_name)
###########################################################################################################################
        if parquet_not_downloaded(df):
            abort_pipeline
            pass
        elif parquet_downloaded(df):
            df = dimension_name_cleanup(df)
#########################################local#############################################################################
            data_2_gcp_cloud_storage(df, table_name, year_month, root_path, filename)
#########################################dataproc##########################################################################
            # data_2_gcp_cloud_storage(df, root_path)
###########################################################################################################################
            bucket_2_bigquery(gcp_project_name, table_name, root_path)
        else:
            print('another issue encountered not yet considered, halting iterations')
            break
    
        start_dt += delta
    
    # close spark session
    # spark.stop()