In [None]:

POPULATION_CSV = 'https://raw.githubusercontent.com/jiobu1/labspt15-cityspire-g-ds/main/notebooks/model/population2010-2019/csv/population_cleaned.csv'
FORECAST_CSV = 'https://raw.githubusercontent.com/jiobu1/labspt15-cityspire-g-ds/main/notebooks/model/population2010-2019/csv/population_prediction.csv'

@router.post('/api/population_forecast_graph')
async def population_forecast_graph(city:City):
    """
    Create visualization of historical and forecasted population

    args:
    - city: str -> The target city
    - periods: int -> number of years to forecast for

    Returns:
    Visualization of population forecast
    - 10 year of historical data
    - forecasts for number of years entered
    """

    city = validate_city(city)
    location = [city.city + ', ' + city.state]

    # Historical population data
    population = pd.read_csv(POPULATION_CSV)
    population = population[population['City,State'].isin(location)]
    population = population[['City,State', '2010', '2011', '2012', '2013', '2014', '2015', '2016', '2017', '2018', '2019']]
    population_melt = population.melt(id_vars=['City,State'], var_name='ds', value_name='y')
    population_melt['ds'] = (population_melt['ds']).astype(int)

    # Predictions
    forecast = pd.read_csv(FORECAST_CSV)
    predictions = forecast[forecast['City,State'].isin(location)][9:]
    predictions['year'] = (predictions['year']).astype(int)

    # Graph Data
    ax = population_melt.plot(x = 'ds', y = 'y', label='Observed', figsize= (10, 8))
    predictions[['year', 'yhat']].plot(ax = ax, x = 'year', y = 'yhat', label = "Forecast")

    # Fill to show upper and lower bounds
    # Graph predictions including the upper and lower bounds
    fig = go.Figure()

    fig.add_trace(go.Scatter(
        name = 'Original',
        x = population_melt['ds'],
        y = population_melt['y'],
        fill = None,
        mode = 'lines',
        line_color = 'black',
        showlegend = True
    ))

    fig.add_trace(go.Scatter(
        name = 'Forecast',
        x = predictions['year'],
        y = predictions['yhat'],
        fill = None,
        mode = 'lines',
        line_color = 'red',
        showlegend = True
    ))

    fig.add_trace(go.Scatter(
        name = 'Lower Bound',
        x = predictions['year'],
        y = predictions['yhat_lower'],
        fill = None,
        mode = 'lines',
        line_color = 'gray'
    ))

    fig.add_trace(go.Scatter(
        name = 'Upper Bound',
        x = predictions['year'],
        y = predictions['yhat_upper'],
        fill='tonexty',
        mode='lines',
        line_color = 'gray'
    ))

    # Edit the layout
    fig.update_layout({
        'autosize':True,
        'title': f'{location[0]} Population Forecast',
        'title_x': 0.5,
        'xaxis_title': 'Year',
        'yaxis_title': 'Population'
        })

    fig.update_yaxes(automargin = True)
    fig.update_xaxes(automargin = True, nticks=20)

    fig.show()

    return fig.to_json()
