In [9]:
import os
import subprocess
from pathlib import Path

"""
Dynamically find the project root (where .git exists) and set it as the current working directory.
"""
project_root = Path(subprocess.check_output(['git', 'rev-parse', '--show-toplevel'], text=True).strip())
os.chdir(project_root)

In [10]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
import scipy.stats as stats
from scipy.stats import mannwhitneyu
import math
import geopandas as gpd
import matplotlib.patches as mpatches
from matplotlib.patches import Patch
import plotly.express as px
from matplotlib import cm
from matplotlib.colors import to_hex
import plotly.graph_objects as go




In [11]:
US_ratings = pd.read_csv('data/USData/BA_US_states_all.csv')
US_ratings.head(3)

Unnamed: 0,beer_name,beer_id,brewery_name,brewery_id,style,user_id,appearance,aroma,palate,taste,overall,rating,avg,user_state,beer_state
0,Kupfer Kolsch,289320.0,Copper State Brewing Company,49595.0,Kölsch,n2185.211743,2.5,4.0,4.0,3.75,3.75,3.76,3.76,North Carolina,Wisconsin
1,Northwestern Alt,289321.0,Copper State Brewing Company,49595.0,Altbier,n2185.211743,3.0,3.75,4.0,3.5,3.5,3.58,3.58,North Carolina,Wisconsin
2,One Cent Wheat,289319.0,Copper State Brewing Company,49595.0,Witbier,n2185.211743,3.75,3.25,3.75,3.5,3.5,3.48,3.48,North Carolina,Wisconsin


In [14]:
US_ratings.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6331638 entries, 0 to 6331637
Data columns (total 15 columns):
 #   Column        Dtype  
---  ------        -----  
 0   beer_name     object 
 1   beer_id       float64
 2   brewery_name  object 
 3   brewery_id    float64
 4   style         object 
 5   user_id       object 
 6   appearance    float64
 7   aroma         float64
 8   palate        float64
 9   taste         float64
 10  overall       float64
 11  rating        float64
 12  avg           float64
 13  user_state    object 
 14  beer_state    object 
dtypes: float64(9), object(6)
memory usage: 724.6+ MB


In [15]:
#csv file containing all the states and their neighbouring states
#empty list for states that have no neighbours (e.g Alazka, Hawaii)
neighbours_df = pd.read_csv('data/additionalData/bordering_states.csv', dtype={'state':'string', 
                                                                          'neighbours': 'string'})
neighbours_df["neighbours"] = neighbours_df["neighbours"].fillna("").apply(lambda x: x.split(";") if x else [])

neighbours_df.head(3)

Unnamed: 0,state,neighbours,nb_neighbours
0,Alabama,"[Florida, Georgia, Mississippi, Tennessee]",4
1,Alaska,[],0
2,Arizona,"[California, Colorado, Nevada, New Mexico, Utah]",5


In [16]:
def gather_region_ratings(state, US_ratings, neighbours_df):
    neighbours = neighbours_df.loc[neighbours_df['state'] == state, 'neighbours'].values[0]
    
    region_states = [state] + neighbours    
    region_ratings = US_ratings[US_ratings['beer_state'].isin(region_states)]

    in_region_ratings = region_ratings[region_ratings['user_state'].isin(region_states)]
    not_in_region_ratings = region_ratings[~region_ratings['user_state'].isin(region_states)]
    
    return in_region_ratings, not_in_region_ratings

states = US_ratings['beer_state'].unique()

all_ratings = {'region': [], 'rating': [], 'rating_type': [], 'user_state': []}

for state in states:
    in_region, non_region = gather_region_ratings(state, US_ratings, neighbours_df)

    region_name = f"{state}"
    
    # Add In-Region ratings to the dictionary
    all_ratings['region'].extend([region_name] * len(in_region))
    all_ratings['rating'].extend(in_region['rating'].tolist())
    all_ratings['rating_type'].extend(['In-Region'] * len(in_region))
    all_ratings['user_state'].extend(in_region['user_state'].tolist())
    
    # Add Non-Region ratings to the dictionary
    all_ratings['region'].extend([region_name] * len(non_region))
    all_ratings['rating'].extend(non_region['rating'].tolist())
    all_ratings['rating_type'].extend(['Non-Region'] * len(non_region))
    all_ratings['user_state'].extend(non_region['user_state'].tolist())

ratings_df = pd.DataFrame(all_ratings)

ratings_df.head()


Unnamed: 0,region,rating,rating_type,user_state
0,Wisconsin,4.04,In-Region,Wisconsin
1,Wisconsin,4.0,In-Region,Wisconsin
2,Wisconsin,3.75,In-Region,Wisconsin
3,Wisconsin,3.9,In-Region,Illinois
4,Wisconsin,3.25,In-Region,Wisconsin


In [17]:
ratings_df.head(200)

Unnamed: 0,region,rating,rating_type,user_state
0,Wisconsin,4.04,In-Region,Wisconsin
1,Wisconsin,4.00,In-Region,Wisconsin
2,Wisconsin,3.75,In-Region,Wisconsin
3,Wisconsin,3.90,In-Region,Illinois
4,Wisconsin,3.25,In-Region,Wisconsin
...,...,...,...,...
195,Wisconsin,3.50,In-Region,Minnesota
196,Wisconsin,4.25,In-Region,Wisconsin
197,Wisconsin,3.75,In-Region,Minnesota
198,Wisconsin,4.15,In-Region,Minnesota


In [18]:
ratings_df = ratings_df[ratings_df['rating_type'] == 'In-Region']
ratings_df = ratings_df.drop('rating_type', axis=1)
ratings_df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 13278808 entries, 0 to 34011603
Data columns (total 3 columns):
 #   Column      Dtype  
---  ------      -----  
 0   region      object 
 1   rating      float64
 2   user_state  object 
dtypes: float64(1), object(2)
memory usage: 405.2+ MB


In [19]:
ratings_df['rating_type'] = np.where(ratings_df['region'] == ratings_df['user_state'], 'In-State', 'Non-State')
avg_ratings = ratings_df.groupby(['region', 'rating_type'])['rating'].mean().unstack()
avg_ratings = avg_ratings.reset_index()
avg_ratings.head()


rating_type,region,In-State,Non-State
0,Alabama,3.876441,3.942452
1,Alaska,3.970084,
2,Arizona,3.950191,3.971004
3,Arkansas,3.782575,3.784315
4,California,4.003972,3.97037


In [20]:
# Creating a scatter plot for each region displaying the average rating for each user state in region
for region in ratings_df['region'].unique():
    subset = ratings_df[ratings_df['region'] == region]
    
    avg_ratings = subset.groupby('user_state')['rating'].mean().reset_index()
    
    # plt.figure(figsize=(14, 6))
    
    # sns.scatterplot(data=avg_ratings, x='user_state', y='rating', color='blue', s=100, marker='o', label='Average Rating')

    # plt.title(f'Average Rating by User State for Region: {region}')
    # plt.xlabel('User State')
    # plt.ylabel('Average Rating')
    # plt.xticks(rotation=90)
    # plt.legend(title='Rating Type')
    # plt.tight_layout()
    # plt.ylim(0,5)

    # plt.show()


In [21]:
# Creating violin plot for each region displaying rating distributions for each user state in the region
unique_regions = ratings_df['region'].unique()

for region in unique_regions:
    subset_data = ratings_df[ratings_df['region'] == region]
    
    # plt.figure(figsize=(12, 8))
    # sns.violinplot(
    #     x='user_state',
    #     y='rating',
    #     data=subset_data,
    #     inner='quartile',
    #     palette='Set2',
    #     hue='user_state'
    # )
    
    # plt.title(f"Violin Plot of Ratings by User State Within Region: {region}")
    # plt.xlabel("User State")
    # plt.ylabel("Rating")
    # plt.xticks(rotation=90)
    # plt.tight_layout()
    # plt.ylim(0,5)
    # plt.show()


In [23]:
def cohen_d(x, y):
    x = np.array(x, dtype=float)
    y = np.array(y, dtype=float)
    nx, ny = len(x), len(y)
    mean_x, mean_y = np.mean(x), np.mean(y)
    std_x, std_y = np.std(x, ddof=1), np.std(y, ddof=1) 
    pooled_std = np.sqrt(((nx - 1) * std_x ** 2 + (ny - 1) * std_y ** 2) / (nx + ny - 2))
    return (mean_x - mean_y) / pooled_std

In [24]:
# Creating a plot of the Cohen's D values between every user_state of a region without calculating it twice.
all_cohen_results = []


for region in ratings_df['region'].unique():
    subset = ratings_df[ratings_df['region'] == region]
    cohen_results = []
    processed_pairs = set()
    
    for state1 in subset['user_state'].unique():
        for state2 in subset['user_state'].unique():
            if state1 != state2:
                state_pair = tuple(sorted([state1, state2]))
                
                if state_pair in processed_pairs:
                    continue
                
                ratings_state1 = subset[subset['user_state'] == state1]['rating']
                ratings_state2 = subset[subset['user_state'] == state2]['rating']
                
                if len(ratings_state1) == 0 or len(ratings_state2) == 0:
                    continue
                
                d_value = cohen_d(ratings_state1, ratings_state2)
                cohen_results.append((state1, state2, d_value, region))
                processed_pairs.add(state_pair)    
    
    cohen_df_region = pd.DataFrame(cohen_results, columns=['State1', 'State2', 'Cohen_d', 'Region'])
    
    cohen_df_region = cohen_df_region.sort_values(by='Cohen_d')
    
    all_cohen_results.append(cohen_df_region)


    
    # plt.figure(figsize=(12, 8))
    # sns.barplot(
    #     x='State1', 
    #     y='Cohen_d', 
    #     data=cohen_df_region, 
    #     palette='viridis', 
    #     hue='State2'
    # )

    # plt.axhline(y=0, color='black', linewidth=1)
    # plt.axhline(y=0.2, color='#FFA07A', linestyle=':', linewidth=2, label='Small effect (d=0.2)')
    # plt.axhline(y=-0.2, color='#FFA07A', linestyle=':', linewidth=2)
    # plt.axhline(y=0.5, color='#FF8C00', linestyle=':', linewidth=2, label='Medium effect (d=0.5)')
    # plt.axhline(y=-0.5, color='#FF8C00', linestyle=':', linewidth=2)
    # plt.axhline(y=0.8, color='#CD3700', linestyle=':', linewidth=2, label='Large effect (d=0.8)')
    # plt.axhline(y=-0.8, color='#CD3700', linestyle=':', linewidth=2)
    
    # plt.xlabel('State')
    # plt.ylabel('Cohen’s D')
    # plt.title(f'Cohen’s D for State pair Ratings by Region: Center state {region}')
    # plt.xticks(rotation=0)
    # plt.legend(title='State Pairs', loc='upper right')
    # plt.tight_layout()
    
    # plt.show()
    

final_cohen_df = pd.concat(all_cohen_results, ignore_index=True)

  final_cohen_df = pd.concat(all_cohen_results, ignore_index=True)


In [25]:
final_cohen_df.head()

Unnamed: 0,State1,State2,Cohen_d,Region
0,Wisconsin,Illinois,-0.109529,Wisconsin
1,Wisconsin,Minnesota,-0.061605,Wisconsin
2,Michigan,Minnesota,-0.058442,Wisconsin
3,Wisconsin,Iowa,-0.031834,Wisconsin
4,Iowa,Minnesota,-0.028861,Wisconsin


In [26]:
michigan_rows = final_cohen_df[(final_cohen_df['State1'] == 'Missouri') | (final_cohen_df['State2'] == 'Missouri')]
michigan_rows

Unnamed: 0,State1,State2,Cohen_d,Region
28,Arkansas,Missouri,-0.056528,Arkansas
29,Texas,Missouri,-0.046513,Arkansas
30,Oklahoma,Missouri,-0.046008,Arkansas
37,Missouri,Louisiana,0.091351,Arkansas
41,Missouri,Tennessee,0.189659,Arkansas
45,Missouri,Mississippi,0.268916,Arkansas
161,Arkansas,Missouri,-0.042045,Missouri
164,Oklahoma,Missouri,-0.021886,Missouri
167,Missouri,Iowa,-0.00318,Missouri
174,Missouri,Nebraska,0.084094,Missouri


In [27]:
# Sort the DataFrame by Cohen's d in ascending order
sorted_cohen_df = final_cohen_df.sort_values(by='Cohen_d', key=abs)

state_groups = []
assigned_states = set()

for _, row in sorted_cohen_df.iterrows():
    if row['Cohen_d'] >= 2:
        continue

    state1 = row['State1']
    state2 = row['State2']

    # Skip if both states are already assigned
    if state1 in assigned_states and state2 in assigned_states:
        continue

    group_found = False

    # Check if either state is in any existing group
    for group in state_groups:
        if state1 in group['States'] or state2 in group['States']:
            if state1 not in group['States']:
                group['States'].append(state1)
                assigned_states.add(state1)
            if state2 not in group['States']:
                group['States'].append(state2)
                assigned_states.add(state2)
            
            group_found = True
            break

    # If no group was found, create a new one
    if not group_found:
        state_groups.append({'States': [state1, state2]})
        assigned_states.update([state1, state2])

print("State Groups:")
print("-" * 50)
for group in state_groups:
    states = ", ".join(group['States'])
    print(f"States: {states}")
    print("-" * 50)


State Groups:
--------------------------------------------------
States: West Virginia, Ohio, Delaware, New York
--------------------------------------------------
States: Pennsylvania, Maryland, Virginia, Kentucky, New Jersey
--------------------------------------------------
States: South Dakota, Wyoming
--------------------------------------------------
States: Utah, Montana, Washington
--------------------------------------------------
States: Missouri, Iowa, Minnesota, New Mexico, Nevada, Idaho
--------------------------------------------------
States: Kansas, Nebraska
--------------------------------------------------
States: Texas, Oklahoma, Louisiana, Arizona, Oregon, Colorado, California
--------------------------------------------------
States: Illinois, Indiana
--------------------------------------------------
States: Wisconsin, Michigan, North Dakota
--------------------------------------------------
States: Alabama, Arkansas, South Carolina, Tennessee, Mississippi
-------

In [28]:
import pprint

# Pretty print the state groups
pprint.pprint(state_groups)

[{'States': ['West Virginia', 'Ohio', 'Delaware', 'New York']},
 {'States': ['Pennsylvania', 'Maryland', 'Virginia', 'Kentucky', 'New Jersey']},
 {'States': ['South Dakota', 'Wyoming']},
 {'States': ['Utah', 'Montana', 'Washington']},
 {'States': ['Missouri', 'Iowa', 'Minnesota', 'New Mexico', 'Nevada', 'Idaho']},
 {'States': ['Kansas', 'Nebraska']},
 {'States': ['Texas',
             'Oklahoma',
             'Louisiana',
             'Arizona',
             'Oregon',
             'Colorado',
             'California']},
 {'States': ['Illinois', 'Indiana']},
 {'States': ['Wisconsin', 'Michigan', 'North Dakota']},
 {'States': ['Alabama',
             'Arkansas',
             'South Carolina',
             'Tennessee',
             'Mississippi']},
 {'States': ['New Hampshire', 'Rhode Island', 'Maine']},
 {'States': ['Georgia', 'North Carolina', 'Florida']},
 {'States': ['Massachusetts', 'Connecticut', 'Vermont']}]


In [30]:
from collections import defaultdict

# List of all 50 U.S. states
us_states_list = [
    'Alabama', 'Alaska', 'Arizona', 'Arkansas', 'California', 'Colorado', 
    'Connecticut', 'Delaware', 'Florida', 'Georgia', 'Hawaii', 'Idaho', 
    'Illinois', 'Indiana', 'Iowa', 'Kansas', 'Kentucky', 'Louisiana', 
    'Maine', 'Maryland', 'Massachusetts', 'Michigan', 'Minnesota', 
    'Mississippi', 'Missouri', 'Montana', 'Nebraska', 'Nevada', 'New Hampshire', 
    'New Jersey', 'New Mexico', 'New York', 'North Carolina', 'North Dakota', 
    'Ohio', 'Oklahoma', 'Oregon', 'Pennsylvania', 'Rhode Island', 
    'South Carolina', 'South Dakota', 'Tennessee', 'Texas', 'Utah', 'Vermont', 
    'Virginia', 'Washington', 'West Virginia', 'Wisconsin', 'Wyoming'
]

# Flatten the list of states and count occurrences
state_occurrences = defaultdict(int)
for group in state_groups:
    for state in group['States']:
        state_occurrences[state] += 1

# Identify states present in multiple groups
states_in_multiple_groups = {state for state, count in state_occurrences.items() if count > 1}

# Check for missing and extra states
grouped_states = set(state_occurrences.keys())
missing_states = set(us_states_list) - grouped_states
extra_states = grouped_states - set(us_states_list)

# Print results nicely
print("Summary of State Groups Check:")
print("-" * 40)

if missing_states:
    print("Missing States (not included in state_groups):")
    for state in sorted(missing_states):
        print(f" - {state}")
else:
    print("All 50 states are included in state_groups.")

print("\n" + "-" * 40)

if extra_states:
    print("Extra States (invalid entries in state_groups):")
    for state in sorted(extra_states):
        print(f" - {state}")
else:
    print("No extra states found in state_groups.")

print("\n" + "-" * 40)

if states_in_multiple_groups:
    print("States Present in Multiple Groups:")
    for state in sorted(states_in_multiple_groups):
        print(f" - {state}")
else:
    print("No states are present in multiple groups.")

print("-" * 40)


Summary of State Groups Check:
----------------------------------------
Missing States (not included in state_groups):
 - Alaska
 - Hawaii

----------------------------------------
No extra states found in state_groups.

----------------------------------------
No states are present in multiple groups.
----------------------------------------


In [8]:
us_map_path = 'data/USData/map/us_states_modified.shp'
us_states_map = gpd.read_file(us_map_path)
# us_states_map = us_states_map.to_crs(epsg=4326)

# # Save to GeoJSON
# geojson_path = '../../USData/map/us_states_modified.geojson'
# us_states_map.to_file(geojson_path, driver="GeoJSON")


NameError: name 'gpd' is not defined

In [31]:
# # Flatten state_groups into a DataFrame
# grouped_states = []
# for idx, States in enumerate(state_groups):
#     for state in States['States']:
#         grouped_states.append({'user_state': state, 'States': f'Group {idx + 1}'})

# # Convert to DataFrame
# top_style_per_state = pd.DataFrame(grouped_states)

# # Identify all states from the GeoDataFrame
# all_states = us_states_map['name'].unique()

# # Identify matched states and missing states
# matched_states = top_style_per_state['user_state'].unique()
# missing_states = set(all_states) - set(matched_states)

# # Add missing states with "No Data"
# missing_states_df = pd.DataFrame({'user_state': list(missing_states), 'States': 'No Data'})
# top_style_per_state = pd.concat([top_style_per_state, missing_states_df], ignore_index=True)

# # Merge with the GeoDataFrame
# us_states_map = us_states_map.merge(
#     top_style_per_state,
#     left_on="name",
#     right_on="user_state",
#     how="left"
# )

# # Fill missing groups with "No Data"
# us_states_map['States'] = us_states_map['States'].fillna("No Data")

# # Map groups to colors
# unique_groups = us_states_map['States'].unique()
# group_colors = {
#     States: plt.cm.tab20(i / len(unique_groups)) if States != "No Data" else "red"
#     for i, States in enumerate(unique_groups)
# }

# # Assign colors to each state
# us_states_map['color'] = us_states_map['States'].map(group_colors)

# # Plot the map
# fig, ax = plt.subplots(1, 1, figsize=(15, 10))
# us_states_map.plot(
#     ax=ax,
#     color=us_states_map['color'],
#     edgecolor='black',
#     linewidth=0.5
# )

# # Create a legend for groups
# legend_patches = [mpatches.Patch(color=color, label=States) for States, color in group_colors.items()]
# ax.legend(handles=legend_patches, title="State Groups by Cohen's d", loc="lower left", fontsize=10)
# ax.set_title("State Groups by Cohen's d", fontsize=16)

# # Clean up the axes for better visuals
# ax.axis("off")
# plt.tight_layout()
# plt.show()


In [12]:
state_to_abbr = {
    'Alabama': 'AL', 'Alaska': 'AK', 'Arizona': 'AZ', 'Arkansas': 'AR', 'California': 'CA',
    'Colorado': 'CO', 'Connecticut': 'CT', 'Delaware': 'DE', 'Florida': 'FL', 'Georgia': 'GA',
    'Hawaii': 'HI', 'Idaho': 'ID', 'Illinois': 'IL', 'Indiana': 'IN', 'Iowa': 'IA',
    'Kansas': 'KS', 'Kentucky': 'KY', 'Louisiana': 'LA', 'Maine': 'ME', 'Maryland': 'MD',
    'Massachusetts': 'MA', 'Michigan': 'MI', 'Minnesota': 'MN', 'Mississippi': 'MS', 'Missouri': 'MO',
    'Montana': 'MT', 'Nebraska': 'NE', 'Nevada': 'NV', 'New Hampshire': 'NH', 'New Jersey': 'NJ',
    'New Mexico': 'NM', 'New York': 'NY', 'North Carolina': 'NC', 'North Dakota': 'ND', 'Ohio': 'OH',
    'Oklahoma': 'OK', 'Oregon': 'OR', 'Pennsylvania': 'PA', 'Rhode Island': 'RI', 'South Carolina': 'SC',
    'South Dakota': 'SD', 'Tennessee': 'TN', 'Texas': 'TX', 'Utah': 'UT', 'Vermont': 'VT',
    'Virginia': 'VA', 'Washington': 'WA', 'West Virginia': 'WV', 'Wisconsin': 'WI', 'Wyoming': 'WY'
}


In [32]:
# Step 1: Map each state to a group number
state_to_group = {}
for idx, group in enumerate(state_groups):
    for state in group['States']:
        state_to_group[state] = f'Group {idx + 1}'

all_states = set(state_to_abbr.keys())  # Get all states
grouped_states = [{'state_name': state, 'group': state_to_group.get(state, 'No Group')} for state in all_states]

grouped_states_df = pd.DataFrame(grouped_states)

In [33]:
grouped_states_df['state_abbreviation'] = grouped_states_df['state_name'].map(state_to_abbr)

In [34]:
def adjust_cohen_d(filtered_cohen_df):
    mask = filtered_cohen_df['State2'] == filtered_cohen_df['Region']
    
    # Switch State1 and State2, and invert Cohen_d
    filtered_cohen_df.loc[mask, ['State1', 'State2']] = filtered_cohen_df.loc[mask, ['State2', 'State1']].values
    filtered_cohen_df.loc[mask, 'Cohen_d'] = -filtered_cohen_df.loc[mask, 'Cohen_d']
    
    return filtered_cohen_df


def create_region_dict(adjusted_df):
    region_dict = {}
    for region, group in adjusted_df.groupby('Region'):
        region_dict[region] = group[['State2', 'Cohen_d']].to_dict(orient='records')
    return region_dict

In [35]:
# Add state abbreviation column
grouped_states_df['state_abbreviation'] = grouped_states_df['state_name'].map(state_to_abbr)

# Filter Cohen's d data based on region and state
filtered_cohen_df = final_cohen_df[
    (final_cohen_df['Region'] == final_cohen_df['State1']) | (final_cohen_df['Region'] == final_cohen_df['State2'])
]

adjusted_df = adjust_cohen_d(filtered_cohen_df)

region_dict = create_region_dict(adjusted_df)


In [36]:
# Create hover text and state abbreviation mapping for visualization
hover_data = []
for region, cohen_d_values in region_dict.items():
    for entry in cohen_d_values:
        hover_data.append({
            'state_name': region,
            'State2': entry['State2'],
            'Cohen_d': entry['Cohen_d'],
            'Region': region
        })



In [37]:
hover_df = pd.DataFrame(hover_data)

#hover_df

In [38]:
hover_df['hover_text'] = hover_df.apply(
    lambda row: "Cohen's d values:<br>" + 
                '<br>'.join([f"{state}: {cohen_d:.4f}" for state, cohen_d in zip(hover_df[hover_df['Region'] == row['Region']]['State2'], 
                                                                            hover_df[hover_df['Region'] == row['Region']]['Cohen_d']) ]),
                                                                            axis=1
                                                                            )



In [39]:
hover_df = hover_df.drop(columns=['state_name','State2', 'Cohen_d'])

# Drop duplicates, keeping only one row per region
hover_df = hover_df.drop_duplicates(subset='Region').reset_index(drop=True)


In [40]:
hover_df['state_abbreviation'] = hover_df['Region'].map(state_to_abbr)

grouped_states_df = grouped_states_df.merge(hover_df[['state_abbreviation', 'hover_text']], on='state_abbreviation', how='left')

grouped_states_df = grouped_states_df.sort_values('group')



In [41]:
grouped_states_df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 50 entries, 25 to 31
Data columns (total 4 columns):
 #   Column              Non-Null Count  Dtype 
---  ------              --------------  ----- 
 0   state_name          50 non-null     object
 1   group               50 non-null     object
 2   state_abbreviation  50 non-null     object
 3   hover_text          48 non-null     object
dtypes: object(4)
memory usage: 2.0+ KB


In [46]:
# Step 10: Plot choropleth map
custom_palette = [
    "#a6cee3", "#1f78b4", "#b2df8a", "#33a02c", "#fb9a99", "#e31a1c", "#fdbf6f", "#ff7f00",
    "#cab2d6", "#6a3d9a", "#ffff99", "#b15928", "#f2a7c3", "#d1f2a5", "#ffb3b3", "#ffcc99",
    "#ccebc5", "#ffb6e6", "#d0f0c0", "#f9c9b6"
]

# Use the custom palette in your color map
num_groups = len(grouped_states_df['group'].unique())
color_map = {f'Group {idx + 1}': custom_palette[idx % len(custom_palette)] for idx in range(num_groups)}
color_map['No Group'] = 'grey'

fig = px.choropleth(
    grouped_states_df, 
    locations='state_abbreviation', 
    color='group', 
    color_discrete_map=color_map,
    title="USA State Groups",
    locationmode="USA-states",
    hover_name='state_name',
    hover_data={'state_name': False, 'group': False, 'state_abbreviation': False, 'hover_text': True}  # Disable state abbreviation and hover text
)

fig.update_traces(hovertemplate='%{customdata[0]}<br><br>%{customdata[3]}')


fig.update_layout(
    title_text = 'State Groups by Cohen\'s d',
    geo_scope='usa',
    width=700,  # Adjust width
    height=500,  # Adjust height
    title_font=dict(size=20),  # Optional: Increase title font size
    geo=dict(
        projection_type="albers usa"  # Optional: Adjust projection style if necessary
    )
)

fig.show()
directory = "img/question1/"
fig.write_html(os.path.join(directory, "custom_region.html"))