In [37]:
import os
import sys
import pendulum
import re
from pymongo.mongo_client import MongoClient
from pyspark.sql import SparkSession
from airflow.decorators import dag, task
import logging
import pandas as pd
from pyspark.sql.functions import col
import subprocess

In [7]:
config_path = os.getcwd().replace('\\airflow','')
sys.path.insert(0, config_path)

from config import settings
from scripts import data_extraction as de
from scripts import data_transform as dm

In [26]:
def extract():
    try:
        de.drop_db_if_exists(settings.mongo_default_db)

        mongo_instance = de.MongoDB(username= settings.mongo_username,
                                password= settings.mongo_password, \
                                default_db = settings.mongo_default_db, \
                                default_col = settings.mongo_default_colname, \
                                default_clusterName= settings.mongo_default_clusterName, \
                                schema = settings.mongo_default_schema)
    
        mongo_instance.test_connectivity()
        mongo_instance.initialize()

        data = de.USDA_API(settings.usda_key)
        data.add_params('state_alpha','US')

        for commodity_desc in data.get_param_values('commodity_desc'):
            for year in de.create_mongo_year_list(2015):
                try:
                    col_title = re.sub(r"[ ,&()]","", commodity_desc).replace(" ", "_")
                    data.add_params('commodity_desc', commodity_desc)
                    data.add_params('year', year)

                    current_doc = data.call()
                    connection = mongo_instance.test_connectivity()

                    if  mongo_instance.test_connectivity() == '1' and type(current_doc) != str:
                        mongo_instance.add_new_col(col_title)
                        mongo_instance.add_record(current_doc, col_title)
                        mongo_instance.drop_col(settings.mongo_default_colname)

                    data.remove_params('commodity_desc')
                    data.remove_params('year')
                except Exception as e:
                    print(f"Error {e}, {data.call()}")
                    data.remove_params('commodity_desc')
                    data.remove_params('year')   
        return True
    except Exception as e:
        logging.error(f"Error {e}")
        return False

In [16]:
def transform(extract_success: bool):
    if extract_success == True:
        spark = SparkSession.builder \
                .config("spark.jars", "C:\jdbc\postgresql-42.7.5.jar") \
                .getOrCreate()
    
        mongo_conn = MongoClient(settings.mongo_client)[settings.mongo_default_db]

        spark_rdd_list={}
        for collection in mongo_conn.list_collection_names():
            if collection not in settings.excluded_commodities:
                test_df = dm.drop_first_row(pd.DataFrame([item for item in mongo_conn.get_collection(collection).find()]))

                #Converts id column to a string to make it easier to be created as a pyspark rdd.
                test_df['id'] = test_df.apply(dm.stringify_id, axis=1)
                test_df = test_df.drop('_id',axis=1)

                #Moves id to the front of the dataframe
                id = test_df['id']
                test_df.drop(labels=['id'], axis=1,inplace=True)
                test_df.insert(0, 'id', id)


                spark_df = spark.createDataFrame(test_df)

                spark_df = spark_df.where(~col('Value').like('%(%'))
                spark_df = spark_df.where(~col('CV (%)').like('%(%'))

                spark_df = spark_df.withColumn("Value",
                                spark_df['Value']
                                .cast('float')) \
                            .withColumn("CV (%)",
                                spark_df['CV (%)']
                                .cast('float')) \
                            .withColumn('year',
                               spark_df['year']
                               .cast('int')) \
                            .withColumn('zip_5',
                               spark_df['zip_5'] \
                               .cast('int')) \
                            .withColumn('load_time',
                               spark_df['load_time'] \
                               .cast('date')) \
                            

            spark_rdd_list[collection] = spark_df
        return spark_rdd_list
    else:
        return 'Upstream workflow failure!'

In [36]:
def load(collection_dict):
    try:
        if type(collection_dict) == dict:
            for collection in collection_dict:
                collection_dict[collection].write.format("jdbc")\
                .mode('overwrite') \
                .option("url", "jdbc:postgresql://localhost:5432/USDA_DB") \
                .option("driver", "org.postgresql.Driver") \
                .option("dbtable", f"ag.{collection.replace('-','_')}") \
                .option("user", f"{settings.pgadmin_user}").option("password", f"{settings.pgadmin_password}") \
                .save()

            return True
    except Exception as e:
        logging.error(f"Error {e}")
        return False

In [None]:
extract_success = extract()
collection_dict = transform(extract_success)
load(collection_dict)