# Predicting Citibike Trips Using Snowflake [Snowpark](https://docs.snowflake.com/en/LIMITEDACCESS/snowpark-python.html) & Tabnet
Here we will use the model that we created previously as a UDF.


In [1]:
today = '2021-03-10'

### 0. Prerequisites

Make sure to install the following dependencies. 

In [2]:
!pip install -q '../snowflake_snowpark_python-0.2.0-py3-none-any.whl[pandas]'
!pip install -q pandas toml matplotlib seaborn pytorch-tabnet

In [14]:
import pandas as pd
from pytorch_tabnet.tab_model import TabNetRegressor
import snowflake.snowpark as snp
from snowflake.snowpark import functions as F
from snowflake.snowpark import types as T
from sklearn.metrics import mean_squared_error
import logging
logging.basicConfig(level=logging.WARN)
#logging.getLogger().setLevel(logging.DEBUG)


### 1. Connect to Snowflake


In [4]:
import toml
from os import path
homedir = path.expanduser("~")

snf_env = toml.load(path.join(homedir, '.snowflake_config.toml')).get('citibike')

session = snp.session.Session.builder.configs(snf_env)
session = session.create()

### 2. 


#### Lets update the feature functions with holiday and precip features.

In [6]:
def generate_features(snowdf, station_id):
    agg_period = 'DAY'

    start_date, end_date = snowdf.select(F.min('STARTTIME'), F.max('STARTTIME')).collect()[0][0:2]

    snowdf = snowdf.filter(F.col('START_STATION_ID') == station_id) \
                         .withColumn('DATE', 
                                     F.call_builtin('DATE_TRUNC', (agg_period, F.col('STARTTIME')))) \
                         .groupBy('DATE') \
                         .count() 
                           
    holiday_df = session.table('HOLIDAYS')
    precip_df = session.table('WEATHER')


  
    #Impute missing values for lag columns using mean of the previous period.
    mean_1 = round(snowdf.sort('DATE').limit(1).select(F.mean('COUNT')).collect()[0][0])
    mean_7 = round(snowdf.sort('DATE').limit(7).select(F.mean('COUNT')).collect()[0][0])
    mean_30 = round(snowdf.sort('DATE').limit(30).select(F.mean('COUNT')).collect()[0][0])
    mean_90 = round(snowdf.sort('DATE').limit(90).select(F.mean('COUNT')).collect()[0][0])
    mean_365 = round(snowdf.sort('DATE').limit(365).select(F.mean('COUNT')).collect()[0][0])
    
    date_win = snp.Window.orderBy('DATE')

    snowdf = snowdf.withColumn('LAG_1', F.lag('COUNT', offset=1, default_value=mean_1) \
                                         .over(date_win)) \
                   .withColumn('LAG_7', F.lag('COUNT', offset=7, default_value=mean_7) \
                                         .over(date_win)) \
                   .withColumn('LAG_30', F.lag('COUNT', offset=30, default_value=mean_30) \
                                         .over(date_win)) \
                   .withColumn('LAG_90', F.lag('COUNT', offset=90, default_value=mean_90) \
                                         .over(date_win)) \
                   .withColumn('LAG_365', F.lag('COUNT', offset=365, default_value=mean_365) \
                                         .over(date_win)) \
                   .join(holiday_df, 'DATE', join_type='left').na.fill({'HOLIDAY':0}) \
                   .join(precip_df, 'DATE', 'inner') \
                   #.withColumn('DAYOFWEEK', F.call_builtin('DAYOFWEEK', F.col('DATE'))) \
                   #.withColumn('MONTH', F.call_builtin('MONTH', F.col('DATE'))) \
                   #.na.drop() \
                   #.sort('DATE', ascending=True) 

    return snowdf


### Inference
We create functions to generate past and future dataframes for inference and evaluation.  These will be useful later in our inference pipeline.

In [7]:
# def generate_past_df(table, station_id, today, prediction_period):
    
#     newdf = session.table(table).filter((F.to_date('STARTTIME') <= F.to_date(F.lit(today))) 
#                                         &              
#                                         (F.to_date('STARTTIME') >= F.dateadd('DAY', 
#                                                                              F.lit(-365+prediction_period), 
#                                                                              F.to_date(F.lit(today)))))
    
#     newdf = generate_features(newdf, station_id)
#     newdf = newdf.filter(F.to_date('DATE') >= F.dateadd('DAY',F.lit(prediction_period),F.to_date(F.lit(today))))
  
#     return newdf


def generate_future_df(table, station_id, today, prediction_period):
    
    newdf = session.table(table).filter((F.to_date('STARTTIME') < F.dateadd('DAY', 
                                                                            F.lit(prediction_period), 
                                                                            F.to_date(F.lit(today))))
                                        &
                                        (F.to_date('STARTTIME') >= F.dateadd('DAY', 
                                                                             F.lit(-365), 
                                                                             F.to_date(F.lit(today)))))

    newdf = generate_features(newdf, station_id)
    newdf = newdf.filter(F.to_date('DATE') >= F.to_date(F.lit(today)))

    return newdf

Lets see how our model performed for the previous 90 days

In [8]:
#Lets predict the next 7 days
prediction_period=7

future_df = generate_future_df(table='trips_stations_vw', 
             station_id='2006', 
             today=today, 
             prediction_period=prediction_period)

future_df.show()

--------------------------------------------------------------------------------------------------------------
|"DATE"               |"COUNT"  |"LAG_1"  |"LAG_7"  |"LAG_30"  |"LAG_90"  |"LAG_365"  |"HOLIDAY"  |"PRECIP"  |
--------------------------------------------------------------------------------------------------------------
|2021-03-10 00:00:00  |823      |693      |631      |328       |691       |615        |0          |0.0       |
|2021-03-11 00:00:00  |761      |823      |651      |376       |653       |587        |0          |0.38      |
|2021-03-12 00:00:00  |760      |761      |608      |423       |624       |606        |0          |0.0       |
|2021-03-13 00:00:00  |789      |760      |668      |395       |555       |655        |0          |0.0       |
|2021-03-14 00:00:00  |761      |789      |603      |389       |471       |635        |0          |0.0       |
|2021-03-15 00:00:00  |752      |761      |648      |380       |533       |581        |0          |0.0       |
|

In [12]:
target = ['COUNT']
feature_columns = [feature.replace('\"', '') for feature in future_df.columns]
feature_columns.remove(target[0])
feature_columns.remove('DATE')

future_df.select(feature_columns).limit(1).show()

------------------------------------------------------------------------------
|"LAG_1"  |"LAG_7"  |"LAG_30"  |"LAG_90"  |"LAG_365"  |"HOLIDAY"  |"PRECIP"  |
------------------------------------------------------------------------------
|693      |631      |328       |691       |615        |0          |0.0       |
------------------------------------------------------------------------------



In [13]:
session.range(1).select(F.call_udf('station_2006_model_udf',
                                   F.lit(377), 
                                   F.lit(427), 
                                   F.lit(550), 
                                   F.lit(1016), 
                                   F.lit(1091), 
                                   F.lit(0), 
                                   F.lit(23.87))).show()

------------------------------------------------------
|"STATION_2006_MODEL_UDF(377, 427, 550, 1016, 10...  |
------------------------------------------------------
|452                                                 |
------------------------------------------------------



In [11]:
future_df.withColumn('PRED', F.call_udf('station_2006_model_udf', [F.col(c) for c in feature_columns])).show()

-----------------------------------------------------------------------------------------------------------------------
|"DATE"               |"COUNT"  |"LAG_1"  |"LAG_7"  |"LAG_30"  |"LAG_90"  |"LAG_365"  |"HOLIDAY"  |"PRECIP"  |"PRED"  |
-----------------------------------------------------------------------------------------------------------------------
|2021-03-10 00:00:00  |823      |693      |631      |328       |691       |615        |0          |0.0       |715     |
|2021-03-11 00:00:00  |761      |823      |651      |376       |653       |587        |0          |0.38      |779     |
|2021-03-12 00:00:00  |760      |761      |608      |423       |624       |606        |0          |0.0       |731     |
|2021-03-13 00:00:00  |789      |760      |668      |395       |555       |655        |0          |0.0       |777     |
|2021-03-14 00:00:00  |761      |789      |603      |389       |471       |635        |0          |0.0       |741     |
|2021-03-15 00:00:00  |752      |761    

### Adventures
#### impute missing instead of drop
#### add additional lag features
#### check feature importance
#### build with most important features
#### train in UDF
#### ci/cd pipeline
#### Feature views and ZCC
#### build/train/deploy loop for top N stations.
#### parallelization
#### 