# Citibike ML
In this example we use the [Citibike dataset](https://ride.citibikenyc.com/system-data). Citibike is a bicycle sharing system in New York City. Everyday users choose from 20,000 bicycles at 1300 stations around New York City.

To ensure customer satisfaction Citibike needs to predict how many bicycles will be needed at each station. Maintenance teams from Citibike will check each station and repair or replace bicycles. Additionally, the team will relocate bicycles between stations based on predicted demand. The business needs to be able to run reports of how many bicycles will be needed at a given station on a given day.

## Streamlit Application
In this section of the demo, we will utilize Streamlit with Snowpark's Python client-side Dataframe API to create a visual front-end application for the Citibike operations team to consume the insights from the ML forecast.

For this demo flow we will assume that the organization has the following **policies and processes** :   
-**Dev Tools**: The ML engineer can develop in their tool of choice (ie. VS Code, IntelliJ, Pycharm, Eclipse, etc.).  Snowpark Python makes it possible to use any environment where they have a python kernel.  For the sake of a demo we will use Jupyter.  
-**Data Governance**: To preserve customer privacy no data can be stored locally.  The ingest system may store data temporarily but it must be assumed that, in production, the ingest system will not preserve intermediate data products between runs. Snowpark Python allows the user to push-down all operations to Snowflake and bring the code to the data.   
-**Automation**: Although the ML engineer can use any IDE or notebooks for development purposes the final product must be python code at the end of the work stream.  Well-documented, modularized code is necessary for good ML operations and to interface with the company's CI/CD and orchestration tools.  
-**Compliance**: Any ML models must be traceable back to the original data set used for training.  The business needs to be able to easily remove specific user data from training datasets and retrain models. 

In [None]:
!pip -q install streamlit

In [None]:
from snowflake.snowpark import functions as F
from dags.snowpark_connection import snowpark_connect
import logging
logging.basicConfig(level=logging.WARN)
logging.getLogger().setLevel(logging.WARN)
session, state_dict = snowpark_connect('./include/state.json')

In [None]:
import pandas as pd
import streamlit as st

In [None]:
%%writefile streamlit_app.py
import streamlit as st
import pandas as pd
from datetime import timedelta
import altair as alt
from snowflake.snowpark import functions as F
from dags.snowpark_connection import snowpark_connect
import logging
logging.basicConfig(level=logging.WARN)
logging.getLogger().setLevel(logging.WARN)

session, state_dict = snowpark_connect('./include/state.json')


def update_forecast_table(forecast_df, stations:list, start_date, end_date):
    df = forecast_df.where((F.col('DATE') >= start_date) & 
                           (F.col('DATE') <= end_date))\
                    .select('STATION_ID', 'DATE', 'PRED')\
                    .filter(forecast_df['STATION_ID'].in_(stations))\
                    .to_pandas()
    
    data = df.pivot(index="STATION_ID", columns="DATE", values="PRED")
    st.write("### Weekly Forecast", data)
    
    return None

def update_eval_table(eval_df, stations:list):
    df = eval_df.select('STATION_ID', 'RUN_DATE', 'RMSE')\
                    .filter(eval_df['STATION_ID'].in_(stations))\
                    .to_pandas()

    data = df.pivot(index="STATION_ID", columns="RUN_DATE", values="RMSE")
    st.write("### Model Monitor (RMSE)", data)    
    return None


forecast_df = session.table('FLAT_FORECAST')
eval_df = session.table('FLAT_EVAL')

min_date=session.table('FLAT_FORECAST').select(F.min('DATE')).collect()[0][0]
max_date=session.table('FLAT_FORECAST').select(F.max('DATE')).collect()[0][0]

start_date = st.date_input('Start Date', value=min_date, min_value=min_date, max_value=max_date)
show_days = st.number_input('Number of days to show', value=7, min_value=1, max_value=30)
end_date = start_date+timedelta(days=show_days)

stations_df=session.table('FLAT_FORECAST').select(F.col('STATION_ID')).distinct().to_pandas()

stations = st.multiselect('Choose stations', stations_df['STATION_ID'], ["519", "545"])
if not stations:
    stations = stations_df['STATION_ID']

update_forecast_table(forecast_df, stations, start_date, end_date)

update_eval_table(eval_df, stations)

download_file_names = st.multiselect(label='Monthly ingest file(s):', 
                                     options=['202003-citibike-tripdata.csv.zip'], 
                                     default=['202003-citibike-tripdata.csv.zip'])

st.button('Run Ingest Taskflow', args=(download_file_names))

In [None]:
session.table('FLAT_FORECAST').where((F.col('DATE') >= start_date) & 
                                     (F.col('DATE') <= end_date))\
                    .select(F.to_char(F.col('DATE'))).show()

In [None]:
logging.getLogger().setLevel(logging.WARN)

data = session.table('FLAT_FORECAST').where((F.col('DATE') >= start_date) & 
                           (F.col('DATE') <= end_date))\
                    .select('STATION_ID', F.to_char(F.col('DATE')).alias('DATE'), 'PRED')\
                    .filter(forecast_df['STATION_ID'].in_(stations))\
                    .to_pandas()\
                    #.pivot(index="STATION_ID", columns="DATE", values="PRED")
data
# alt.Chart(data).mark_rect().encode(
#     alt.X('DATE:T'),
#     alt.Y('STATION_ID:Q'),
#     alt.Color('PRED:N',
#         scale=alt.Scale(scheme='greenblue'),
#         legend=alt.Legend(title='Total Records')
#     )
# )


In [None]:



# alt.Chart(data).mark_line().encode(
#             x="DATE:T",
#             y=alt.Y("RMSE:N", stack=None),
#             color="STATION_ID:N")


In [None]:
help(alt.Chart().mark_line())

In [None]:
import streamlit as st
import pandas as pd
import altair as alt

from urllib.error import URLError

@st.cache
def get_UN_data():
    AWS_BUCKET_URL = "http://streamlit-demo-data.s3-us-west-2.amazonaws.com"
    df = pd.read_csv(AWS_BUCKET_URL + "/agri.csv.gz")
    return df.set_index("Region")

try:
    df = get_UN_data()
    countries = st.multiselect(
        "Choose countries", list(df.index), ["China", "United States of America"]
    )
    if not countries:
        st.error("Please select at least one country.")
    else:
        data = df.loc[countries]
        data /= 1000000.0
        st.write("### Gross Agricultural Production ($B)", data.sort_index())

        data = data.T.reset_index()
        data = pd.melt(data, id_vars=["index"]).rename(
            columns={"index": "year", "value": "Gross Agricultural Product ($B)"}
        )
        chart = (
            alt.Chart(data)
            .mark_area(opacity=0.3)
            .encode(
                x="year:T",
                y=alt.Y("Gross Agricultural Product ($B):Q", stack=None),
                color="Region:N",
            )
        )
        st.altair_chart(chart, use_container_width=True)
except URLError as e:
    st.error(
        """
        **This demo requires internet access.**

        Connection error: %s
    """
        % e.reason
    )

In [None]:
AWS_BUCKET_URL = "http://streamlit-demo-data.s3-us-west-2.amazonaws.com"
df1 = pd.read_csv(AWS_BUCKET_URL + "/agri.csv.gz")
df1 = df1.set_index("Region")
countries = ["China", "United States of America"]
data = df1.loc[countries]
data /= 1000000.0
data = data.T.reset_index()
data = pd.melt(data, id_vars=["index"]).rename(
            columns={"index": "year", "value": "Gross Agricultural Product ($B)"})

In [None]:
data