In [67]:
import os
import json
import requests
import pandas as pd
from mysql import connector # allows us to connect Python to MySQL
from dotenv import load_dotenv # load sensitive environmental variables from .env file

## Define Key Variables

In [68]:
load_dotenv() # this must be run in order to grant getenv function permission to access objects in .env file. 

# retrieve sensitive information stored in local .env file
API_KEY = os.getenv("API_KEY")
API_HOST = os.getenv("API_HOST") 

# load MySQL credentials  
db_host = os.getenv("MYSQL_HOST")
db_port = os.getenv("MYSQL_PORT")
db_user = os.getenv("MYSQL_USER")
db_pssword = os.getenv("MYSQL_PASSWORD")
db_db = os.getenv("MYSQL_DATABASE")
db_table = os.getenv("SQL_TABLE")

# API Request
# url = f"https://api.eia.gov/v2/electricity/retail-sales/data/?api_key={API_KEY}&frequency=annual&data[0]=customers&data[1]=price&data[2]=revenue&data[3]=sales&start=2019&sort[0][column]=period&sort[0][direction]=desc&offset=0&length=5000"

# QA 1
url = f"https://api.eia.gov/v2/electricity/retail-sales/data/?api_key={API_KEY}&frequency=monthly&data[0]=customers&data[1]=price&data[2]=revenue&data[3]=sales&sort[0][column]=period&sort[0][direction]=desc&offset=0&length=5000"

# QA 2
url = f"https://api.eia.gov/v2/electricity/retail-sales/data/?api_key={API_KEY}&frequency=monthly&data[0]=customers&data[1]=price&data[2]=revenue&data[3]=sales&facets[stateid][]=ME&sort[0][column]=period&sort[0][direction]=desc&offset=0&length=5000"

# Don't forget to add "api_key =" and "&" after you insert API KEY variable
# Full documentation here: https://www.eia.gov/opendata/documentation/APIv2.1.0.pdf

## GET Request (Extract)

In [69]:
# GET request function
def get_data(url):
    try:
        response = requests.get(url)
        json_data = response.json()
        rows_returned = int(json_data["response"]["total"])
        if rows_returned > 5000:
            print(f"Success! {rows_returned} records exists, but only 5,000 can be returned due to API limit.")
        else:
            print(f"Success! {rows_returned} records were returned.")
        return json_data["response"]["data"]
    except Exception as e:
        print(f"Error Occurred during the extraction stage: {e}")
        return None

# raw_data = get_data(url)

## Transform

In [70]:
def store_data(raw_data):
    # creating a list that will store each record
    rows = [] 
    # define column names 
    col_names = ['period', 'stateid', 'sectorid', 'customers', 'price', 'revenue', 'sales'] 
    # for loop to store each field within a record into a variable
    for record in raw_data:
        period = record["period"]
        stateid = record["stateid"]
        sectorid = record["sectorid"]
        customers = record["customers"]
        price = record["price"]
        revenue = record["revenue"]
        sales = record["sales"]
        
        # append each field into the empty list as dictionaries
        rows.append({"period":period, "stateid":stateid, "sectorid":sectorid,"customers":customers, "price":price, "revenue":revenue, "sales":sales})
    return rows

# data = store_data(raw_data)

In [71]:
# turning the data into a Python DataFrame
def to_df(data):
    df = pd.DataFrame(data)
    dfn = len(df)
    df = df.dropna()
    print(f"{dfn - len(df)} rows containing nulls were dropped, {len(df)} records remain.")
    return df


# df = to_df(data)

## Load into SQL

In [None]:
# connect to our local database
def db_connect(db_host, db_port, db_user, db_pssword, db_db):
    db_connection = None
    try:
        db_connection = connector.connect(
            host = db_host,
            user = db_user,
            passwd = db_pssword,
            database = db_db,
            port = db_port,
            connection_timeout = 10 
        )
        print(f"Connection to schema: "{db_db}" successful ✅")
        return db_connection
    except Exception as e:
        print(f"Error Occurred during the loading stage: {e}")

In [85]:
# create a table in our database
def create_table(db_connection):
    SQL_CREATE_TABLE = f"""
    CREATE TABLE {db_table} (
		period VARCHAR(25) NOT NULL,
        stateid VARCHAR(5) NOT NULL,
        sectorid VARCHAR(10) NOT NULL,
        customers INT NOT NULL,
        price INT NOT NULL,
        revenue INT NOT NULL,
        sales INT NOT NULL,
        PRIMARY KEY (period, stateid, sectorid)
    );
    """
    try:
        cursor = db_connection.cursor()
        cursor.execute(SQL_CREATE_TABLE)
        db_connection.commit()
        print(f"{db_table} created successfully ✅")

    except connector.Error as e:
        if e.errno == 1050:
            print(f"""The table: "{db_table}" has already been created.""")
        else:
            print(f"❌ [CREATING TABLE ERROR]: '{e}'")
            return

# create_table(db_connection)

In [None]:
# Insert or update data in the database from the dataframe
def insert_into_table(db_connection, df, db_table):

    cursor = db_connection.cursor()

    INSERT_DATA_SQL_QUERY = f"""
        INSERT INTO {db_table}
        VALUES(%s, %s, %s, %s, %s, %s, %s) AS src
        ON DUPLICATE KEY UPDATE
        period = src.period,
        stateid = src.stateid,
        sectorid = src.sectorid,
        customers = src.customers,
        price = src.price,
        revenue = src.revenue,
        sales = src.sales;
        """
    # Create a list of tuples from the dataframe values in array form 
    data_as_tuples = [tuple(x) for x in df.to_numpy()]
    # can also use df.itertuples() to iterate over DF rows as tuples

    # Execute the query
    cursor.executemany(INSERT_DATA_SQL_QUERY, data_as_tuples)
    db_connection.commit()
    print(f"{cursor.rowcount} records inserted or updated successfully ✅")
    
# insert_into_table(db_connection, df)

In [81]:
# return how many rows exists in the database table
def return_db_rows(db_connection):
    RETURN_ROW_SQL_QUERY = f"""
        SELECT COUNT(*) FROM e_sales;
        """
    cursor = db_connection.cursor()
    cursor.execute(RETURN_ROW_SQL_QUERY)
    results = cursor.fetchone()[0]
    print(f"There are {results} rows in the table.")
    cursor.close()

In [82]:
# full pipeline
def run_pipeline():
    raw_data = get_data(url)

    if raw_data is not None:
        data = store_data(raw_data)
        df = to_df(data)
    else:
        return

    db_connection = db_connect(db_host, db_port, db_user, db_pssword, db_db)

    if db_connection is not None:
        create_table(db_connection)
        insert_into_table(db_connection, df, db_table)
        return_db_rows(db_connection)
        db_connection.close()
    else:
        print("Failed to connect to DB")

In [86]:
run_pipeline()

Success! 1776 records were returned.
716 rows containing nulls were dropped, 1060 records remain.
Connection to schema: eia_electricity_sale successful ✅
The table: "e_sales" has already been created.
0 records inserted or updated successfully ✅
There are 6147 rows in the table.


In [None]:
# return how many rows we tried uploaded & uploaded to MySQL
rows_uploaded = len(list_of_sales_tuples)
cur.execute(f"SELECT COUNT(*) FROM {sql_table}")
upload_count = cur.fetchone()[0]

try: 
    cur.executemany(UPSERT_SQL, list_of_sales_tuples)
    db_connection.commit()
    print(f"Success! Attempted to upload {rows_uploaded} records.")
    print(f"Actual records uploaded: {upload_count}")
except Exception as e:
    db_connection.rollback()
    print(f"Error! Rollback due to {e}")
finally:
    cur.close()
    db_connection.close()
    print("All database connects closed. Clean up completed.")
