In [180]:
import json
import csv 
import functools

import pandas as pd 
import geopandas as gpd 
import numpy as np
import plotly.express as px 
import plotly

In [181]:
def map_from_csv(fpath: str, drop_header=True) -> dict: 
    with open(fpath, mode='r') as infile:
        reader = csv.reader(infile)
        if drop_header:
            next(reader, None)
        return {rows[0]:rows[1] for rows in reader}
    
def create_code_columns(df) -> pd.DataFrame:
    df["Energy_code"] = df["MSN"].str[0:2]
    df["Sector_code"] = df["MSN"].str[2:4]
    df["Unit_code"] = df["MSN"].str[4]
    return df

def _determine_subset_list(arg, default):
    if arg is not None: 
        if isinstance(arg, str):
            arg = [arg]
    else: 
        arg = default
    return arg 

def data_subset(df, states=None, years=None, sectors=None, sources=None) -> pd.DataFrame:
    all_states = df["State"].unique()
    all_years = df["Year"].unique()
    all_sectors = df["Sector"].unique()
    all_sources = df["Source"].unique()
        
    states = _determine_subset_list(states, all_states)
    years = _determine_subset_list(years, all_years)
    sectors = _determine_subset_list(sectors, all_sectors)
    sources = _determine_subset_list(sources, all_sources)
    
    masks = [
        df["State"].isin(states),
        df["Year"].isin(years),
        df["Sector"].isin(sectors),
        df["Source"].isin(sources)
    ]
    final_mask = [all(row) for row in zip(*masks)]
    
    return df[final_mask]
    

In [191]:
df = pd.read_csv("use_all_btu.csv")

state_abbr_map = map_from_csv("states.csv")
energy_codes_map = map_from_csv("energy_codes.csv")
sector_codes_map = map_from_csv("sector_codes.csv")
unit_codes_map = map_from_csv("unit_codes.csv")
state_color_map = map_from_csv("state_plot_colors.csv")

state_color_map

{'Alabama': 'rgb(99, 110, 250)',
 'Alaska': 'rgb(239, 85, 59)',
 'Arizona': 'rgb(0, 204, 150)',
 'Arkansas': 'rgb(171, 99, 250)',
 'California': 'rgb(255, 161, 90)',
 'Colorado': 'rgb(25, 211, 243)',
 'Connecticut': 'rgb(255, 102, 146)',
 'Delaware': 'rgb(182, 232, 128)',
 'District of Columbia': 'rgb(255, 151, 255)',
 'Florida': 'rgb(254, 203, 82)',
 'Georgia': 'rgb(95, 70, 144)',
 'Hawaii': 'rgb(29, 105, 150)',
 'Idaho': 'rgb(56, 166, 165)',
 'Illinois': 'rgb(15, 133, 84)',
 'Indiana': 'rgb(115, 175, 72)',
 'Iowa': 'rgb(237, 173, 8)',
 'Kansas': 'rgb(225, 124, 5)',
 'Kentucky': 'rgb(204, 80, 62)',
 'Louisiana': 'rgb(148, 52, 110)',
 'Maine': 'rgb(111, 64, 112)',
 'Montana': 'rgb(102, 102, 102)',
 'Nebraska': 'rgb(136, 204, 238)',
 'Nevada': 'rgb(204, 102, 119)',
 'New Hampshire': 'rgb(221, 204, 119)',
 'New Jersey': 'rgb(17, 119, 51)',
 'New Mexico': 'rgb(51, 34, 136)',
 'New York': 'rgb(170, 68, 153)',
 'North Carolina': 'rgb(68, 170, 153)',
 'North Dakota': 'rgb(153, 153, 51)',
 'O

# Combining datasets

In [183]:
df = create_code_columns(df)
df = df.rename(columns={"State": "Abbreviation"})

# Mapping codes to full values 
df["State"] = df["Abbreviation"].map(state_abbr_map)
df["Source"] = df["Energy_code"].map(energy_codes_map)
df["Sector"] = df["Sector_code"].map(sector_codes_map)
df["Unit"] = df["Unit_code"].map(unit_codes_map)

# Removing data not related to consumption

In [184]:
# Remove non-state entities
not_states = ["US", "DC"]
df = df[~df["Abbreviation"].isin(not_states)]

# Dropping MSN's that don't end in B (GDP and generation)
df = df[~df["MSN"].str[-1].isin(["X", "R"])]

# Filtering out rows that aren't related to consumption 
consumption_codes = ["AC", "CC", "IC", "RC", "TC", "AP", "IP", "CP", "RP", "TP"]
df = df[df["Sector_code"].isin(consumption_codes)]
df = df[~df["Energy_code"].isin(["TN", "TP", "P1"])]

# Remove MSN code names 
df = df.drop(columns=["Data_Status", "MSN", "Abbreviation", "Energy_code", "Sector_code", "Unit_code"])

# Melting years into a Year column

Total across all sectors can be found in all MSN rows that have TC as the 3rd and 4th character.

In [185]:
id_vars = ["State", "Source", "Sector", "Unit"]
df = df.melt(id_vars=id_vars, var_name="Year", value_name="BTU")
df

Unnamed: 0,State,Source,Sector,Unit,Year,BTU
0,Alaska,Aviation gasoline blending components,Industrial,Billion BTU,1960,0.0
1,Alaska,Asphalt and road oil,Industrial,Billion BTU,1960,312.0
2,Alaska,Asphalt and road oil,Total,Billion BTU,1960,312.0
3,Alaska,Aviation gasoline,Transportation,Billion BTU,1960,5209.0
4,Alaska,Aviation gasoline,Total,Billion BTU,1960,5209.0
...,...,...,...,...,...,...
442495,Wyoming,Wood and biomass waste,Total,Billion BTU,2018,4901.0
442496,Wyoming,Waxes,Industrial,Billion BTU,2018,0.0
442497,Wyoming,Wind energy,Commercial,Billion BTU,2018,0.0
442498,Wyoming,Wind energy,Industrial,Billion BTU,2018,0.0


# Getting total use across all sectors 

In [186]:
total_df = data_subset(df, sectors=["Total"], sources=["Fossil fuels"])

In [187]:
summed_df = total_df.groupby(["State", "Year"], as_index=False).sum()

In [195]:
min_y = 0
max_y = summed_df["BTU"].max()

summed_df["Year"] = summed_df["Year"].astype(int)

fig = px.bar(
        summed_df, 
        x="State", 
        y="BTU", 
        animation_frame="Year", 
        range_y=(min_y, max_y), 
        color="State",
        title="United States total energy consumption",
        color_discrete_map=state_color_map, 
)
fig.update_layout(yaxis_title="Billion BTU")
fig.update_xaxes(categoryorder="total ascending", tickmode="linear")
fig.show()

# State choropleth

In [None]:
# gdf = gpd.read_file(r"C:\Users\Chris\Documents\Programming\pythonstuff\US-Energy\cb_2018_us_state_20m\cb_2018_us_state_20m.shp")
# gdf.to_crs(epsg=4326)

# fig = px.choropleth_mapbox(summed_df, locations="Abbreviation", color="BTU", geojson=json.loads(gdf.to_json()), featureidkey="properties.STUSPS", color_continuous_scale=plotly.colors.diverging.Temps, range_color=(0, 3_000_000), animation_frame="Year", title="Residential energy consumption (BTU) in the US")
# fig.update_layout(mapbox_style="carto-positron",
#                   mapbox_zoom=2.6, mapbox_center={"lat": 38, "lon": -98})
# fig.show()
# gdf

In [None]:
px.colors.qualitative.Plotly + px.colors.qualitative.Prism + px.colors.qualitative.Safe + px.colors.qualitative.Vivid + px.colors.qualitative.Pastel

In [None]:
rgb_colors = []
for color in px.colors.qualitative.Plotly:
    color = color.lstrip('#')
    rgb = []
    for i in (0, 2, 4):
        rgb.append(int(color[i:i+2], 16))
    r,g,b = rgb
    rgb_colors.append(f"rgb({r}, {g}, {b})")
all_colors = rgb_colors + px.colors.qualitative.Prism + px.colors.qualitative.Safe + px.colors.qualitative.Vivid + px.colors.qualitative.Pastel
all_colors = all_colors[0:51]
color_map = {"State": "Color"}
color_map.update({color_tup[0]: [color_tup[1]] for color_tup in zip(state_abbr_map.values(), all_colors)})
df = pd.DataFrame(color_map).T
df = df.reset_index()
df
df.columns = df.iloc[0]
df = df.drop(index=0)
df.to_csv("state_plot_colors.csv", index=False)