<!--  
title: End2End Spatial RecSys
author: Juan Guillermo  
description: End-to-end implementation of a recommendation system for retail store placement.  
             The system analyzes historical store locations and integrates fine-grained  
             demographic and socioeconomic data from U.S. ZIP codes. Using machine learning,  
             it predicts optimal locations for new stores by predicting ZIP codes with a high  
             probability of a new opening. The underlying probability model ranks  
             potential store locations in U.S. states not included in the training data.  
image_path: assets/site_recommendation_system_network_plot.html
-->

In [1]:
excluded_states = "Pennsylvania,Ohio,North Carolina"
showoff_states = "Michigan,Florida"

with open(".env", "w") as f:
    f.write('APP_HOME="C:/Users/57320/Desktop/forecast_optimization_ux/thid_party_apps/location_recommendation_system"\n')
    f.write(f'HOLD_OUT_STATES="{excluded_states}"\n')
    f.write(f'SHOW_OFF_STATES="{showoff_states}"\n')

print(".env file created successfully.")

.env file created successfully.


In [2]:
%%writefile plot_styles.py 
"""
title: plot styles
description: Defines several decorators to distributes a custom styles across the plotly figures of a project,
             thus providing standardization and a cohesive style. Full suport on plotly. Partial support on Folium
             maps
"""

import plotly.graph_objects as go
import functools
import textwrap
import functools
import plotly.graph_objects as go
import folium


def transparent_background():
    """
    Decorator to set a fully transparent background for a Plotly figure.
    """
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            fig = func(*args, **kwargs)
            fig.update_layout(
                paper_bgcolor='rgba(0,0,0,0)',  # Fully transparent background
                plot_bgcolor='rgba(0,0,0,0)',  # Transparent plot area
            )
            return fig
        return wrapper
    return decorator

def methodological_clarification(clarification_text, words_per_line=100):
    """
    Decorator to add methodological clarification text to a Plotly figure.
    - Wraps text every `words_per_line` words by inserting line breaks.
    - Adjusts margins for better positioning.
    """
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            fig = func(*args, **kwargs)
            wrapped_text = "<br>".join(textwrap.wrap(clarification_text, width=words_per_line))
            fig.add_annotation(
                text=wrapped_text,
                showarrow=False,
                xref="paper", yref="paper",
                x=0.5, y=0,  # Moved 5% closer (was -0.02)
                xanchor="center", yanchor="top",
                font=dict(size=12, color="black"),
                align="center"
            )
            fig.update_layout(margin=dict(l=15, r=15, t=55, b=50))  # Adjusted margins closer
            return fig
        return wrapper
    return decorator


def centered_title(title_text, title_coords="tightly_integrated"):
    """
    Decorator to add a centered title slightly above the plot.
    
    Parameters:
    - title_text (str): The text of the title.
    - title_coords (tuple or str): 
        - A tuple (x, y) to manually position the title.
        - If "tightly_integrated", uses the default (0.5, 0.85).
    
    Defaults:
    - If `title_coords="tightly_integrated"`, places the title at (x=0.5, y=0.85).
    - If a tuple (x, y) is provided, it is used directly.
    """
    # Handle the default case
    if title_coords == "tightly_integrated":
        x, y = 0.5, 0.85
    elif isinstance(title_coords, tuple) and len(title_coords) == 2:
        x, y = title_coords
    else:
        raise ValueError("title_coords must be either 'tightly_integrated' or a tuple (x, y).")

    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            fig = func(*args, **kwargs)

            # Add the centered annotation
            fig.add_annotation(
                text=title_text,
                showarrow=False,
                xref="paper", yref="paper",
                x=x, y=y,
                xanchor="center", yanchor="bottom",
                font=dict(size=16, color="black", family="Arial"),
                align="center"
            )

            # Preserve existing margins while ensuring a reasonable top margin
            existing_margins = fig.layout.margin.to_plotly_json() if hasattr(fig.layout, "margin") else {}
            fig.update_layout(
                margin={**existing_margins, "t": max(existing_margins.get("t", 0), 50)}
            )

            return fig
        return wrapper
    return decorator


def apply_typography():
    """
    Decorator to enforce a rigorous and minimalist font style in Plotly figures.
    """
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            fig = func(*args, **kwargs)
            fig.update_layout(
                font=dict(family="Lato, sans-serif", size=14, color="black"),
                title=dict(font=dict(size=18, family="Lato, sans-serif", color="black", weight="bold")),
                xaxis=dict(title=dict(font=dict(size=14, family="Lato, sans-serif", color="black"))),
                yaxis=dict(title=dict(font=dict(size=14, family="Lato, sans-serif", color="black"))),
            )
            return fig
        return wrapper
    return decorator


def save_plot_as_html(filepath="plot.html"):
    """
    Decorator to save a Plotly figure or Folium map as an HTML file.
    
    - Detects if the returned object is a Plotly `go.Figure` or a Folium `Map`.
    - Saves appropriately and prints which case was matched.
    """
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            fig = func(*args, **kwargs)  # Call the decorated function
            
            if isinstance(fig, go.Figure):
                print(f"[INFO] Matched: Plotly Figure. Saving to {filepath}.")
                fig.write_html(filepath)

            elif isinstance(fig, folium.Map):
                print(f"[INFO] Matched: Folium Map. Saving to {filepath}.")
                fig.save(filepath)

            else:
                print("[WARNING] No match found. Returning object as is.")

            return fig  # Return the original figure/map
        return wrapper
    return decorator


Overwriting plot_styles.py


In [3]:
%%writefile spatial_features.py
"""
title: Spatial Features
description: Provides a dataframe with zip level spatial features from the US
"""
import os
import pandas as pd
import re
from dotenv import load_dotenv

# Load environment variables
load_dotenv()
APP_HOME = os.getenv("APP_HOME")

if not APP_HOME:
    raise ValueError("APP_HOME environment variable not found. Please set it in the .env file.")

# Change working directory
os.chdir(APP_HOME)

# Load the dataset
zip_features = pd.read_csv('data/income_new.csv')

# Standardizing ZIP code
zip_features["zip_code"] = zip_features["ZIP"].astype(str)

# Selecting relevant columns
zip_features = zip_features[
    [col for col in zip_features.columns if re.search("(zip_code)|(Hous)|(Fam)|(Marr)|(Non)", col)]
]

# Display DataFrame information
print(zip_features.info())

# Dictionary for term replacements
abbreviations = {
    'Households': 'Hs',
    'Less Than': 'LT',
    'Nonfamily Households': 'NFHs',
    'Married-Couple Families': 'MCF',
    'Families': 'Fam',
    'Median Income (Dollars)': 'MedInc',
    'Mean Income (Dollars)': 'MeanInc',
    'Income in the Past 12 Months': 'Inc12M',
    ' to ': '-',
    '$': '',
}

# Function to replace all terms and abbreviate numbers
def abbreviate_names(name):
    for long, short in abbreviations.items():
        name = name.replace(long, short)
    name = name.replace('100,000', '100k')
    name = name.replace('150,000', '150k')
    name = name.replace('200,000', '200k')
    return name

# Apply abbreviation function
zip_features.columns = [abbreviate_names(col) for col in zip_features.columns]

# Save the processed file
#zip_features.to_csv("data/processed_income_features.csv", index=False)
#print("Processed file saved.")


Overwriting spatial_features.py


In [4]:
%%writefile stores_placements.py
"""
title: Store Placements Processing
description: This script processes store locations and standardizes ZIP codes.
"""

import os
import pandas as pd
import geopandas as gpd
from dotenv import load_dotenv
import plotly.express as px
import pandas as pd
import numpy as np
import plotly.graph_objects as go

# Load environment variables
load_dotenv()
APP_HOME = os.getenv("APP_HOME")

if not APP_HOME:
    raise ValueError("APP_HOME environment variable not found. Please set it in the .env file.")

# Change working directory
os.chdir(APP_HOME)

# Load the dataset
stores_df = pd.read_csv("data/trader_joes.csv")

# Standardizing ZIP code
stores_df["zip_code"] = stores_df["zip_code"].astype(str)

# Convert to GeoDataFrame
stores_gdf = gpd.GeoDataFrame(stores_df, geometry=gpd.points_from_xy(stores_df.longitude, stores_df.latitude))

# Display GeoDataFrame information
print(stores_gdf.info())

# Save the processed GeoDataFrame (optional)
# stores_gdf.to_file("data/processed_trader_joes.geojson", driver="GeoJSON")
# print("Processed GeoDataFrame saved.")


Overwriting stores_placements.py


In [5]:
%%writefile spatial_frame.py
"""
title: Spatial Frame Processing
description: This script processes spatial data from shapefiles and standardizes ZIP codes.
"""

import os
import geopandas as gpd
from dotenv import load_dotenv

# Load environment variables
load_dotenv()
APP_HOME = os.getenv("APP_HOME")

if not APP_HOME:
    raise ValueError("APP_HOME environment variable not found. Please set it in the .env file.")

# Change working directory
os.chdir(APP_HOME)

# Load the shapefile
spatial_frame = gpd.read_file("zips_with_states.shp")

# Standardizing ZIP code
spatial_frame["zip_code"] = spatial_frame["zip_code"].astype(str)

# Display DataFrame information
print(spatial_frame.info())

# Save the processed file (optional)
# spatial_frame.to_file("data/processed_zips_with_states.shp")
# print("Processed spatial data saved.")


Overwriting spatial_frame.py


In [6]:
%%writefile data_manager.py
"""
title: Data Manager
description: Handles ingestion merge and exploratory plottin of 3 data sources
             for spatial analysis.
"""

import numpy as np
import pandas as pd
import geopandas as gpd
import folium
import plotly.express as px
import plotly.express as px
import pandas as pd
import numpy as np
import plotly.graph_objects as go

class StoreLocationDataManager:
    
    #
    # (0) initialization
    #
    def __init__(self, 
                 stores_df: pd.DataFrame, 
                 spatial_frame: gpd.GeoDataFrame, 
                 zip_features: pd.DataFrame, 
                 zip_id_col: str = "zip_code",
                 state_column: str = "NAME"):
        #
        # inputs for spatial inference
        #
        self.stores_df = stores_df
        self.spatial_frame = spatial_frame
        self.zip_features = zip_features
        self.zip_id_col = zip_id_col
        self.state_column = state_column

        # Validate the common key and state column
        self._validate_zip_id_col()
        self._validate_state_column()

        # Precompute unique state names
        self.unique_states = set(self.spatial_frame[self.state_column].dropna().unique())
    #
    def _validate_zip_id_col(self):
        """Ensure the common key exists and is a string in all dataframes."""
        datasets = {
            "stores_df": self.stores_df,
            "spatial_frame": self.spatial_frame,
            "zip_features": self.zip_features
        }
        
        for dataset_name, dataset in datasets.items():
            if self.zip_id_col not in dataset.columns:
                raise ValueError(f"Common key '{self.zip_id_col}' not found in {dataset_name}.")
            if not pd.api.types.is_string_dtype(dataset[self.zip_id_col]):
                raise TypeError(f"Common key '{self.zip_id_col}' in {dataset_name} must be of type str.")
            if dataset[self.zip_id_col].isna().any():
                raise ValueError(f"Common key '{self.zip_id_col}' contains NaN values in {dataset_name}.")
    #
    def _validate_state_column(self):
        """Ensure the state column exists and contains valid data."""
        if self.state_column not in self.spatial_frame.columns:
            raise ValueError(f"State column '{self.state_column}' not found in spatial_frame.")
        if not pd.api.types.is_string_dtype(self.spatial_frame[self.state_column]):
            raise TypeError(f"State column '{self.state_column}' must be of type str.")
        if self.spatial_frame[self.state_column].isna().any():
            raise ValueError(f"State column '{self.state_column}' contains NaN values.")

    #
    # (1) Plotting store locations on a map
    #
    def plot_store_location(self, 
                            state_name: str, 
                            zoom_start: int = 6, 
                            min_zoom: int = 6
                            ):
        """Generate an interactive map showing store locations in the given state."""
        
        # Validate state name
        if state_name not in self.unique_states:
            raise ValueError(f"State '{state_name}' not found in the spatial data.")

        # Filter spatial data for the given state
        state_frame = self.spatial_frame[self.spatial_frame[self.state_column] == state_name]

        # Compute map center
        centroid = state_frame.geometry.centroid
        map_center = [centroid.y.mean(), centroid.x.mean()]
        
        # Create map
        map_location = folium.Map(
            location=map_center, 
            zoom_start=zoom_start,
            min_zoom=min_zoom
        )
        
        # Add state boundary
        folium.GeoJson(
            state_frame.to_json(),
            name=state_name,
            style_function=lambda feature: {
                'fillColor': 'blue',
                'color': 'white',
                'weight': 0.5,
                'fillOpacity': 0.4
            }
        ).add_to(map_location)

        
        # Find matching stores in the state
        matched_stores = self.stores_df[self.stores_df[self.zip_id_col].isin(state_frame[self.zip_id_col])]
        for _, row in matched_stores.iterrows():
            store_name = row.get('store_name', "Unnamed Store")
            folium.Marker(
                location=[row['latitude'], row['longitude']],
                popup=store_name
            ).add_to(map_location)
        
        return map_location
    
    #
    def _left_join(self, left_table: str, right_table: str) -> pd.DataFrame:
        """Perform a left join between two of the class's main tables using zip_id_col."""
        
        # Validate table names
        valid_tables = {
            "stores_df": self.stores_df,
            "spatial_frame": self.spatial_frame,
            "zip_features": self.zip_features
        }
        
        if left_table not in valid_tables:
            raise ValueError(f"Invalid left table '{left_table}'. Must be one of {list(valid_tables.keys())}.")
        if right_table not in valid_tables:
            raise ValueError(f"Invalid right table '{right_table}'. Must be one of {list(valid_tables.keys())}.")
        
        left_df = valid_tables[left_table]
        right_df = valid_tables[right_table]

        # Validate presence of zip_id_col
        if self.zip_id_col not in left_df.columns:
            raise ValueError(f"Key '{self.zip_id_col}' not found in {left_table}.")
        if self.zip_id_col not in right_df.columns:
            raise ValueError(f"Key '{self.zip_id_col}' not found in {right_table}.")
        
        return left_df.merge(right_df, on=self.zip_id_col, how='left')

    #
    def plot_stores_summary_per_state(self, household_col: str = "Hs", drop_highest_pct: float = 5.0):
        """Plots the number of stores vs household count per state, with a quadratic regression fit.
        Drops the top X% of states by household count before fitting.

        Args:
            household_col (str): Column name in zip_features containing household counts. Defaults to "Hs".
            drop_highest_pct (float): Percentage of states with the highest household counts to drop before regression. Defaults to 5%.
        """

        # (1) Get store count per state
        stores_per_state = self._left_join("stores_df", "spatial_frame")
        stores_count = stores_per_state[self.state_column].value_counts().reset_index()
        stores_count.columns = ["State", "Store Count"]

        # (2) Merge zip_features with spatial_frame to get households per zip
        zip_with_states = self._left_join("zip_features", "spatial_frame")

        # (3) Summarize household count per state
        if household_col not in zip_with_states.columns:
            raise ValueError(f"Column '{household_col}' is missing in zip_features.")
        households_per_state = zip_with_states.groupby(self.state_column)[household_col].sum().reset_index()
        households_per_state.columns = ["State", "Total Households"]

        # (4) Merge store count with household count
        merged_df = stores_count.merge(households_per_state, on="State", how="left")

        # (5) Remove top X% of states based on household count
        num_states_to_drop = int(len(merged_df) * (drop_highest_pct / 100))
        dropped_states = merged_df.nlargest(num_states_to_drop, "Total Households")["State"].tolist()
        filtered_df = merged_df[~merged_df["State"].isin(dropped_states)]

        # (6) Compute store density (Stores per 10,000 households)
        filtered_df["Stores per 10k Households"] = (filtered_df["Store Count"] / filtered_df["Total Households"]) * 10000

        # (7) Fit Quadratic Regression (2nd-degree polynomial)
        X = filtered_df["Total Households"]
        y = filtered_df["Store Count"]
        coeffs = np.polyfit(X, y, 2)  # Quadratic fit: y = ax² + bx + c
        poly_eq = np.poly1d(coeffs)
        X_fit = np.linspace(X.min(), X.max(), 100)
        y_fit = poly_eq(X_fit)

        # (8) Create Scatterplot
        fig = px.scatter(
            filtered_df,
            x="Total Households",
            y="Store Count",
            text="State",
            size="Stores per 10k Households",
            title="Number of Stores vs Household Count per State",
            labels={"Total Households": "Total Households", "Store Count": "Number of Stores"},
        )

        # (9) Add Regression Fit Line
        fig.add_trace(go.Scatter(
            x=X_fit,
            y=y_fit,
            mode="lines",
            line=dict(dash="dot", color="red", width=2),
            name="Quadratic Fit"
        ))

        # (10) Caption with dropped states
        dropped_caption = f"Dropped top {drop_highest_pct}% states (by household count): {', '.join(dropped_states)}"
        fig.add_annotation(
            text=dropped_caption,
            xref="paper", yref="paper",
            x=0.05, y=-0.15,
            showarrow=False,
            font=dict(size=12, color="gray"),
        )

        # (11) Style Updates
        fig.update_traces(marker=dict(color="blue", line=dict(width=2, color="white")), textposition="top center")
        fig.update_layout(title_x=0.5, xaxis=dict(showgrid=True), yaxis=dict(showgrid=True))

        #fig.show()
        return fig

Overwriting data_manager.py


In [7]:
%%writefile perform_exploratory_analysis.py
"""
title: Perform Exploratory Analysis
description: Merges the data sources for spatial analysis and produces exploratory plots
image_path: {stores_summary.html,plot_store_location_*}
"""

import os
from spatial_features import zip_features
from stores_placements import stores_gdf
from spatial_frame import spatial_frame
from data_manager import StoreLocationDataManager
from plotly_styles import (
    save_plot_as_html, centered_title, methodological_clarification,
    transparent_background
    )

# Load environment variables
import os
from dotenv import load_dotenv
load_dotenv()

# Load environment variables
SHOW_OFF_STATES = [state.strip() for state in os.getenv("SHOW_OFF_STATES").split(',')]
SHOW_OFF_STATES

Overwriting perform_exploratory_analysis.py


In [8]:
%%writefile perform_exploratory_analysis.py -a

# Instantiate the data manager
data_manager = StoreLocationDataManager(
    stores_df=stores_gdf,
    spatial_frame=spatial_frame,
    zip_features=zip_features,
    zip_id_col="zip_code",  # Adjust if a different column name is used
    state_column="NAME"  # Adjust if a different column name is used for states
)

# Confirm successful instantiation
print("StoreLocationDataManager successfully instantiated.")


Appending to perform_exploratory_analysis.py


In [9]:
%%writefile perform_exploratory_analysis.py -a

# Apply decorators one by one, with their specific parameters
data_manager.plot_stores_summary_per_state = transparent_background()(
    data_manager.plot_stores_summary_per_state
)

data_manager.plot_stores_summary_per_state = methodological_clarification(
    clarification_text="This analysis is based on 2024 store data.", 
    #words_per_line=80
)(
    data_manager.plot_stores_summary_per_state
)

data_manager.plot_stores_summary_per_state = centered_title(
    title_text="Store Summary per State", 
    title_coords=(0.5, 0.9)  # Adjust position if needed
)(
    data_manager.plot_stores_summary_per_state
)

data_manager.plot_stores_summary_per_state = save_plot_as_html(
    filepath="stores_summary.html"
)(
    data_manager.plot_stores_summary_per_state
)

# Step 5: Call the decorated method
data_manager.plot_stores_summary_per_state()

Appending to perform_exploratory_analysis.py


In [10]:
%%writefile perform_exploratory_analysis.py -a

for state in SHOW_OFF_STATES:
    filepath = f"plot_store_location_{state}.html"
    
    # Dynamically wrap the function for each state
    decorated_function = save_plot_as_html(filepath=filepath)(
        data_manager.plot_store_location
    )
    
    # Call the function to generate and save the map
    print(f"Generating and saving plot for {state}...")
    decorated_function(state)


Appending to perform_exploratory_analysis.py


In [11]:
%%writefile spatial_correlation.py
"""
title: Spatial Correlation
description: Elaborates StoreLocationDataManager. It provides methods to asses the correlation pattern
             between zip level features and the stores location.
"""

import os
import warnings
import geopandas as gpd
import pandas as pd
import folium
import plotly.express as px
from shapely.geometry import Point
from folium.plugins import Draw, MeasureControl
from sklearn.model_selection import train_test_split

from data_manager import StoreLocationDataManager

import os
from dotenv import load_dotenv
load_dotenv()

HOLD_OUT_STATES = [state.strip() for state in os.getenv("HOLD_OUT_STATES").split(',')]
HOLD_OUT_STATES

print(f"[INFO] Hold-out states: {HOLD_OUT_STATES}")

class SpatialCorrelation(StoreLocationDataManager):
    def __init__(self, spatial_resolution=0.2, hold_out_states=HOLD_OUT_STATES, *args, **kwargs):
        super().__init__(*args, **kwargs)  # Pass additional parameters to the parent class
        self.spatial_resolution = spatial_resolution
        self.hold_out_states = hold_out_states
        self.determine_target_classification()  # Ensure raw_target is set at instantiation
    
    def determine_target_classification(self):
        """Classify locations as part of store proximity or holdout exclusions."""
        print(f"[INFO] Excluding states: {self.hold_out_states}")
        self.spatial_frame['raw_target'] = 0  # Initialize everything as neutral
        
        # Flag states to be excluded
        self.spatial_frame.loc[self.spatial_frame[self.state_column].isin(self.hold_out_states), 'raw_target'] = -1

        # Compute store buffers
        store_buffers = self.stores_df.geometry.buffer(self.spatial_resolution)

        # Find positive examples (store proximity), ensuring excluded states remain -1
        positives_geo = gpd.sjoin(
            self.spatial_frame[self.spatial_frame['raw_target'] == 0],  # Ignore excluded states
            gpd.GeoDataFrame(geometry=store_buffers), 
            how='inner', op='intersects'
        )

        self.spatial_frame.loc[positives_geo.index, 'raw_target'] = 1
    
    def get_samples(self, neg_to_pos_ratio=5, **kwargs):
        """Retrieve training and testing samples while respecting hold-out exclusions."""
        print(f"[INFO] Generating samples while excluding states: {self.hold_out_states}")
        neg_to_pos_ratio = int(neg_to_pos_ratio)

        self.spatial_frame['target'] = self.spatial_frame['raw_target']
        positives = self.spatial_frame[self.spatial_frame['target'] == 1]
        potential_negatives = self.spatial_frame[self.spatial_frame['target'] == 0]  # Use 0 instead of -1
        
        if len(positives) * neg_to_pos_ratio > len(potential_negatives):
            warnings.warn('Not enough negatives to match the ratio without replacement. Enabling replacement.')
            sampled_negatives = potential_negatives.sample(n=len(positives) * neg_to_pos_ratio, replace=True, random_state=42)
        else:
            sampled_negatives = potential_negatives.sample(n=len(positives) * neg_to_pos_ratio, replace=False, random_state=42)

        self.spatial_frame.loc[sampled_negatives.index, 'target'] = 0
        filtered_data = self.spatial_frame[self.spatial_frame['target'].isin([1, 0])]
        merged_data = filtered_data.merge(self.zip_features, how='left', on=self.zip_id_col)
        merged_data = merged_data.dropna()

        feature_columns = self.zip_features.columns.tolist()
        X = merged_data[feature_columns]
        y = merged_data['target']

        valid_split_params = {'test_size', 'train_size', 'random_state', 'shuffle', 'stratify'}
        split_params = {k: v for k, v in kwargs.items() if k in valid_split_params}

        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(
            X, y, test_size=0.5, random_state=42, **split_params
        )

        return self.X_train, self.y_train, self.X_test, self.y_test
    
    def plot_correlation_bars(self, k=20):
        X_train, y_train, _, _ = self.get_samples()
        data = pd.concat([X_train, y_train.rename('store_proximity')], axis=1)
        correlation_series = data.corr()['store_proximity'].drop('store_proximity')
        
        corr_df = correlation_series.abs().nlargest(k).reset_index()
        corr_df.columns = ['Feature', 'Correlation']
        
        fig = px.bar(
            corr_df, x='Feature', y='Correlation', text='Feature', color='Correlation',
            color_continuous_scale=px.colors.diverging.Tropic, title="Feature Correlation with Store Proximity"
        )
        
        fig.update_traces(
            marker=dict(line=dict(color='black', width=0.5), opacity=0.8),
            texttemplate='%{text}', textposition='inside',
            textfont=dict(color='white', size=12), insidetextanchor="start"
        )
        return fig
    
    def plot_store_location_with_proximity(self, state_name):
        if state_name not in self.spatial_frame[self.state_column].unique():
            raise ValueError(f"State '{state_name}' not found in the spatial data.")
        
        state_frame = self.spatial_frame[self.spatial_frame[self.state_column] == state_name]
        map_location = folium.Map(
            location=[state_frame.geometry.centroid.y.mean(), state_frame.geometry.centroid.x.mean()],
            zoom_start=6, min_zoom=6
        )
        
        Draw(export=True).add_to(map_location)
        MeasureControl().add_to(map_location)
        
        for target, color in zip([1, 0, -1], ['green', 'red', 'gray']):
            folium.GeoJson(
                state_frame[state_frame['target'] == target].to_json(),
                name=f"{'Positive' if target == 1 else 'Negative' if target == 0 else 'Neutral'} Examples",
                style_function=lambda x, color=color: {'fillColor': color, 'color': color, 'weight': 1, 'fillOpacity': 0.5}
            ).add_to(map_location)
        
        state_stores = self.stores_df[self.stores_df[self.zip_id_col].isin(state_frame[self.zip_id_col])]
        for _, row in state_stores.iterrows():
            folium.CircleMarker(
                location=[row['latitude'], row['longitude']], radius=2.5, color='black', fill=True,
                fill_color='black', fill_opacity=0.5, popup=row.get('store_name', 'Store')
            ).add_to(map_location)
        
        return map_location

Overwriting spatial_correlation.py


In [12]:
%%writefile perform_spatial_correlation_analysis.py
"""
title: Perform Spatial Correlation Analysis
description: Performs the Spatial Correlation Analysis
image_path: {correlation_bars.html,plot_store_location_with_proximity_*}
"""

from spatial_features import zip_features
from stores_placements import stores_gdf
from spatial_frame import spatial_frame
from spatial_correlation import  SpatialCorrelation

from plotly_styles import (
    save_plot_as_html, centered_title, methodological_clarification,
    transparent_background
    )

# Load environment variables
import os
from dotenv import load_dotenv
load_dotenv()

# Load environment variables
SHOW_OFF_STATES = [state.strip() for state in os.getenv("SHOW_OFF_STATES").split(',')]
SHOW_OFF_STATES

Overwriting perform_spatial_correlation_analysis.py


In [13]:
%%writefile perform_spatial_correlation_analysis.py -a
# Inspect datasets
#zip_features.info()
#stores_df.info()
#spatial_frame.info()

# Instantiate the data manager
sc =  SpatialCorrelation(
    stores_df=stores_gdf,
    spatial_frame=spatial_frame,
    zip_features=zip_features,
    zip_id_col="zip_code",  # Adjust if a different column name is used
    state_column="NAME"  # Adjust if a different column name is used for states
)

# Confirm successful instantiation
print("SpatialCorrelation successfully instantiated.")

Appending to perform_spatial_correlation_analysis.py


In [14]:
%%writefile perform_spatial_correlation_analysis.py -a

# Apply decorators one by one, with their specific parameters
plot_correlation_bars = transparent_background()(
    sc.plot_correlation_bars
)

plot_correlation_bars = methodological_clarification(
    clarification_text="This analysis is based on 2024 store data.", 
    #words_per_line=80
)(
    plot_correlation_bars
)

plot_correlation_bars = centered_title(
    title_text="Store Summary per State", 
    title_coords=(0.5, 0.9)  # Adjust position if needed
)(
    plot_correlation_bars
)

plot_correlation_bars = save_plot_as_html(
    filepath="correlation_bars.html"
)(
    plot_correlation_bars
)

# Step 5: Call the decorated method
plot_correlation_bars()


Appending to perform_spatial_correlation_analysis.py


In [15]:
%%writefile perform_spatial_correlation_analysis.py -a

for state in SHOW_OFF_STATES:
    filepath = f"plot_store_location_with_proximity_{state}.html"
    
    # Wrap the function dynamically for each state
    decorated_function = save_plot_as_html(filepath=filepath)(
        sc.plot_store_location_with_proximity
    )
    
    # Call the function to generate and save the map
    print(f"Generating and saving proximity map for {state}...")
    decorated_function(state)


Appending to perform_spatial_correlation_analysis.py


In [16]:
%%writefile spatial_baseline.py
"""
title: Spatial Baseline
description: Elaborates SpatialBaseline. It provides a baseline ML model to predict store placements
             from underlyng data about the zip codes.
"""

import warnings
import geopandas as gpd
import pandas as pd
import folium
import numpy as np
import plotly.express as px
from shapely.geometry import Point
from folium.plugins import Draw, MeasureControl
from branca.colormap import linear
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split
from spatial_correlation import SpatialCorrelation

class SpatialBaseline(SpatialCorrelation):
    
    #
    # (0) Initialization
    #
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)  # Pass additional parameters to the parent class
        
        self.baseline_probabilities=None
        self.fitted_model=None
        self.fit_baseline()
        
    #
    # (1) Fit model baseline
    #
    def fit_baseline(self, neg_to_pos_ratio=5):
        
        # Retrieve the data from get_samples
        X_train, y_train, X_test, y_test = self.get_samples(neg_to_pos_ratio)

        # Train the model
        self.fitted_model = LogisticRegression()
        self.fitted_model.fit(X_train, y_train)

        # Evaluate the model
        predictions = self.fitted_model.predict(X_test)
        self.baseline_probabilities = self.fitted_model.predict_proba(X_test)[:, 1]
        accuracy = accuracy_score(y_test, predictions)
        f1_score_value = f1_score(y_test, predictions)

        # Optionally print or return the metrics
        print("Accuracy:", accuracy)
        print("F1 Score:", f1_score_value)
        
    #
    # (2) plot probabilities
    #   
    
    def _compute_probabilities(self, state_frame):
        """
        Computes probabilities using the fitted model and merges them into the state frame.
        """

        if self.fitted_model is None:
            raise ValueError("The model has not been fitted yet. Fit the model before plotting probabilities.")

        # Merge spatial data with features
        merged_data = state_frame.merge(self.zip_features, how='left', on=self.zip_id_col)

        # Ensure all necessary feature columns exist
        feature_columns = self.zip_features.columns.tolist()
        valid_data = merged_data.dropna(subset=feature_columns)

        # Convert to numerical format
        X = valid_data[feature_columns].values.astype(np.float32)

        # Predict probabilities
        valid_data["probabilities"] = self.fitted_model.predict_proba(X)[:, 1]

        # Ensure alignment of IDs
        valid_data[self.zip_id_col] = merged_data[self.zip_id_col][valid_data.index]

        # Merge back into the original state frame
        probs = valid_data[[self.zip_id_col, "probabilities"]]
        merged_data = merged_data.merge(probs, on=self.zip_id_col, how="left")

        return merged_data
    #
    def plot_probabilities(self, state_name, probabilities=None):
        """
        Plot predicted probabilities for a given state using a choropleth map and store locations.
        
        If probabilities are provided, they are appended directly to the state frame.
        Otherwise, the probabilities are computed using the fitted model, with missing values imputed using the median.
        """

        # Ensure the state exists in spatial data
        if state_name not in self.unique_states:
            raise ValueError(f"State '{state_name}' not found in the spatial data.")

        # Extract state-specific spatial frame
        state_frame = self.spatial_frame[self.spatial_frame[self.state_column] == state_name].copy()

        if probabilities is not None:
            # Validate shape before appending probabilities
            if len(probabilities) != len(state_frame):
                raise ValueError("Provided probabilities do not match the number of state entries.")
            state_frame["probabilities"] = probabilities

        else:
            # Compute probabilities using the fitted model
            state_frame = self._compute_probabilities(state_frame)

            # Impute missing probabilities with the median
            if state_frame["probabilities"].isnull().any():
                median_prob = state_frame["probabilities"].median()
                state_frame["probabilities"].fillna(median_prob, inplace=True)

        #
        # (1.1) base map
        #

        # Compute map center based on state's geographic centroid
        centroid_y = state_frame.geometry.centroid.y.mean()
        centroid_x = state_frame.geometry.centroid.x.mean()

        # Create Folium map centered on the state
        map_location = folium.Map(
            location=[centroid_y, centroid_x],
            zoom_start=6,
            min_zoom=6  # Restrict zooming out
        )

        # Plot probabilities as a choropleth layer
        folium.Choropleth(
            geo_data=state_frame.to_json(),
            data=state_frame,  # Since probabilities are now directly in state_frame
            columns=[self.zip_id_col, "probabilities"],
            key_on=f"feature.properties.{self.zip_id_col}",
            fill_color="YlOrRd",
            fill_opacity=0.7,
            line_opacity=0.2,
            legend_name="Probability of Store Presence"
        ).add_to(map_location)

        #
        # (1.2) Layer in the store locations
        #
        colormap = linear.YlOrRd_09.scale(0, 1)

        # Match stores based on ZIP code
        matched_stores = self.stores_df[self.stores_df[self.zip_id_col].isin(state_frame[self.zip_id_col])]

        for _, row in matched_stores.iterrows():
            # Fetch probability for the store's ZIP code
            prob = state_frame.loc[state_frame[self.zip_id_col] == row[self.zip_id_col], "probabilities"]

            if not prob.empty:
                probability = prob.values[0]  # Extract the probability value
            else:
                probability = 0  # Default to 0 if no probability is found

            store_name = row.get("store_name", f"Store {_}")

            # Add store marker with probability-based coloring
            folium.CircleMarker(
                location=[row["latitude"], row["longitude"]],
                radius=8,
                popup=f"{store_name} - {probability:.2f}",
                color="black",  # Border color
                fill=True,
                fill_color=colormap(probability),  # Fill color based on probability
                fill_opacity=1  # Solid color fill
            ).add_to(map_location)

        # Add colormap legend
        colormap.add_to(map_location)

        return map_location  # Return the Folium map for display

Overwriting spatial_baseline.py


In [17]:
%%writefile perform_spatial_baseline.py
"""
title: Perform Spatial Baseline
description: Performs the training of baseline model to predict store placements from underlyng
             zip data
image_path: plot_probabilities_*
"""

from spatial_features import zip_features
from stores_placements import stores_gdf
from spatial_frame import spatial_frame
from spatial_baseline import SpatialBaseline
from plotly_styles import (
    save_plot_as_html
    )

# Load environment variables
import os
from dotenv import load_dotenv
load_dotenv()

# Load environment variables
HOLD_OUT_STATES = [state.strip() for state in os.getenv("HOLD_OUT_STATES").split(',')]
HOLD_OUT_STATES

Overwriting perform_spatial_baseline.py


In [18]:
%%writefile perform_spatial_baseline.py -a
# Inspect datasets
#zip_features.info()
#stores_df.info()
#spatial_frame.info()

# Instantiate the spatial baseline
sb =  SpatialBaseline(
    stores_df=stores_gdf,
    spatial_frame=spatial_frame,
    zip_features=zip_features,
    zip_id_col="zip_code",  # Adjust if a different column name is used
    state_column="NAME",  # Adjust if a different column name is used for states
    spatial_resolution=0.4
)

# Confirm successful instantiation
print("SpatialBaseline successfully instantiated.")

Appending to perform_spatial_baseline.py


In [19]:
%%writefile perform_spatial_baseline.py -a

# Inspect datasets
for state in HOLD_OUT_STATES:
    filepath = f"plot_probabilities_{state}.html"
    
    # Wrap the function dynamically for each state
    decorated_function = save_plot_as_html(filepath=filepath)(
        sb.plot_probabilities
    )
    
    # Call the function to generate and save the plot
    print(f"Generating and saving probability plot for {state}...")
    decorated_function(state)


Appending to perform_spatial_baseline.py
