In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import geopandas as gpd
import seaborn as sns

ideas:

change in who a county votes for (flips or not) is spatially correlated / informed 

spatial correlation on residuals from a model that predicts winner in each county

most religious groups are spatially correlated

todo:

- describe dataset and coverage
- describe what spatial correlation means for this dataset (or rather, for the specific features we are looking at)
    - describe what we are looking for in the spatial correlation
- show the main election results (2012-2024)
- show the religious groups correlations (multiple maps since lots of groups)
- show the spatial correlation of the residuals from a model that predicts the winner in each county (or another predictor)

In [None]:
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 100)

In [None]:
# load the full census election data geojson
census_election_data = gpd.read_file('../../data/election/final_data/county_demographics_with_elections_2012_2024.geojson')

#### Main Content

##### Using multi feature choropleth to visualize the raw data

In [None]:
feature_configs = [
    {
        'name': 'state_winner_2024',
        'scale': 'linear',
        'color_scheme': 'political',
        'legend_name': 'State Winner 2024'
    },
    {
        'name': 'state_winner_2020',
        'scale': 'linear',
        'color_scheme': 'political',
        'legend_name': 'State Winner 2020'
    },
    {
        'name': 'state_winner_2016',
        'scale': 'linear',
        'color_scheme': 'political',
        'legend_name': 'State Winner 2016'
    },
    {
        'name': 'state_winner_2012',
        'scale': 'linear',
        'color_scheme': 'political',
        'legend_name': 'State Winner 2012'
    }
]

map_obj = create_multi_feature_choropleth(
    gdf=census_election_data,
    feature_configs=feature_configs,
    zoom_start=6
)
map_obj

##### Using multi feature spatial correlation mapping

In [None]:
# Example usage
feature_configs = [
    {
        'name': 'flipped_2012_2016',
        'scale': 'linear',
        'color_scheme': 'politcal',
        'legend_name': 'flipped_2012_2016'
    },
    {
        'name': 'flipped_2020_2024',
        'scale': 'linear',
        'color_scheme': 'politcal',
        'legend_name': 'flipped_2020_2024'
    },
    # {
    #     'name': 'religion_catholic_church_percent_adherents_of_total_adherents',
    #     'scale': 'linear',
    #     'color_scheme': 'sequential',
    #     'legend_name': 'CATH Church Adherents %'
    # },
    # {
    #     'name': 'religion_church_of_jesus_christ_of_latter_day_saints_percent_adherents_of_total_adherents',
    #     'scale': 'linear',
    #     'color_scheme': 'sequential',
    #     'legend_name': 'LDS Church Adherents %'
    # },
    # {
    #     'name': 'percent_religious',
    #     'scale': 'linear',
    #     'color_scheme': 'sequential',
    #     'legend_name': 'Percent Population Religious'
    # }
]

results_multi_spatial, interactive_map_multi_spatial = analyze_multi_feature_spatial_correlation(
    gdf=census_election_data,
    feature_configs=feature_configs,
    zoom_start=6
)

# Display results
print_multi_feature_results(results_multi_spatial)

# Display map
interactive_map_multi_spatial

##### Using spatial correlation on residuals from a model that predicts a feature

In [None]:
# Initial model training (only needs to be run once)
trainer = ModelTrainer()
model_results = trainer.fit_predict(
    gdf=census_election_data,
    target_column='winner_2024_numeric',
    exclude_columns=['winner_2020_numeric', 'votes_dem_2012', 'votes_gop_2012', 'per_point_diff_2012', 'votes_dem_2016', 'votes_gop_2016', 'per_point_diff_2016', 
                     'votes_dem_2024', 'votes_gop_2024', 'per_point_diff_2024', 'per_dem_2012', 'per_gop_2012', 'diff_2012', 'diff_2016', 'diff_2020', 'per_dem_2016',
                       'per_gop_2016', 'per_dem_2020', 'per_gop_2020', 'per_dem_2024', 'per_gop_2024', 'winner_2012_numeric', 'winner_2016_numeric', 'state_winner_2012',
                         'state_winner_2016', 'state_winner_2020', 'state_winner_2024', 'per_point_diff_2020'],
    exclude_patterns=['index', 'per_dem', 'per_gop', 'date'],  # patterns to exclude
    verbose=True
)

# Spatial analysis (run once after model training)
analyzer = SpatialAnalyzer()
spatial_results = analyzer.analyze_residuals(
    gdf=census_election_data,
    residuals=model_results.residuals
)

# Create map (can be run multiple times with different options)
visualizer = MapVisualizer()
m = visualizer.create_map(
    gdf=census_election_data,
    target_column='winner_2024_numeric',
    model_results=model_results,
    spatial_results=spatial_results,
    zoom_start=4
)

In [None]:
model_results.feature_importances.head(20)

##### Getting all spatial correlations

In [None]:
# Calculate correlations for all numeric features
results_all_spatial = calculate_multi_feature_correlation(
    census_election_data,
    weight_type='queen',  # or 'knn'
    k_neighbors=5  # only needed if using 'knn'
)

# Get top n features by global Moran's I
top_global = get_top_correlations(
    results_all_spatial,
    n=1000,
    p_threshold=0.05,
    sort_by='global'
)

# Get top n features by local Moran's I
top_local = get_top_correlations(
    results_all_spatial,
    n=1000,
    p_threshold=0.05,
    sort_by='local'
)

#### Code

##### Multi feature Choropleth (folium)

In [None]:
import folium
import numpy as np
from branca.colormap import LinearColormap
import pandas as pd

# Define color schemes
COLOR_SCHEMES = {
    'population': ['#fee5d9', '#fcae91', '#fb6a4a', '#de2d26', '#a50f15'],  # Reds
    'political': ['#0571b0', '#92c5de', '#f7f7f7', '#f4a582', '#ca0020'],   # Blue-Red diverging
    'sequential': ['#ffffcc', '#a1dab4', '#41b6c4', '#2c7fb8', '#253494'],  # Blue-Green
    'viridis': ['#440154', '#414487', '#2a788e', '#22a884', '#7ad151', '#fde725']  # Viridis
}

def create_multi_feature_choropleth(
    gdf,
    feature_configs,
    center=None,
    zoom_start=6,
    width='100%',
    height='100%'
):
    """
    Create an interactive choropleth map with multiple togglable feature layers using Folium.
    
    Parameters:
    -----------
    gdf : GeoDataFrame
        The geodataframe containing geographic and feature data
    feature_configs : list of dict
        List of feature configurations, each containing:
        {
            'name': str (column name in gdf),
            'scale': str ('linear' or 'log'),
            'color_scheme': str (optional, one of 'population', 'political', 'sequential'),
            'legend_name': str (optional, display name for the legend)
        }
    center : tuple, optional
        (lat, lon) center coordinates. If None, will use centroid of the data
    zoom_start : int, default 6
        Initial zoom level
    width : str, default '100%'
        Width of the map in percentage
    height : str, default '100%'
        Height of the map in percentage
        
    Returns:
    --------
    folium.Map
        The interactive map object
    """
    # Input validation
    for config in feature_configs:
        if config['name'] not in gdf.columns:
            raise ValueError(f"Feature '{config['name']}' not found in the GeoDataFrame")
        if config.get('scale') not in ['linear', 'log']:
            raise ValueError(f"Scale for feature '{config['name']}' must be 'linear' or 'log'")
    
    # Calculate center if not provided
    if center is None:
        center = [
            gdf.geometry.centroid.y.mean(),
            gdf.geometry.centroid.x.mean()
        ]
    
    # Create base map
    m = folium.Map(
        location=center,
        zoom_start=zoom_start,
        width=width,
        height=height
    )
    
    # Create feature groups dictionary
    feature_groups = {}
    
    for config in feature_configs:
        feature_name = config['name']
        scale_type = config['scale']
        # Use viridis for log scale, otherwise use specified color scheme
        color_scheme = 'viridis' if scale_type == 'log' else config.get('color_scheme', 'sequential')
        legend_name = config.get('legend_name', feature_name)
        
        # Get appropriate color scheme
        colors = COLOR_SCHEMES.get(color_scheme, COLOR_SCHEMES['sequential'])
        
        # Create feature group
        fg = folium.FeatureGroup(name=legend_name, show=False)
        
        # Handle data scaling and preparation
        data = gdf[feature_name].copy()
        if scale_type == 'log':
            min_val = data.min()
            if min_val <= 0:
                print(f"Warning: {feature_name} contains values ≤ 0. Adding offset for log scale.")
                data = data - min_val + 1
            # Apply log transformation
            data = np.log(data)
        
        # Create colormap
        colormap = LinearColormap(
            colors=colors,
            vmin=data.min(),
            vmax=data.max(),
            caption=legend_name + (' (log scale)' if scale_type == 'log' else '')
        )
        
        # Add choropleth layer
        for idx, row in gdf.iterrows():
            value = row[feature_name]
            if scale_type == 'log':
                if value <= 0:
                    value = value - min_val + 1
                value = np.log(value)
                
            color = colormap(value)
            
            # Format display value
            display_value = row[feature_name]
            if isinstance(display_value, (int, float)):
                if abs(display_value) >= 1000000:
                    display_value = f"{display_value/1000000:.2f}M"
                elif abs(display_value) >= 1000:
                    display_value = f"{display_value/1000:.2f}K"
                elif feature_name.startswith('per_'):  # Percentage values
                    display_value = f"{display_value:.2f}%"
                else:
                    display_value = f"{display_value:.2f}"
            
            # Get state and county names
            state_name = row.get('state_name', row.get('STATE_NAME', ''))
            county_name = row.get('county_name', row.get('COUNTY_NAME', ''))
            
            # Create GeoJSON-style feature
            feature = {
                'type': 'Feature',
                'geometry': row.geometry.__geo_interface__,
                'properties': {
                    'value': display_value,
                    'state': state_name,
                    'county': county_name
                }
            }
            
            # Add polygon to feature group
            folium.GeoJson(
                feature,
                style_function=lambda x, color=color: {
                    'fillColor': color,
                    'color': 'black',
                    'weight': 1,
                    'fillOpacity': 0.7
                },
                tooltip=folium.GeoJsonTooltip(
                    fields=['county', 'state', 'value'],
                    aliases=['County', 'State', legend_name],
                    localize=True,
                    sticky=False,
                    labels=True,
                    style="""
                        background-color: white;
                        border: 2px solid black;
                        border-radius: 3px;
                        box-shadow: 3px 3px 3px rgba(0,0,0,0.2);
                        padding: 5px;
                        font-size: 12px;
                    """
                )
            ).add_to(fg)
        
        # Add colormap to map
        colormap.add_to(m)
        
        # Add feature group to map
        fg.add_to(m)
        
        feature_groups[legend_name] = fg
    
    # Add custom JavaScript for radio button behavior
    layer_control_html = """
    <script type="text/javascript">
    function setupLayerControl() {
        // Get layer control container
        var layerControlContainer = document.querySelector('.leaflet-control-layers-overlays');
        if (!layerControlContainer) {
            setTimeout(setupLayerControl, 100);
            return;
        }

        // Replace checkboxes with radio buttons
        var inputs = layerControlContainer.querySelectorAll('input[type="checkbox"]');
        inputs.forEach(function(input) {
            input.type = 'radio';
            input.name = 'layer-control';
            
            // Update the input's checked state based on layer visibility
            var layer = input._layer;
            if (layer && layer._map) {
                input.checked = layer._map.hasLayer(layer);
            }
            
            // Add change listener to handle layer toggling
            input.addEventListener('change', function(e) {
                if (this.checked) {
                    // Uncheck and hide all other layers
                    inputs.forEach(function(otherInput) {
                        if (otherInput !== input) {
                            otherInput.checked = false;
                            if (otherInput._layer) {
                                otherInput._layer.remove();
                            }
                        }
                    });
                    
                    // Show the selected layer
                    if (this._layer) {
                        this._layer.addTo(this._layer._map);
                    }
                }
            });
        });

        // Monitor layer visibility changes
        var map = document.querySelector('#map');  // Assuming map has id="map"
        if (map && map._leaflet_map) {
            map._leaflet_map.on('layeradd layerremove', function(e) {
                inputs.forEach(function(input) {
                    if (input._layer === e.layer) {
                        input.checked = e.type === 'layeradd';
                    }
                });
            });
        }
    }

    // Initialize when DOM is ready
    if (document.readyState === 'loading') {
        document.addEventListener('DOMContentLoaded', setupLayerControl);
    } else {
        setupLayerControl();
    }

    // Also set up observer to handle dynamic updates
    var observer = new MutationObserver(function(mutations) {
        mutations.forEach(function(mutation) {
            if (mutation.addedNodes.length) {
                setupLayerControl();
            }
        });
    });

    observer.observe(document.body, {
        childList: true,
        subtree: true
    });
    </script>
    """
    
    # Add layer control to map
    folium.LayerControl().add_to(m)
    
    # Add custom JavaScript to handle radio button behavior
    m.get_root().html.add_child(folium.Element(layer_control_html))
    
    return m

##### Spatial Correlation mapping

In [None]:
import folium
from folium import plugins
import pandas as pd
import geopandas as gpd
import numpy as np
from libpysal.weights import Queen, KNN
from esda.moran import Moran, Moran_Local
import warnings
from typing import Union, Tuple, Dict, List
from branca.colormap import LinearColormap

def analyze_multi_feature_spatial_correlation(
    gdf: gpd.GeoDataFrame,
    feature_configs: List[Dict],
    weight_type: str = 'queen',
    k_neighbors: int = 5,
    center: Tuple[float, float] = None,
    zoom_start: int = 4
) -> Tuple[Dict, folium.Map]:
    """
    Creates an interactive Folium map showing spatial correlation analysis for multiple features.
    
    Parameters:
    -----------
    gdf : gpd.GeoDataFrame
        GeoDataFrame containing geometry and data
    feature_configs : list of dict
        List of feature configurations, each containing:
        {
            'name': str (column name in gdf),
            'scale': str ('linear' or 'log', optional, default='linear'),
            'color_scheme': str (optional, one of 'population', 'political', 'sequential', 'viridis'),
            'legend_name': str (optional, display name for the legend)
        }
    weight_type : str, optional (default='queen')
        Type of spatial weights to use ('queen' or 'knn')
    k_neighbors : int, optional (default=5)
        Number of neighbors for KNN weights
    center : tuple, optional
        (lat, lon) center coordinates for the map
    zoom_start : int, optional (default=4)
        Initial zoom level for the map
    """
    # Input validation
    if not isinstance(gdf, gpd.GeoDataFrame):
        raise TypeError("Input must be a GeoDataFrame")
    
    for config in feature_configs:
        if config['name'] not in gdf.columns:
            raise ValueError(f"Feature '{config['name']}' not found in GeoDataFrame")
    
    # Create copy to avoid modifying original
    gdf_analysis = gdf.copy()
    
    # Create spatial weights matrix
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        if weight_type.lower() == 'queen':
            weights = Queen.from_dataframe(gdf_analysis, use_index=True)
        elif weight_type.lower() == 'knn':
            weights = KNN.from_dataframe(gdf_analysis, k=k_neighbors)
        else:
            raise ValueError("weight_type must be either 'queen' or 'knn'")
    
    # Handle islands if using queen weights
    if weight_type.lower() == 'queen':
        islands = weights.islands
        if len(islands) > 0:
            print(f"Removing {len(islands)} isolated areas from analysis")
            gdf_analysis = gdf_analysis[~gdf_analysis.index.isin(islands)].copy()
            weights = Queen.from_dataframe(gdf_analysis, use_index=True)
    
    weights.transform = 'r'  # Row-standardize weights
    
    # Calculate map center if not provided
    if center is None:
        center = [
            gdf_analysis.geometry.centroid.y.mean(),
            gdf_analysis.geometry.centroid.x.mean()
        ]
    
    # Create base map
    m = folium.Map(location=center, zoom_start=zoom_start)
    
    # Color schemes
    COLOR_SCHEMES = {
        'population': ['#fee5d9', '#fcae91', '#fb6a4a', '#de2d26', '#a50f15'],
        'political': ['#0571b0', '#92c5de', '#f7f7f7', '#f4a582', '#ca0020'],
        'sequential': ['#ffffcc', '#a1dab4', '#41b6c4', '#2c7fb8', '#253494'],
        'viridis': ['#440154', '#414487', '#2a788e', '#22a884', '#7ad151', '#fde725']
    }
    
    # Store results for all features
    all_results = {}
    
    for config in feature_configs:
        feature_name = config['name']
        scale_type = config.get('scale', 'linear')
        color_scheme = config.get('color_scheme', 'sequential')
        legend_name = config.get('legend_name', feature_name)
        
        # Handle missing values for this feature
        if gdf_analysis[feature_name].isnull().any():
            print(f"Warning: {gdf_analysis[feature_name].isnull().sum()} missing values found in {feature_name}")
            feature_data = gdf_analysis.dropna(subset=[feature_name]).copy()
        else:
            feature_data = gdf_analysis.copy()
        
        # Standardize the variable
        if scale_type == 'log':
            # Handle log transformation for non-negative values
            min_val = feature_data[feature_name].min()
            if min_val <= 0:
                offset = abs(min_val) + 1
                feature_data[f'{feature_name}_standardized'] = np.log(feature_data[feature_name] + offset)
            else:
                feature_data[f'{feature_name}_standardized'] = np.log(feature_data[feature_name])
        else:
            feature_data[f'{feature_name}_standardized'] = feature_data[feature_name]
        
        # Z-score standardization
        feature_data[f'{feature_name}_standardized'] = (
            feature_data[f'{feature_name}_standardized'] - 
            feature_data[f'{feature_name}_standardized'].mean()
        ) / feature_data[f'{feature_name}_standardized'].std()
        
        # Calculate Global Moran's I using standardized values
        moran = Moran(feature_data[f'{feature_name}_standardized'], weights)
        
        # Calculate Local Moran's I using standardized values
        local_moran = Moran_Local(feature_data[f'{feature_name}_standardized'], weights)
        
        # Add local indicators to dataframe
        feature_data[f'{feature_name}_local_i'] = local_moran.Is
        feature_data[f'{feature_name}_p_value'] = local_moran.p_sim
        
        # Classify clusters
        feature_data[f'{feature_name}_cluster'] = 'Not Significant'
        sig_mask = local_moran.p_sim < 0.05
        
        # Use standardized values for clustering
        std_val = feature_data[f'{feature_name}_standardized']
        lag_val = weights.sparse.dot(std_val)
        
        # Assign cluster types
        feature_data.loc[sig_mask & (std_val > 0) & (lag_val > 0), f'{feature_name}_cluster'] = 'High-High'
        feature_data.loc[sig_mask & (std_val < 0) & (lag_val < 0), f'{feature_name}_cluster'] = 'Low-Low'
        feature_data.loc[sig_mask & (std_val > 0) & (lag_val < 0), f'{feature_name}_cluster'] = 'High-Low'
        feature_data.loc[sig_mask & (std_val < 0) & (lag_val > 0), f'{feature_name}_cluster'] = 'Low-High'
        
        # Create feature groups for each layer type
        fg_original = folium.FeatureGroup(name=f"{legend_name} - Values")
        fg_moran = folium.FeatureGroup(name=f"{legend_name} - Local Moran's I")
        fg_clusters = folium.FeatureGroup(name=f"{legend_name} - Clusters")
        
        # Get colors for the schemes
        colors = COLOR_SCHEMES.get(color_scheme, COLOR_SCHEMES['sequential'])
        
        # Create colormaps
        colormap_original = LinearColormap(
            colors=colors,
            vmin=feature_data[feature_name].min(),
            vmax=feature_data[feature_name].max(),
            caption=f"{legend_name} Values"
        )
        
        colormap_moran = LinearColormap(
            colors=['#ca0020', '#f4a582', '#f7f7f7', '#92c5de', '#0571b0'],
            vmin=feature_data[f'{feature_name}_local_i'].min(),
            vmax=feature_data[f'{feature_name}_local_i'].max(),
            caption=f"{legend_name} Local Moran's I"
        )
        
        cluster_colors = {
            'High-High': '#d7191c',
            'Low-Low': '#2c7bb6',
            'High-Low': '#fdae61',
            'Low-High': '#abd9e9',
            'Not Significant': '#ffffbf'
        }
        
        # Add layers
        for idx, row in feature_data.iterrows():
            # Value distribution layer
            color = colormap_original(row[feature_name])
            folium.GeoJson(
                row.geometry.__geo_interface__,
                style_function=lambda x, color=color: {
                    'fillColor': color,
                    'color': 'black',
                    'weight': 1,
                    'fillOpacity': 0.7
                },
                tooltip=f"{legend_name}: {row[feature_name]:.2f}"
            ).add_to(fg_original)
            
            # Local Moran's I layer
            color = colormap_moran(row[f'{feature_name}_local_i'])
            folium.GeoJson(
                row.geometry.__geo_interface__,
                style_function=lambda x, color=color: {
                    'fillColor': color,
                    'color': 'black',
                    'weight': 1,
                    'fillOpacity': 0.7
                },
                tooltip=f"Local Moran's I: {row[f'{feature_name}_local_i']:.3f}<br>P-value: {row[f'{feature_name}_p_value']:.3f}"
            ).add_to(fg_moran)
            
            # Cluster type layer
            color = cluster_colors[row[f'{feature_name}_cluster']]
            folium.GeoJson(
                row.geometry.__geo_interface__,
                style_function=lambda x, color=color: {
                    'fillColor': color,
                    'color': 'black',
                    'weight': 1,
                    'fillOpacity': 0.7
                },
                tooltip=f"Cluster Type: {row[f'{feature_name}_cluster']}"
            ).add_to(fg_clusters)
        
        # Add feature groups to map
        fg_original.add_to(m)
        fg_moran.add_to(m)
        fg_clusters.add_to(m)
        
        # Add colormaps to map
        colormap_original.add_to(m)
        colormap_moran.add_to(m)
        
        # Store results for this feature
        all_results[feature_name] = {
            'global_statistics': {
                'morans_i': moran.I,
                'p_value': moran.p_sim,
                'z_score': moran.z_sim
            },
            'cluster_summary': feature_data[f'{feature_name}_cluster'].value_counts().to_dict(),
            'local_statistics': {
                'mean_local_i': feature_data[f'{feature_name}_local_i'].mean(),
                'significant_clusters': (feature_data[f'{feature_name}_p_value'] < 0.05).sum(),
                'percent_significant': (feature_data[f'{feature_name}_p_value'] < 0.05).mean() * 100
            }
        }
    
    # Add cluster type legend
    legend_html = """
    <div style="position: fixed; bottom: 50px; right: 50px; z-index: 1000; background-color: white; 
                padding: 10px; border: 2px solid grey; border-radius: 5px">
    <p><strong>Cluster Types</strong></p>
    """
    for cluster_type, color in cluster_colors.items():
        legend_html += f"""
        <p><i class="fa fa-square fa-1x" style="color:{color}"></i> {cluster_type}</p>
        """
    legend_html += "</div>"
    m.get_root().html.add_child(folium.Element(legend_html))
    
    # Add layer control
    folium.LayerControl().add_to(m)
    
    return all_results, m

def print_multi_feature_results(results: Dict):
    """
    Prints formatted results for multiple features.
    """
    for feature_name, feature_results in results.items():
        print(f"\nSpatial Correlation Analysis Results for {feature_name}\n")
        print("=" * 50)
        
        print("\nGlobal Moran's I Statistics:")
        print(f"Moran's I: {feature_results['global_statistics']['morans_i']:.3f}")
        print(f"P-value: {feature_results['global_statistics']['p_value']:.3f}")
        print(f"Z-score: {feature_results['global_statistics']['z_score']:.3f}")
        
        print("\nLocal Statistics:")
        print(f"Mean Local Moran's I: {feature_results['local_statistics']['mean_local_i']:.3f}")
        print(f"Number of Significant Clusters: {feature_results['local_statistics']['significant_clusters']}")
        print(f"Percent Significant: {feature_results['local_statistics']['percent_significant']:.1f}%")
        
        print("\nCluster Type Distribution:")
        for cluster_type, count in feature_results['cluster_summary'].items():
            print(f"{cluster_type}: {count} areas")

def scale_morans_i(morans_i_values):
    """
    Scale Moran's I values to ensure they fall within [-1, 1] range.
    """
    abs_max = max(abs(np.min(morans_i_values)), abs(np.max(morans_i_values)))
    if abs_max == 0:
        return morans_i_values
    return morans_i_values / abs_max

##### Spatial correlation on residuals from predicting target feature

In [None]:
import pandas as pd
import numpy as np
import geopandas as gpd
import folium
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error
from libpysal.weights import Queen
from esda.moran import Moran, Moran_Local
from branca.colormap import LinearColormap
from typing import Tuple, Dict, List, Union, Optional
from dataclasses import dataclass

@dataclass
class ModelResults:
    """Container for model results to pass between components"""
    residuals: np.ndarray
    feature_importances: pd.DataFrame
    predictions: np.ndarray
    metrics: Dict
    
@dataclass
class SpatialResults:
    """Container for spatial analysis results"""
    moran_stats: Dict
    spatial_weights: Queen
    cluster_types: pd.Series = None

class ModelTrainer:
    def __init__(self):
        self.model = RandomForestRegressor(n_estimators=100, random_state=42)
        self.scaler = StandardScaler()
        self.feature_columns = None
        
    def _get_numeric_columns(
        self,
        gdf: gpd.GeoDataFrame,
        target_column: str,
        exclude_columns: Optional[List[str]] = None,
        exclude_patterns: Optional[List[str]] = None
    ) -> List[str]:
        """
        Get all numeric columns except the target and excluded columns.
        
        Parameters:
        -----------
        gdf : GeoDataFrame
            Input geodataframe
        target_column : str
            Name of the target column to exclude
        exclude_columns : list, optional
            List of specific column names to exclude
        exclude_patterns : list, optional
            List of patterns to match for excluding columns
            
        Returns:
        --------
        list : List of column names to use as features
        """
        # Get all numeric columns
        numeric_cols = gdf.select_dtypes(include=['int64', 'float64']).columns.tolist()
        
        # Remove target column if it's in numeric columns
        if target_column in numeric_cols:
            numeric_cols.remove(target_column)
            
        # Remove explicitly excluded columns
        if exclude_columns:
            numeric_cols = [col for col in numeric_cols if col not in exclude_columns]
            
        # Remove columns matching patterns
        if exclude_patterns:
            for pattern in exclude_patterns:
                numeric_cols = [
                    col for col in numeric_cols 
                    if not any(pattern.lower() in col.lower() for pattern in exclude_patterns)
                ]
        
        return numeric_cols
        
    def fit_predict(
        self,
        gdf: gpd.GeoDataFrame,
        target_column: str,
        exclude_columns: Optional[List[str]] = None,
        exclude_patterns: Optional[List[str]] = None,
        test_size: float = 0.2,
        verbose: bool = True
    ) -> ModelResults:
        # [Previous fit_predict implementation until predictions]
        """
        Fit the model and calculate residuals using all numeric columns as features.
        
        Parameters:
        -----------
        gdf : GeoDataFrame
            Input geodataframe containing features and target
        target_column : str
            Name of the column to predict
        exclude_columns : list, optional
            List of specific column names to exclude from features
        exclude_patterns : list, optional
            List of patterns to match for excluding columns
        test_size : float
            Proportion of data to use for testing
        verbose : bool
            Whether to print information about selected features
            
        Returns:
        --------
        dict : Dictionary containing model performance metrics
        """
        # Validate inputs
        if not isinstance(gdf, gpd.GeoDataFrame):
            raise TypeError("Input must be a GeoDataFrame")
        if target_column not in gdf.columns:
            raise ValueError(f"Target column '{target_column}' not found")
            
        # Get numeric feature columns
        self.feature_columns = self._get_numeric_columns(
            gdf, target_column, exclude_columns, exclude_patterns
        )
        
        if not self.feature_columns:
            raise ValueError("No numeric feature columns found after exclusions")
            
        if verbose:
            print(f"Using {len(self.feature_columns)} features:")
            print("\n".join(f"- {col}" for col in self.feature_columns))
            
        # Prepare data
        X = gdf[self.feature_columns].copy()
        y = gdf[target_column].copy()
        
        # Check for and handle missing values
        missing_counts = X.isnull().sum()
        if missing_counts.any():
            if verbose:
                print("\nHandling missing values:")
                for col in missing_counts[missing_counts > 0].index:
                    print(f"- {col}: {missing_counts[col]} missing values")
            X = X.fillna(X.mean())
        
        # Scale features
        X_scaled = self.scaler.fit_transform(X)
        
        # Split data
        X_train, X_test, y_train, y_test = train_test_split(
            X_scaled, y, test_size=test_size, random_state=42
        )
        
        # Fit model
        if verbose:
            print("\nFitting Random Forest model...")
        self.model.fit(X_train, y_train)
        
        # Adjust residuals calculation for GOP=1, Dem=0 coding
        y_pred = self.model.predict(X_scaled)
        residuals = y - y_pred  # Positive means model under-predicted GOP

        # Calculate performance metrics
        train_pred = self.model.predict(X_train)
        test_pred = self.model.predict(X_test)

        # Store feature importances
        self.feature_importances = pd.DataFrame({
            'feature': self.feature_columns,
            'importance': self.model.feature_importances_
        }).sort_values('importance', ascending=False)
        
        metrics = {
            'train_r2': r2_score(y_train, train_pred),
            'test_r2': r2_score(y_test, test_pred),
            'train_rmse': np.sqrt(mean_squared_error(y_train, train_pred)),
            'test_rmse': np.sqrt(mean_squared_error(y_test, test_pred)),
        }
        
        return ModelResults(
            residuals=residuals,
            feature_importances=self.feature_importances,
            predictions=y_pred,
            metrics=metrics
        )

class SpatialAnalyzer:
    def analyze_residuals(
        self,
        gdf: gpd.GeoDataFrame,
        residuals: np.ndarray,
        weights_type: str = 'queen',
        verbose: bool = True
    ) -> SpatialResults:
        # Create spatial weights matrix
        spatial_weights = Queen.from_dataframe(gdf)
        spatial_weights.transform = 'r'
        
        # Calculate Global Moran's I
        moran = Moran(residuals, spatial_weights)
        local_moran = Moran_Local(residuals, spatial_weights)
        
        moran_stats = {
            'global_moran_i': moran.I,
            'p_value': moran.p_sim,
            'z_score': moran.z_sim,
            'local_moran_i': local_moran.Is,
            'local_p_values': local_moran.p_sim
        }
        
        return SpatialResults(moran_stats=moran_stats, spatial_weights=spatial_weights)

In [None]:
class MapVisualizer:
    def create_map(
        self,
        gdf: gpd.GeoDataFrame,
        target_column: str,
        model_results: ModelResults,
        spatial_results: SpatialResults,
        center: Tuple[float, float] = None,
        zoom_start: int = 4
    ) -> folium.Map:
        """
        Create interactive map with election results and analysis layers.
        
        Parameters:
        -----------
        gdf : GeoDataFrame
            Input geodataframe
        target_column : str
            Name of the column containing actual results (0=Dem, 1=GOP)
        model_results : ModelResults
            Container with model predictions and residuals
        spatial_results : SpatialResults
            Container with spatial analysis results
        center : tuple, optional
            (lat, lon) center coordinates for the map
        zoom_start : int, default=4
            Initial zoom level for the map
        """
        if center is None:
            center = [gdf.geometry.centroid.y.mean(), gdf.geometry.centroid.x.mean()]
        
        # Create base map
        m = folium.Map(location=center, zoom_start=zoom_start)
        
        # Create feature groups
        fg_results = folium.FeatureGroup(name="2024 Election Results")
        fg_residuals = folium.FeatureGroup(name="Model Residuals")
        fg_clusters = folium.FeatureGroup(name="Spatial Correlation Patterns")
        
        # Create colormaps
        colormap_results = LinearColormap(
            colors=['#2b83ba', '#ffffff', '#de2d26'],  # Blue (Dem/0) to Red (GOP/1)
            vmin=0,
            vmax=1,
            caption="2024 Election Results (Blue = Democratic, Red = GOP)"
        )
        
        # Residual colormap interpretation: positive means under-predicted GOP
        colormap_residuals = LinearColormap(
            colors=['#2c7bb6', '#abd9e9', '#ffffbf', '#fdae61', '#d7191c'],
            vmin=model_results.residuals.min(),
            vmax=model_results.residuals.max(),
            caption="Residuals (Blue = Under-predicted GOP, Red = Over-predicted GOP)"
        )
        
        # Calculate standardized residuals and spatial lag
        std_residuals = (model_results.residuals - model_results.residuals.mean()) / model_results.residuals.std()
        lag_residuals = spatial_results.spatial_weights.sparse.dot(std_residuals)
        
        # Determine cluster types
        significance_level = 0.05
        sig_mask = spatial_results.moran_stats['local_p_values'] < significance_level
        
        cluster_types = pd.Series(index=gdf.index, data='Not Significant')
        
        # High-High: both residual and lag are positive (clusters of GOP under-prediction)
        cluster_types[sig_mask & (std_residuals > 0) & (lag_residuals > 0)] = 'High-High'
        
        # Low-Low: both residual and lag are negative (clusters of GOP over-prediction)
        cluster_types[sig_mask & (std_residuals < 0) & (lag_residuals < 0)] = 'Low-Low'
        
        # High-Low: residual is positive, lag is negative (GOP under-prediction outliers)
        cluster_types[sig_mask & (std_residuals > 0) & (lag_residuals < 0)] = 'High-Low'
        
        # Low-High: residual is negative, lag is positive (GOP over-prediction outliers)
        cluster_types[sig_mask & (std_residuals < 0) & (lag_residuals > 0)] = 'Low-High'
        
        # Define colors for cluster types
        cluster_colors = {
            'High-High': '#d7191c',     # Red: Under-predicted GOP clusters
            'Low-Low': '#2c7bb6',       # Blue: Over-predicted GOP clusters
            'High-Low': '#fdae61',      # Orange: Under-predicted GOP outliers
            'Low-High': '#abd9e9',      # Light Blue: Over-predicted GOP outliers
            'Not Significant': '#ffffbf' # Beige: No significant pattern
        }
        
        # Add election results layer
        for idx, row in gdf.iterrows():
            result = row[target_column]
            winner = 'Democratic' if result == 0 else 'GOP'
            winner_color = '#2b83ba' if result == 0 else '#de2d26'
            
            folium.GeoJson(
                row.geometry.__geo_interface__,
                style_function=lambda x, color=winner_color: {
                    'fillColor': color,
                    'color': 'black',
                    'weight': 1,
                    'fillOpacity': 0.7
                },
                tooltip=f"""
                    Winner: {winner}<br>
                    County: {row.get('county_name', 'N/A')}<br>
                    State: {row.get('state_name', 'N/A')}
                """
            ).add_to(fg_results)
        
        # Add residuals layer
        for idx, row in gdf.iterrows():
            residual = model_results.residuals[idx]
            color = colormap_residuals(residual)
            
            prediction = model_results.predictions[idx]
            predicted_winner = 'GOP' if prediction > 0.5 else 'Democratic'
            actual_winner = 'GOP' if row[target_column] == 1 else 'Democratic'
            
            folium.GeoJson(
                row.geometry.__geo_interface__,
                style_function=lambda x, color=color: {
                    'fillColor': color,
                    'color': 'black',
                    'weight': 1,
                    'fillOpacity': 0.7
                },
                tooltip=f"""
                    County: {row.get('county_name', 'N/A')}<br>
                    State: {row.get('state_name', 'N/A')}<br>
                    Residual: {residual:.3f}<br>
                    Actual Winner: {actual_winner}<br>
                    Predicted Winner: {predicted_winner}<br>
                    Model Prediction: {prediction:.3f}
                """
            ).add_to(fg_residuals)
        
        # Add cluster type layer
        for idx, row in gdf.iterrows():
            cluster_type = cluster_types[idx]
            color = cluster_colors[cluster_type]
            
            interpretation = {
                'High-High': 'Cluster of GOP under-prediction',
                'Low-Low': 'Cluster of GOP over-prediction',
                'High-Low': 'GOP under-prediction outlier',
                'Low-High': 'GOP over-prediction outlier',
                'Not Significant': 'No significant spatial pattern'
            }
            
            folium.GeoJson(
                row.geometry.__geo_interface__,
                style_function=lambda x, color=color: {
                    'fillColor': color,
                    'color': 'black',
                    'weight': 1,
                    'fillOpacity': 0.7
                },
                tooltip=f"""
                    County: {row.get('county_name', 'N/A')}<br>
                    State: {row.get('state_name', 'N/A')}<br>
                    Pattern: {cluster_type}<br>
                    Interpretation: {interpretation[cluster_type]}<br>
                    Residual: {model_results.residuals[idx]:.3f}
                """
            ).add_to(fg_clusters)
        
        # Add layers to map
        fg_results.add_to(m)
        fg_residuals.add_to(m)
        fg_clusters.add_to(m)
        colormap_results.add_to(m)
        colormap_residuals.add_to(m)
        
        # Add cluster type legend
        legend_html = """
        <div style="position: fixed; bottom: 50px; right: 50px; z-index: 1000; background-color: white; 
                    padding: 10px; border: 2px solid grey; border-radius: 5px">
        <p><strong>Spatial Correlation Patterns</strong></p>
        """
        for cluster_type, color in cluster_colors.items():
            interpretation = {
                'High-High': 'Cluster of GOP under-prediction',
                'Low-Low': 'Cluster of GOP over-prediction',
                'High-Low': 'GOP under-prediction outlier',
                'Low-High': 'GOP over-prediction outlier',
                'Not Significant': 'No significant spatial pattern'
            }
            legend_html += f"""
            <p><i class="fa fa-square fa-1x" style="color:{color}"></i> 
            {cluster_type}: {interpretation[cluster_type]}</p>
            """
        legend_html += "</div>"
        m.get_root().html.add_child(folium.Element(legend_html))
        
        # Add layer control
        folium.LayerControl().add_to(m)
        
        return m

##### All spatial correlations

In [None]:
import pandas as pd
import geopandas as gpd
import numpy as np
from libpysal.weights import Queen, KNN
from esda.moran import Moran, Moran_Local
import warnings
from typing import Union, Tuple, Dict, List
from dataclasses import dataclass
from collections import defaultdict

@dataclass
class SpatialCorrelationResult:
    """Data class to store correlation results for a single feature"""
    feature: str
    global_morans_i: float
    global_p_value: float
    global_z_score: float
    mean_local_i: float
    significant_clusters: int
    percent_significant: float
    cluster_counts: Dict[str, int]

def calculate_multi_feature_correlation(
    gdf: gpd.GeoDataFrame,
    features: List[str] = None,
    weight_type: str = 'queen',
    k_neighbors: int = 5,
    exclude_columns: List[str] = None
) -> List[SpatialCorrelationResult]:
    """
    Calculates spatial correlation metrics for multiple features in a GeoDataFrame.
    
    Parameters:
    -----------
    gdf : gpd.GeoDataFrame
        GeoDataFrame containing geometry and data
    features : List[str], optional
        List of column names to analyze. If None, will analyze all numeric columns
    weight_type : str, optional (default='queen')
        Type of spatial weights to use ('queen' or 'knn')
    k_neighbors : int, optional (default=5)
        Number of neighbors for KNN weights
    exclude_columns : List[str], optional
        List of column names to exclude from analysis
        
    Returns:
    --------
    List[SpatialCorrelationResult]
        List of correlation results for each feature
    """
    # Input validation
    if not isinstance(gdf, gpd.GeoDataFrame):
        raise TypeError("Input must be a GeoDataFrame")
        
    # If no features specified, use all numeric columns
    if features is None:
        features = gdf.select_dtypes(include=[np.number]).columns.tolist()
    
    # Remove excluded columns
    if exclude_columns:
        features = [f for f in features if f not in exclude_columns]
    
    results = []
    
    # Create spatial weights matrix once
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        if weight_type.lower() == 'queen':
            weights = Queen.from_dataframe(gdf, use_index=True)
            # Handle islands
            islands = weights.islands
            if len(islands) > 0:
                gdf = gdf[~gdf.index.isin(islands)].copy()
                weights = Queen.from_dataframe(gdf, use_index=True)
        elif weight_type.lower() == 'knn':
            weights = KNN.from_dataframe(gdf, k=k_neighbors)
        else:
            raise ValueError("weight_type must be either 'queen' or 'knn'")
    
    weights.transform = 'r'  # Row-standardize weights
    
    # Analyze each feature
    for feature in features:
        # Skip if feature has all missing values
        if gdf[feature].isnull().all():
            continue
            
        # Create copy of data without missing values for this feature
        gdf_clean = gdf.dropna(subset=[feature]).copy()
        
        if len(gdf_clean) < 2:  # Skip if not enough data
            continue
            
        try:
            # Calculate Global Moran's I
            moran = Moran(gdf_clean[feature], weights)
            
            # Calculate Local Moran's I
            local_moran = Moran_Local(gdf_clean[feature], weights)
            
            # Calculate cluster types
            sig_mask = local_moran.p_sim < 0.05
            std_val = (gdf_clean[feature] - gdf_clean[feature].mean()) / gdf_clean[feature].std()
            lag_val = weights.sparse.dot(std_val)
            
            cluster_types = np.full(len(gdf_clean), 'Not Significant', dtype=object)
            cluster_types[sig_mask & (std_val > 0) & (lag_val > 0)] = 'High-High'
            cluster_types[sig_mask & (std_val < 0) & (lag_val < 0)] = 'Low-Low'
            cluster_types[sig_mask & (std_val > 0) & (lag_val < 0)] = 'High-Low'
            cluster_types[sig_mask & (std_val < 0) & (lag_val > 0)] = 'Low-High'
            
            cluster_counts = dict(pd.Series(cluster_types).value_counts())
            
            result = SpatialCorrelationResult(
                feature=feature,
                global_morans_i=moran.I,
                global_p_value=moran.p_sim,
                global_z_score=moran.z_sim,
                mean_local_i=np.mean(local_moran.Is),
                significant_clusters=np.sum(sig_mask),
                percent_significant=np.mean(sig_mask) * 100,
                cluster_counts=cluster_counts
            )
            
            results.append(result)
            
        except Exception as e:
            print(f"Error analyzing feature {feature}: {str(e)}")
            continue
            
    return results

def get_top_correlations(
    results: List[SpatialCorrelationResult],
    n: int = 5,
    p_threshold: float = 0.05,
    sort_by: str = 'global'
) -> pd.DataFrame:
    """
    Returns the top n features by spatial correlation.
    
    Parameters:
    -----------
    results : List[SpatialCorrelationResult]
        List of correlation results
    n : int, optional (default=5)
        Number of top features to return
    p_threshold : float, optional (default=0.05)
        Only include results with p-value below this threshold
    sort_by : str, optional (default='global')
        Sort by 'global' or 'local' Moran's I
        
    Returns:
    --------
    pd.DataFrame
        Sorted DataFrame of top correlations
    """
    # Convert results to DataFrame
    df = pd.DataFrame([
        {
            'feature': r.feature,
            'global_morans_i': r.global_morans_i,
            'global_p_value': r.global_p_value,
            'global_z_score': r.global_z_score,
            'mean_local_i': r.mean_local_i,
            'significant_clusters': r.significant_clusters,
            'percent_significant': r.percent_significant
        }
        for r in results
    ])
    
    # Filter by p-value
    df = df[df['global_p_value'] < p_threshold]
    
    # Sort by specified metric
    sort_col = 'global_morans_i' if sort_by == 'global' else 'mean_local_i'
    df = df.sort_values(by=sort_col, ascending=False)
    
    return df.head(n)