In [33]:
import glob
import os
import pandas as pd
import numpy as np
from scipy import stats
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
import ipywidgets as widgets
from IPython.display import display, clear_output

# ========================================================
# 1. Data Preparation
# ========================================================

folder_path = '/Users/jerryshen/Documents/Medical School Application Prep/Residency Analysis'

# Load the uscities.csv file
uscities_file = os.path.join(folder_path, "uscities.csv")
uscities_df = pd.read_csv(uscities_file)

# Clean the city and state columns
uscities_df['city'] = uscities_df['city'].str.strip()
uscities_df['state_id'] = uscities_df['state_id'].str.strip()

# Find all CSV files that match the pattern "*_match_list.csv"
csv_files = glob.glob(os.path.join(folder_path, "*_match_list.csv"))

# List to store DataFrames from each match list file
schoolData = []
for file in csv_files:
    df = pd.read_csv(file, encoding='latin1')
    df['Rank'] = df['Rank'].str.lstrip('#')
    df['Rank'] = pd.to_numeric(df['Rank'], errors='coerce')
    df['SourceFile'] = os.path.basename(file)
    # Split the Location column (e.g., "Boston, MA") into separate city and state_id columns
    df[['city', 'state_id']] = df['Location'].str.split(',', expand=True)
    df['city'] = df['city'].str.strip()
    df['state_id'] = df['state_id'].str.strip()
    
    # Merge with uscities_df to add lat, lng, population, and density columns
    df = pd.merge(
        df,
        uscities_df[['city', 'state_id', 'lat', 'lng', 'population', 'density']],
        on=['city', 'state_id'],
        how='left'
    )
    
    schoolData.append(df)

# Combine all DataFrames into one
combined_df = pd.concat(schoolData, ignore_index=True)

# Create a School column from SourceFile
combined_df['School'] = combined_df['SourceFile'].apply(lambda x: x.split('_')[0])

# Define a function to assign geographic region
def assign_region(state):
    if state in ['ME','NH','VT','MA','CT','RI']:
        return 'New England'
    elif state in ['NJ', 'NY', 'PA']:
        return 'Middle Atlantic'
    elif state in ['MD', 'DE', 'DC','WV', 'VA','NC','SC','GA','FL']:
        return 'South Atlantic'
    elif state in ['KY','TN','MS','AL']:
        return 'East South Central'
    elif state in ['WI', 'MI', 'IL', 'IN', 'OH']:
        return 'East North Central'
    elif state in ['ND', 'SD', 'NE', 'KS', 'MN', 'IA','MO']:
        return 'West North Central'
    elif state in ['TX', 'OK', 'AR', 'LA']:
        return 'West South Central'
    elif state in ['MT', 'WY', 'ID', 'UT', 'CO', 'NV','NM','AZ']:
        return 'Mountain West'
    elif state in ['WA', 'OR', 'CA', 'AK', 'HI']:
        return 'Pacific West'
    else:
        return np.nan

# Create geographic US location column
combined_df['geographic US location'] = combined_df['state_id'].apply(assign_region)

# Compute regional proportions for each school.
region_counts = combined_df.groupby(['School', 'geographic US location']).size().reset_index(name='count')
region_pivot = region_counts.pivot(index='School', columns='geographic US location', values='count').fillna(0)
region_proportions = region_pivot.div(region_pivot.sum(axis=1), axis=0).reset_index()

# ========================================================
# 2. Specialty-Based Scores Calculation
# ========================================================

specialties = combined_df['Specialty'].unique()
all_specialty_results = []
for specialty in specialties:
    def compute_scores(x):
        if (x['Specialty'] == specialty).any():
            prop = (x['Specialty'] == specialty).mean()
            med = x.loc[x['Specialty'] == specialty, 'Rank'].median()
        else:
            prop = np.nan
            med = np.nan
        return pd.Series({'proportion': prop, 'median_rank': med})
    
    specialty_prop = combined_df.groupby('SourceFile').apply(compute_scores).reset_index()
    specialty_prop['School'] = specialty_prop['SourceFile'].apply(lambda x: x.split('_')[0])
    
    # Leave NaN for schools without the specialty.
    specialty_prop['Normalized Proportion'] = stats.zscore(specialty_prop['proportion'], nan_policy='omit')
    specialty_prop['Normalized Median Rank'] = stats.zscore(specialty_prop['median_rank'], nan_policy='omit')
    specialty_prop['Specialty'] = specialty
    all_specialty_results.append(specialty_prop)

all_specialties_df = pd.concat(all_specialty_results, ignore_index=True)

competitive_specialties = ["Dermatology", "General Surgery", "Neurological Surgery",
                           "Orthopaedic Surgery", "Plastic Surgery", "Otolaryngology",
                           "Vascular Surgery", "Thoracic Surgery"]
all_specialties_df['Competitiveness'] = all_specialties_df['Specialty'].apply(
    lambda x: 1 if x in competitive_specialties else 0
)

# ========================================================
# 3. Interactive Controls Setup
# ========================================================

# Create dropdowns for highlighting three schools.
school_list = sorted(combined_df['School'].unique())
school_options = [("None", None)] + [(s, s) for s in school_list]

highlight_school1 = widgets.Dropdown(options=school_options, 
                                       description="Highlight School 1:",
                                       layout=widgets.Layout(width='300px'),
                                       style={'description_width': '150px'})
highlight_school2 = widgets.Dropdown(options=school_options, 
                                       description="Highlight School 2:",
                                       layout=widgets.Layout(width='300px'),
                                       style={'description_width': '150px'})
highlight_school3 = widgets.Dropdown(options=school_options, 
                                       description="Highlight School 3:",
                                       layout=widgets.Layout(width='300px'),
                                       style={'description_width': '150px'})

# Adjust specialty checkbox list.
specialties_sorted = sorted(all_specialties_df['Specialty'].dropna().astype(str).unique())
specialty_checkboxes = [widgets.Checkbox(value=False, description=s) for s in specialties_sorted]
specialty_grid = widgets.GridBox(children=specialty_checkboxes,
                                 layout=widgets.Layout(
                                     grid_template_columns="repeat(4, 200px)",
                                     grid_gap="10px 10px"))

# Adjust the quality slider style.
w_slider = widgets.FloatSlider(value=0.25, min=0, max=1, step=0.01,
                               description='Rank Weight',
                               continuous_update=False,
                               layout=widgets.Layout(width='400px'),
                               style={'description_width': '150px'})

quality_section = widgets.VBox([
    widgets.HTML("<h3>Residency Quality Metric</h3>"),
    specialty_grid,
    w_slider
])

regions = ['New England', 'Middle Atlantic', 'South Atlantic', 'East North Central', 
           'East South Central', 'West North Central', 'West South Central', 
           'Mountain West', 'Pacific West']
default_region_weights = {r: 0.5 for r in regions}
region_sliders = {r: widgets.FloatSlider(value=default_region_weights[r],
                                           min=0, max=1, step=0.01,
                                           description=r,
                                           continuous_update=False,
                                           layout=widgets.Layout(width='350px'),
                                           style={'description_width': '150px'})
                  for r in regions}
region_slider_widgets = widgets.GridBox(children=list(region_sliders.values()),
                                        layout=widgets.Layout(
                                            grid_template_columns="repeat(3, 350px)",
                                            grid_gap="10px 10px"))
geo_importance_slider = widgets.FloatSlider(value=0.1, min=0, max=1, step=0.01,
                                             description='Geo Importance',
                                             continuous_update=False,
                                             layout=widgets.Layout(width='400px'),
                                             style={'description_width': '150px'})
# Create a reset button.
reset_button = widgets.Button(
    description="Reset Sliders",
    button_style='info',
    layout=widgets.Layout(width='100px')
)

# Define the function to reset sliders.
def reset_sliders(b):
    # Reset the quality slider.
    w_slider.value = 0.25
    # Reset the geo importance slider.
    geo_importance_slider.value = 0.1
    # Reset each region slider.
    for r, slider in region_sliders.items():
        slider.value = default_region_weights[r]
    # Update the plot.
    update_plot()

# Attach the reset function to the button.
reset_button.on_click(reset_sliders)

geo_section = widgets.VBox([
    widgets.HTML("<h3>Geographic Preference</h3>"),
    region_slider_widgets,
    geo_importance_slider,
    reset_button
])

# Add the highlight dropdowns alongside other controls.
controls = widgets.VBox([quality_section, geo_section, 
                         widgets.HBox([highlight_school1, highlight_school2, highlight_school3])])

# ========================================================
# 4. Update Plot Function (Two-Panel Figure)
# ========================================================

state_to_region = {
    'CT':'New England','ME':'New England','MA':'New England','NH':'New England','RI':'New England','VT':'New England',
    'NJ':'Middle Atlantic','NY':'Middle Atlantic','PA':'Middle Atlantic',
    'DE':'South Atlantic','DC':'South Atlantic','MD':'South Atlantic','VA':'South Atlantic','WV':'South Atlantic','NC':'South Atlantic','SC':'South Atlantic','GA':'South Atlantic','FL':'South Atlantic',
    'IN':'East North Central','IL':'East North Central','MI':'East North Central','OH':'East North Central','WI':'East North Central',
    'KY':'East South Central','TN':'East South Central','MS':'East South Central','AL':'East South Central',
    'ND':'West North Central','SD':'West North Central','NE':'West North Central','KS':'West North Central','MN':'West North Central','IA':'West North Central','MO':'West North Central',
    'OK':'West South Central','TX':'West South Central','AR':'West South Central','LA':'West South Central',
    'ID':'Mountain West','MT':'Mountain West','WY':'Mountain West','NV':'Mountain West','UT':'Mountain West','CO':'Mountain West','NM':'Mountain West','AZ':'Mountain West',
    'AK':'Pacific West','CA':'Pacific West','HI':'Pacific West','OR':'Pacific West','WA':'Pacific West'
}

combined_df.to_csv('all_specialties_withRank.csv', index=False)

def update_plot(change=None):
    clear_output(wait=True)
    display(controls)
    
    # --- Compute Quality Component ---
    selected_specialties = [cb.description for cb in specialty_checkboxes if cb.value]
    all_schools = pd.DataFrame({'School': combined_df['School'].unique()})
    
    if selected_specialties:
        # Create a grid of every school and each selected specialty.
        specialty_grid = all_schools.merge(pd.DataFrame({'Specialty': selected_specialties}), how='cross')
        
        # Filter the computed specialty data for the selected specialties.
        df = all_specialties_df[all_specialties_df['Specialty'].isin(selected_specialties)].copy()
        w = w_slider.value
        df['Score'] = ((1 - w) * df['Normalized Proportion'] - w * df['Normalized Median Rank'])
        
        # Merge the grid with computed scores so every school-specialty pair exists.
        df_full = specialty_grid.merge(df[['School', 'Specialty', 'Score']], on=['School', 'Specialty'], how='left')
        
        # Sum scores across specialties for each school; if all are NaN, keep as NaN.
        def sum_with_all_nan(series):
            return np.nan if series.isna().all() else series.sum(skipna=True)
        
        school_scores = df_full.groupby('School')['Score'].apply(sum_with_all_nan).reset_index(name='Total_Composite_Score')
        
        # Force schools with no rank information to have Total_Composite_Score = -1.
        no_rank_mask = school_scores['Total_Composite_Score'].isna()
        school_scores.loc[no_rank_mask, 'Total_Composite_Score'] = -1
        
        # For schools with valid data, compute normalized quality.
        valid_mask = school_scores['Total_Composite_Score'] != -1
        if valid_mask.sum() > 0:
            valid_scores = school_scores.loc[valid_mask, 'Total_Composite_Score']
            norm_valid = stats.zscore(valid_scores)
            school_scores.loc[valid_mask, 'Normalized_Total_Composite_Score'] = norm_valid
        else:
            school_scores['Normalized_Total_Composite_Score'] = school_scores['Total_Composite_Score']
    else:
        school_scores = all_schools.copy()
        school_scores['Total_Composite_Score'] = -1
        school_scores['Normalized_Total_Composite_Score'] = -1
    
    # --- Merge Regional Proportions ---
    school_scores = pd.merge(school_scores, region_proportions, on='School', how='left')
    
    # --- Compute Geographic Component ---
    raw_slider_sum = sum(region_sliders[r].value for r in regions)
    if raw_slider_sum == 0:
        normalized_region_weights = {r: 0 for r in regions}
    else:
        normalized_region_weights = {r: region_sliders[r].value / raw_slider_sum for r in regions}
    
    def compute_region_score(row):
        score = 0
        for r in regions:
            score += row.get(r, 0) * normalized_region_weights[r]
        return score
    school_scores['raw_region_score'] = school_scores.apply(compute_region_score, axis=1)
    mean_region_score = school_scores['raw_region_score'].mean()
    school_scores['adjusted_region_score'] = school_scores['raw_region_score'] - mean_region_score
    
    # --- Allocation Based on Geo Importance ---
    geo_importance = geo_importance_slider.value
    quality_alloc = (1 - geo_importance) * 10
    region_alloc  = geo_importance * 10
    
    # Scale quality component for valid schools.
    valid_mask = school_scores['Total_Composite_Score'] != -1
    if valid_mask.sum() > 0:
        q_min = school_scores.loc[valid_mask, 'Normalized_Total_Composite_Score'].min()
        q_max = school_scores.loc[valid_mask, 'Normalized_Total_Composite_Score'].max()
        if q_max != q_min:
            school_scores.loc[valid_mask, 'Scaled_Quality'] = (school_scores.loc[valid_mask, 'Normalized_Total_Composite_Score'] - q_min) / (q_max - q_min) * quality_alloc
        else:
            school_scores.loc[valid_mask, 'Scaled_Quality'] = quality_alloc
    school_scores.loc[~valid_mask, 'Scaled_Quality'] = -1
    
    # Scale geographic component.
    r_min = school_scores['adjusted_region_score'].min()
    r_max = school_scores['adjusted_region_score'].max()
    if r_max != r_min:
        school_scores['Scaled_Region'] = (school_scores['adjusted_region_score'] - r_min) / (r_max - r_min) * region_alloc
    else:
        school_scores['Scaled_Region'] = region_alloc
    
    # --- Compute Final Score ---
    school_scores['Final_Score'] = school_scores['Scaled_Quality'] + school_scores['Scaled_Region']
    school_scores.loc[~valid_mask, 'Final_Score'] = -1
    
    # Rescale valid scores so that the minimum valid score is 0 and the maximum is 10.
    valid_mask = school_scores['Final_Score'] != -1
    if valid_mask.sum() > 0:
        valid_scores = school_scores.loc[valid_mask, 'Final_Score']
        min_valid = valid_scores.min()
        max_valid = valid_scores.max()
        if max_valid != min_valid:
            school_scores.loc[valid_mask, 'Final_Score'] = ((valid_scores - min_valid) / (max_valid - min_valid)) * 10
        else:
            school_scores.loc[valid_mask, 'Final_Score'] = 10
    
    # Force the Final_Score column to be numeric before sorting.
    school_scores['Final_Score'] = pd.to_numeric(school_scores['Final_Score'], errors='coerce')
    school_scores = school_scores.sort_values(by='Final_Score', ascending=True)
    
    # --- Build Lollipop Plot (Left Panel) ---
    # Get the highlighted schools from the dropdowns.
    hl1 = highlight_school1.value
    hl2 = highlight_school2.value
    hl3 = highlight_school3.value
    
    lollipop_traces = []
    for _, row in school_scores.iterrows():
        # Set defaults.
        line_color = 'lightgray'
        marker_color = 'gray'
        marker_size = 8
        
        # Only change color if the dropdown is not None.
        if hl1 is not None and row['School'] == hl1:
            marker_color = 'goldenrod'
            marker_size = 12
        elif hl2 is not None and row['School'] == hl2:
            marker_color = 'purple'
            marker_size = 12
        elif hl3 is not None and row['School'] == hl3:
            marker_color = 'teal'
            marker_size = 12
            
        lollipop_traces.append(go.Scatter(
            x=[0, row['Final_Score']],
            y=[row['School'], row['School']],
            mode='lines',
            line=dict(color=line_color, width=2),
            showlegend=False,
            hoverinfo='skip'
        ))
    marker_trace = go.Scatter(
        x=school_scores['Final_Score'],
        y=school_scores['School'],
        mode='markers+text',
        text=[f"{score:.2f}" for score in school_scores['Final_Score']],
        textposition='middle right',
        marker=dict(color=[
            'goldenrod' if hl1 is not None and school == hl1 
            else 'purple' if hl2 is not None and school == hl2 
            else 'teal' if hl3 is not None and school == hl3 
            else 'gray'
            for school in school_scores['School']
        ], size=[
            12 if (hl1 is not None and school == hl1) or 
                  (hl2 is not None and school == hl2) or 
                  (hl3 is not None and school == hl3) else 8 
            for school in school_scores['School']
        ]),
        name='Final Score'
    )
    
    # --- Build US Map (Right Panel) ---
    state_list = list(state_to_region.keys())
    map_weights = [normalized_region_weights.get(state_to_region[st], 0) for st in state_list]
    state_df = pd.DataFrame({'state': state_list, 'region_weight': map_weights})
    
    map_trace = go.Choropleth(
        locations=state_df['state'],
        z=state_df['region_weight'],
        locationmode="USA-states",
        colorscale="Greens",
        zmin=0,
        zmax=1,
        colorbar=dict(title="Region Weight")
    )
    
    combined_fig = make_subplots(
        rows=1, cols=2, 
        column_widths=[0.5, 0.5],
        specs=[[{"type": "xy"}, {"type": "geo"}]],
        subplot_titles=("School Composite Scores", "US Geographic Preference")
    )
    
    for trace in lollipop_traces:
        combined_fig.add_trace(trace, row=1, col=1)
    combined_fig.add_trace(marker_trace, row=1, col=1)
    combined_fig.add_trace(map_trace, row=1, col=2)
    
    combined_fig.update_geos(scope="usa", projection_type="albers usa", row=1, col=2)
    combined_fig.update_layout(
        height=min(1200, 50*len(school_scores)),  # Limit the height for readability.
        width=1500, template="plotly_white",
        title_text="Composite Scores and Geographic Preferences"
    )
    combined_fig.update_xaxes(title_text="Final Score (0-10)", row=1, col=1)
    combined_fig.update_yaxes(title_text="School", row=1, col=1, tickfont=dict(size=10))
    
    combined_fig.show()

# ========================================================
# 5. Attach Observers & Display Controls
# ========================================================

for cb in specialty_checkboxes:
    cb.observe(update_plot, 'value')
w_slider.observe(update_plot, 'value')
for slider in region_sliders.values():
    slider.observe(update_plot, 'value')
geo_importance_slider.observe(update_plot, 'value')
highlight_school1.observe(update_plot, 'value')
highlight_school2.observe(update_plot, 'value')
highlight_school3.observe(update_plot, 'value')

display(controls)
update_plot()

VBox(children=(VBox(children=(HTML(value='<h3>Residency Quality Metric</h3>'), GridBox(children=(Checkbox(valu…