In [1]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

df_salary = pd.read_csv("../../../data//salary_potential.csv")
df_two_years = pd.read_csv("../../../data/National_education_cost/CP3-pub-2y-current-dollars.csv")
df_four_years = pd.read_csv("../../../data/National_education_cost/CP3-pub-4y-current-dollars.csv")

import warnings
warnings.filterwarnings("ignore")

df_four_years.head()

Unnamed: 0,State,2004-05,2005-06,2006-07,2007-08,2008-09,2009-10,2010-11,2011-12,2012-13,...,2016-17,2017-18,2018-19,2019-20,2020-21,2021-22,2022-23,2023-24,1-Year % Change,5-Year % Change
0,Alabama,4510,4782,4906,5244,5968,6487,7373,8001,8734,...,10083,10650,10777,10922,10996,11338,11683,11890,0.02,0.1
1,Alaska,3435,3793,4194,4425,4678,4922,5261,5455,5785,...,7128,7440,7821,8233,8610,8815,9024,9163,0.02,0.17
2,Arizona,4078,4434,4674,4959,5584,6554,8075,9435,9728,...,10931,11210,11545,11879,11811,11820,12184,12583,0.03,0.09
3,Arkansas,4581,4980,5314,5599,5914,5980,6304,6654,6995,...,8254,8550,8701,9036,9078,9250,9478,9734,0.03,0.12
4,California,4195,4526,4549,4951,5436,6550,7485,8933,8986,...,9302,9800,9875,9854,9924,9943,10353,10641,0.03,0.08


In [2]:
def get_data_ready(data: pd.DataFrame):
    data = data.rename({
        "In Current Dollars":"State"
        }, 
        axis=1, 
        # inplace=True
        ).dropna()
    column_to_drop = data.columns[-7:-2].to_list()
    # print(column_to_drop)
    data.drop(column_to_drop, axis=1, inplace=True)
    # print(data.columns)
    
    data = pd.melt(data, id_vars=["State", "1-Year % Change", "5-Year % Change"], var_name="Year", value_name="Tuition")
    data = data.sort_values(by=["State", "Year"]).reset_index(drop=True)
    relevant = pd.DataFrame(data.groupby(["State"])["Tuition"].mean(""))
    return relevant




In [3]:
relevant_four_year = get_data_ready(df_four_years)
relevant_two_year = get_data_ready(df_two_years)



In [4]:
relevant_salary = df_salary.iloc[:, 1:].groupby(["state_name"]).mean().reset_index()
relevant_salary.rename({
    "state_name":"State"
}, axis=1, inplace=True)
relevant_salary['State'] = relevant_salary['State'].apply(lambda x: x.replace("-", " "))
# relevant_salary

In [5]:
for i in relevant_salary.State:
    if i not in relevant_two_year.index:
        print(i)

Alaska


In [6]:
# # merge data
# df_list = [relevant_salary, relevant_four_year, relevant_two_year]
# merged_df = df_list[0]
# for df in df_list:
#     merged_df = pd.merge(right=merged_df, left=df, on="State")
# merged_df

merged_df = pd.merge(right=relevant_two_year, left=relevant_four_year, on="State", suffixes=["-four years", "-two years"]).reset_index()
df = pd.merge(right=merged_df, left=relevant_salary, on="State")


In [7]:
df.head()

Unnamed: 0,State,early_career_pay,mid_career_pay,Tuition-four years,Tuition-two years
0,Alabama,44992.0,81592.0,7726.8,3639.94
1,Arizona,50228.571429,90642.857143,8154.333333,2087.564
2,Arkansas,44110.526316,79242.105263,6703.533333,2742.738
3,California,67232.0,123976.0,7473.6,1088.828
4,Colorado,51857.894737,93257.894737,7613.4,3311.248


In [8]:
# x and y given as array_like objects
import plotly.express as px
pay_list = df.columns[1:3].to_list()
univ_type_list =  df.columns[3:].to_list()
pay_list, univ_type_list


(['early_career_pay', 'mid_career_pay'],
 ['Tuition-four years', 'Tuition-two years'])

In [10]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
fig = make_subplots(
    rows=2, cols=2,
    column_widths=[0.4, 0.4],
    row_heights=[0.4, 0.4],
    )

for univ_type in univ_type_list:
    for pay in pay_list:
        fig.add_trace(

            go.Scatter(
            x=df[univ_type],
            y=df[pay],
            mode="markers",
            name= f"{univ_type}-{pay}",
            
            
            ),
            row=univ_type_list.index(univ_type)+1,
            col=pay_list.index(pay)+1,
        )
        


fig.update_layout(title_text="Correlation between salary and tuition", showlegend=True, height=1000)
# fig.write_html("../../../graphs/html/five/scatter_plot.html")
fig.show()

In [None]:
fig = px.imshow(df.iloc[:,1:].corr(), text_auto=True, aspect="auto", width=1000)
fig.write_html("../../../graphs/html/five/correlation.html")
fig.show()