# Citibike ML
In this example we use the [Citibike dataset](https://ride.citibikenyc.com/system-data). Citibike is a bicycle sharing system in New York City. Everyday users choose from 20,000 bicycles at 1300 stations around New York City.

To ensure customer satisfaction Citibike needs to predict how many bicycles will be needed at each station. Maintenance teams from Citibike will check each station and repair or replace bicycles. Additionally, the team will relocate bicycles between stations based on predicted demand. The business needs to be able to run reports of how many bicycles will be needed at a given station on a given day.



## Data Engineering
We begin where all ML use cases do: data engineering. In this section of the demo, we will utilize Snowpark's Python client-side Dataframe API to build an **ELT pipeline**.  We will extract the data from the source system (s3), load it into snowflake and add transformations to clean the data before analysis. 

The data engineer has been told that there is historical data going back to 2013 and new data will be made available at the end of each month. 

For this demo flow we will assume that the organization has the following **policies and processes** :   
-**Dev Tools**: The data engineer can develop in their tool of choice (ie. VS Code, IntelliJ, Pycharm, Eclipse, etc.).  Snowpark Python makes it possible to use any environment where they have a python kernel.  For the sake of a demo we will use Jupyter.  
-**Data Governance**: To preserve customer privacy no data can be stored locally.  The ingest system may store data temporarily but it must be assumed that, in production, the ingest system will not preserve intermediate data products between runs. Snowpark Python allows the user to push-down all operations to Snowflake and bring the code to the data.   
-**Automation**: Although the data engineer can use any IDE or notebooks for development purposes the final product must be python code at the end of the work stream.  Well-documented, modularized code is necessary for good ML operations and to interface with the company's CI/CD and orchestration tools.  
-**Compliance**: Any ML models must be traceable back to the original data set used for training.  The business needs to be able to easily remove specific user data from training datasets and retrain models.  

Input: Historical bulk data at `https://s3.amazonaws.com/tripdata/`. Incremental data to be loaded one month at a time.  
Output: `trips` table

In [1]:
import snowflake.snowpark as snp
from snowflake.snowpark import functions as F
from snowflake.snowpark import types as T

import pandas as pd
from datetime import datetime
import requests
from zipfile import ZipFile
from io import BytesIO
import os

#import logging
#logging.basicConfig(level=logging.WARN)
#logging.getLogger().setLevel(logging.DEBUG)

### 1. Load  credentials and connect to Snowflake


We will utilize a simple json file to store our credentials. This should **never** be done in production and is for demo purposes only.

In [2]:
from dags.snowpark_connection import snowpark_connect
session, compute_parameters, state_dict = snowpark_connect('./include/creds.json')

### 2. 

In [3]:
#fq_load_stage_name=session.get_fully_qualified_current_schema()+'.\"'+state_dict['load_stage_name']+'\"'
#session.sql('CREATE STAGE IF NOT EXISTS '+fq_load_stage_name).collect()
load_stage_name=state_dict['load_stage_name']
session.sql('CREATE STAGE IF NOT EXISTS '+load_stage_name).collect()
#session.sql('CREATE OR REPLACE STAGE '+state_dict['load_stage_name']).collect()

[Row(status='LOAD_STAGE already exists, statement succeeded.')]

### 3. Extract:  


Create a list of files to download and upload to stage

In [4]:
import pandas as pd
from datetime import datetime

#For files like 201306-citibike-tripdata.zip
date_range1 = pd.period_range(start=datetime.strptime("201306", "%Y%m"), 
                             end=datetime.strptime("201612", "%Y%m"), 
                             freq='M').strftime("%Y%m")
file_name_end1 = '-citibike-tripdata.zip'
files_to_download = [date+file_name_end1 for date in date_range1.to_list()]


Starting in January 2017 Citibike changed the format of the file name.

In [5]:
#For files like 201701-citibike-tripdata.csv.zip
date_range2 = pd.period_range(start=datetime.strptime("201701", "%Y%m"), 
                             end=datetime.strptime("202112", "%Y%m"), 
                             freq='M').strftime("%Y%m")
file_name_end2 = '-citibike-tripdata.csv.zip'
files_to_download = files_to_download + [date+file_name_end2 for date in date_range2.to_list()]

For development purposes we will start with loading just a couple of files.  We will create a bulk load process afterwards.

In [6]:
files_to_download = [files_to_download[i] for i in [0,102]] #19,50,100,102]]
files_to_download

['201306-citibike-tripdata.zip', '202112-citibike-tripdata.csv.zip']

In [7]:
session.use_warehouse(compute_parameters['fe_warehouse'])

In [8]:
schema1_download_files = list()
schema2_download_files = list()
schema2_start_date = datetime.strptime('202102', "%Y%m")

for file_name in files_to_download:
    file_start_date = datetime.strptime(file_name.split("-")[0], "%Y%m")
    if file_start_date < schema2_start_date:
        schema1_download_files.append(file_name)
    else:
        schema2_download_files.append(file_name)

In [9]:
schema1_download_files, schema2_download_files

(['201306-citibike-tripdata.zip'], ['202112-citibike-tripdata.csv.zip'])

In [10]:
schema1_load_stage = state_dict['load_stage_name']+'/schema1/'
schema2_load_stage = state_dict['load_stage_name']+'/schema2/'

schema1_files_to_load = list()
for zip_file_name in schema1_download_files:
    
    url = state_dict['download_base_url']+zip_file_name
    
    print('Downloading and unzipping: '+url)
    r = requests.get(url)
    file = ZipFile(BytesIO(r.content))
    csv_file_name=file.namelist()[0]
    file.extract(csv_file_name)
    file.close()
    
    print('Putting '+csv_file_name+' to stage: '+schema1_load_stage)
    session.file.put(local_file_name=csv_file_name, 
                     stage_location=schema1_load_stage, 
                     source_compression='NONE', 
                     overwrite=True)
    schema1_files_to_load.append(csv_file_name)
    os.remove(csv_file_name)
    
schema2_files_to_load = list()
for zip_file_name in schema2_download_files:
    
    url = state_dict['download_base_url']+zip_file_name
    
    print('Downloading and unzipping: '+url)
    r = requests.get(url)
    file = ZipFile(BytesIO(r.content))
    csv_file_name=file.namelist()[0]
    file.extract(csv_file_name)
    file.close()
    
    print('Putting '+csv_file_name+' to stage: '+schema2_load_stage)
    session.file.put(local_file_name=csv_file_name, 
                     stage_location=schema2_load_stage, 
                     source_compression='NONE', 
                     overwrite=True)
    schema2_files_to_load.append(csv_file_name)
    os.remove(csv_file_name)

Downloading and unzipping: https://s3.amazonaws.com/tripdata/201306-citibike-tripdata.zip
Putting 201306-citibike-tripdata.csv to stage: LOAD_STAGE/schema1/
Downloading and unzipping: https://s3.amazonaws.com/tripdata/202112-citibike-tripdata.csv.zip
Putting 202112-citibike-tripdata.csv to stage: LOAD_STAGE/schema2/


In [11]:
session.sql("list @"+load_stage_name+" pattern='.*20.*[.]gz'").collect()

[Row(name='load_stage/schema1/201306-citibike-tripdata.csv.gz', size=16218896, md5='bd979640f17f10a3bf42f449aff29ad6', last_modified='Mon, 21 Feb 2022 23:51:40 GMT'),
 Row(name='load_stage/schema2/202112-citibike-tripdata.csv.gz', size=60670624, md5='f32f0bd73c8304fda217fa4b9aa554fa', last_modified='Mon, 21 Feb 2022 23:52:09 GMT')]

### 4. Load: 
Load raw as all string type.  We will fix data types in the transform stage.

There are two schema types so we will create two ingest tables.

In [12]:
#Upper case fields are common to both schemas.
#Schema from 2013 to 2021
load_schema1 = T.StructType([T.StructField("tripduration", T.StringType()),
                             T.StructField("STARTTIME", T.StringType()), 
                             T.StructField("STOPTIME", T.StringType()), 
                             T.StructField("START_STATION_ID", T.StringType()),
                             T.StructField("START_STATION_NAME", T.StringType()), 
                             T.StructField("START_STATION_LATITUDE", T.StringType()),
                             T.StructField("START_STATION_LONGITUDE", T.StringType()),
                             T.StructField("END_STATION_ID", T.StringType()),
                             T.StructField("END_STATION_NAME", T.StringType()), 
                             T.StructField("END_STATION_LATITUDE", T.StringType()),
                             T.StructField("END_STATION_LONGITUDE", T.StringType()),
                             T.StructField("bike_id", T.StringType()),
                             T.StructField("USERTYPE", T.StringType()), 
                             T.StructField("birth_year", T.StringType()),
                             T.StructField("gender", T.StringType())])

#starting in February 2021 the schema changed
load_schema2 = T.StructType([T.StructField("ride_id", T.StringType()), 
                             T.StructField("rideable_type", T.StringType()), 
                             T.StructField("STARTTIME", T.StringType()), 
                             T.StructField("STOPTIME", T.StringType()), 
                             T.StructField("START_STATION_NAME", T.StringType()), 
                             T.StructField("START_STATION_ID", T.StringType()),
                             T.StructField("END_STATION_NAME", T.StringType()), 
                             T.StructField("END_STATION_ID", T.StringType()),
                             T.StructField("START_STATION_LATITUDE", T.StringType()),
                             T.StructField("START_STATION_LONGITUDE", T.StringType()),
                             T.StructField("END_STATION_LATITUDE", T.StringType()),
                             T.StructField("END_STATION_LONGITUDE", T.StringType()),
                             T.StructField("USERTYPE", T.StringType())])

trips_table_schema = T.StructType([T.StructField("STARTTIME", T.StringType()), 
                             T.StructField("STOPTIME", T.StringType()), 
                             T.StructField("START_STATION_NAME", T.StringType()), 
                             T.StructField("START_STATION_ID", T.StringType()),
                             T.StructField("END_STATION_NAME", T.StringType()), 
                             T.StructField("END_STATION_ID", T.StringType()),
                             T.StructField("START_STATION_LATITUDE", T.StringType()),
                             T.StructField("START_STATION_LONGITUDE", T.StringType()),
                             T.StructField("END_STATION_LATITUDE", T.StringType()),
                             T.StructField("END_STATION_LONGITUDE", T.StringType()),
                             T.StructField("USERTYPE", T.StringType())])

Create empty tables in order to setup CDC.

In [17]:
session.createDataFrame([[None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]], 
                        schema=load_schema1)\
       .na.drop()\
       .write.mode('overwrite')\
       .saveAsTable(state_dict['load_table_name']+'schema1')

In [18]:
session.createDataFrame([[None, None, None, None, None, None, None, None, None, None, None, None, None]], 
                        schema=load_schema2)\
       .na.drop()\
       .write.mode('overwrite')\
       .saveAsTable(state_dict['load_table_name']+'schema2')

In [15]:
# loaddf = session.read.option("SKIP_HEADER", 1)\
#                      .option("FIELD_OPTIONALLY_ENCLOSED_BY", "\042")\
#                      .option("COMPRESSION", "GZIP")\
#                      .option("NULL_IF", "\\\\N")\
#                      .option("NULL_IF", "NULL")\
#                      .option("pattern", "'.*20.*[.]gz'")\
#                      .schema(load_schema1)\
#                      .csv(load_stage_name)

In [19]:
schema1_pipe_name = 'SCHEMA1_PIPE'
schema1_load_table = state_dict['load_table_name']+'schema1'
session.sql('CREATE OR REPLACE PIPE '+schema1_pipe_name+\
            ' AS COPY INTO '+schema1_load_table+\
            ' FROM @'+schema1_load_stage).collect()

[Row(status='Pipe SCHEMA1_PIPE successfully created.')]

In [20]:
schema2_pipe_name = 'SCHEMA2_PIPE'
schema2_load_table = state_dict['load_table_name']+'schema2'
session.sql('CREATE OR REPLACE PIPE '+schema2_pipe_name+\
              ' AS COPY INTO '+schema2_load_table+' FROM @'+schema2_load_stage).collect()

[Row(status='Pipe SCHEMA2_PIPE successfully created.')]

In [None]:
#!pip install -q snowflake-ingest

In [None]:
# from cryptography.hazmat.primitives import serialization as crypto_serialization
# from cryptography.hazmat.primitives.asymmetric import rsa
 
# keySize = 2048
 
# key = rsa.generate_private_key(public_exponent=65537, key_size=keySize)
 
# private_key = key.private_bytes(
#     crypto_serialization.Encoding.PEM,
#     crypto_serialization.PrivateFormat.PKCS8,
#     crypto_serialization.NoEncryption()
# )
# private_key = private_key.decode('utf-8')
 
# public_key = key.public_key().public_bytes(
#     crypto_serialization.Encoding.PEM,
#     crypto_serialization.PublicFormat.SubjectPublicKeyInfo
# )
# public_key = public_key.decode('utf-8')

# public_key_str=''.join(public_key.split('\n')[1:-2])
# private_key_str=''.join(private_key.split('\n')[1:-2])

# with open('citibike-ml-john-private-key.pem', 'w') as fh:
#     fh.write(private_key)
# with open('citibike-ml-john-public-key.pem', 'w') as fh:
#     fh.write(public_key)

# session.use_role('ACCOUNTADMIN')
# session.sql('ALTER USER '+connection_parameters['user']+' SET RSA_PUBLIC_KEY=\"'+public_key_str+'\"').collect()
# session.use_role('DBA_CITIBIKE')

In [24]:
from snowflake.ingest import SimpleIngestManager
from snowflake.ingest import StagedFile
from snowflake.ingest.utils.uris import DEFAULT_SCHEME
import json
with open('include/creds.json') as f:
    data = json.load(f)
    connection_parameters = {
      'account': data['account'],
      'user': data['username'],
      'password': data['password'],
      'role': data['role'],
      'schema': data['schema'],
      'database': data['database'],
      'warehouse': data['warehouse']}

In [48]:
# with open('citibike-ml-john-private-key.pem', 'r') as fh:
#     private_key=fh.read()
# private_key_str=''.join(private_key.split('\n')[1:-2])
from cryptography.hazmat.primitives.serialization import load_pem_private_key
from cryptography.hazmat.backends import default_backend

import os
with open("./citibike-ml-john-private-key.pem", 'rb') as pem_in:
  pemlines = pem_in.read()
  private_key_obj = load_pem_private_key(pemlines)

private_key_text = private_key_obj.private_bytes(
  Encoding.PEM, PrivateFormat.PKCS8, NoEncryption()).decode('utf-8')

private_key_text

TypeError: load_pem_private_key() missing 1 required positional argument: 'password'

In [26]:
schema1_ingest_manager = SimpleIngestManager(account=connection_parameters['account'],
                                             host=connection_parameters['account']+'.snowflakecomputing.com',
                                             user=connection_parameters['user'],
                                             pipe=schema1_pipe_name,
                                             private_key=private_key)

In [28]:
staged_file_list = []
for file_name in schema1_files_to_load:
    staged_file_list.append(StagedFile(file_name, None))

In [35]:
session.sql('describe pipe '+schema1_pipe_name).collect()

[Row(created_on=datetime.datetime(2022, 2, 21, 15, 54, 6, 165000, tzinfo=<DstTzInfo 'America/Los_Angeles' PST-1 day, 16:00:00 STD>), name='SCHEMA1_PIPE', database_name='CITIBIKEML', schema_name='DEMO', definition='COPY INTO RAW_schema1 FROM @LOAD_STAGE/schema1/', owner='DBA_CITIBIKE', notification_channel=None, comment='', integration=None, pattern=None, error_integration=None)]

In [31]:
schema1_ingest_manager.ingest_files(staged_file_list)

IngestResponseError: Http Error: 404, Vender Code: 390404, Message: Specified object does not exist or not authorized. Pipe not found

In [None]:
# import time
# from datetime import timedelta
# import datetime

# try:
#     resp = schema1_ingest_manager.ingest_files(staged_file_list)
# except HTTPError as e:
#     # HTTP error, may need to retry
#     logger.error(e)
#     exit(1)

# # This means Snowflake has received file and will start loading
# assert(resp['responseCode'] == 'SUCCESS')

# # Needs to wait for a while to get result in history
# while True:
#     history_resp = schema1_ingest_manager.get_history()

#     if len(history_resp['files']) > 0:
#         print('Ingest Report:\n')
#         print(history_resp)
#         break
#     else:
#         # wait for 20 seconds
#         time.sleep(20)

#     hour = timedelta(hours=1)
#     date = datetime.datetime.utcnow() - hour
#     history_range_resp = schema1_ingest_manager.get_history_range(date.isoformat() + 'Z')

#     print('\nHistory scan report: \n')
#     print(history_range_resp)

In [33]:
session.sql('list @LOAD_STAGE/schema1/').collect()

[Row(name='load_stage/schema1/201306-citibike-tripdata.csv.gz', size=16218896, md5='bd979640f17f10a3bf42f449aff29ad6', last_modified='Mon, 21 Feb 2022 23:51:40 GMT')]

In [32]:
session.get_current_role()

'"DBA_CITIBIKE"'

In [None]:
def load_trips_to_raw(session, files_to_load:list, load_stage_name:str, load_table_name:str):
    from snowflake.snowpark import functions as F
    from snowflake.snowpark import types as T
    from datetime import datetime

    stage_table_names = list()
    schema1_files = list()
    schema2_files = list()
    schema2_start_date = datetime.strptime('202102', "%Y%m")
    
    for file_name in files_to_load:
        file_start_date = datetime.strptime(file_name.split("-")[0], "%Y%m")
        if file_start_date < schema2_start_date:
            schema1_files.append(file_name)
        else:
            schema2_files.append(file_name)

    if len(schema1_files) > 0:
        load_schema1 = asdfasdf
        
        
        
        
        
        
        
        csv_file_format_options = {"FIELD_OPTIONALLY_ENCLOSED_BY": "'\"'", "skip_header": 1}
        
        stage_table_name = load_table_name + str('schema1')
        
        loaddf = session.read.option("SKIP_HEADER", 1)\
                              .option("FIELD_OPTIONALLY_ENCLOSED_BY", "\042")\
                              .option("COMPRESSION", "GZIP")\
                              .option("NULL_IF", "\\\\N")\
                              .option("NULL_IF", "NULL")\
                              .schema(load_schema1)\
                              .csv("@"+load_stage_name)\
                              .copy_into_table(stage_table_name, 
                                               files=schema1_files, 
                                               format_type_options=csv_file_format_options)
        stage_table_names.append(stage_table_name)


    if len(schema2_files) > 0:
        load_schema2 = asdfasdf
        
        
        
        
        
        csv_file_format_options = {"FIELD_OPTIONALLY_ENCLOSED_BY": "'\"'", "skip_header": 1}

        stage_table_name = load_table_name + str('schema2')
        loaddf = session.read.option("SKIP_HEADER", 1)\
                              .option("FIELD_OPTIONALLY_ENCLOSED_BY", "\042")\
                              .option("COMPRESSION", "GZIP")\
                              .option("NULL_IF", "\\\\N")\
                              .option("NULL_IF", "NULL")\
                              .schema(load_schema2)\
                              .csv("@"+load_stage_name)\
                              .copy_into_table(stage_table_name, 
                                               files=schema2_files, 
                                               format_type_options=csv_file_format_options)
        stage_table_names.append(stage_table_name)
        
    return list(set(stage_table_names))

In [None]:
from datetime import datetime
interim_target_table_names = list()
for stage_table_name in stage_table_names:
    schema = stage_table_name.split("_")[1]
    if schema == 'schema1':
        interim_target_table_name = 'INTERIM_schema1'
        stream_name = 'STREAM_schema1'
        task_name = 'TRIPSCDCTASK_schema1'
        procedure_name = 'TRIPSCDCPROC_schema1'
        create_processcdc_procedure_statement = schema1_spoc_str(procedure_name, 
                                                                 interim_target_table_name, 
                                                                 stream_name)

    elif schema == 'schema2':
        interim_target_table_name = 'INTERIM_schema2'
        stream_name = 'STREAM_schema2'
        task_name = 'TRIPSCDCTASK_schema2'
        procedure_name = 'TRIPSCDCPROC_schema2'
        create_processcdc_procedure_statement = schema2_spoc_str(procedure_name, 
                                                                 interim_target_table_name, 
                                                                 stream_name)

    #outside the if else condition but still inside the for loop
    interim_target_table_names.append(interim_target_table_name)
    create_stream_sql ='CREATE OR REPLACE STREAM ' + stream_name + \
                   ' ON TABLE ' + stage_table_name + \
                   ' APPEND_ONLY = FALSE SHOW_INITIAL_ROWS = TRUE'

    create_interim_target_table_sql = 'CREATE OR REPLACE TABLE ' + interim_target_table_name +\
                                ' LIKE ' + stage_table_name
    create_task_statement = "CREATE OR REPLACE TASK " + task_name + \
                        " WAREHOUSE='" + cdc_task_warehouse_name +"'"+ \
                        " SCHEDULE = '1 minute'"+ \
                        " WHEN SYSTEM$STREAM_HAS_DATA('" + stream_name + "')"+\
                        " AS CALL " + procedure_name + "()"
    resume_task_statement = "ALTER TASK " + task_name + " RESUME"

    _ = session.sql(create_stream_sql).collect()
    _ = session.sql(create_interim_target_table_sql).collect() 
    _ = session.sql(create_processcdc_procedure_statement).collect()
    _ = session.sql(create_task_statement).collect()
    _ = session.sql(resume_task_statement).collect()


In [None]:
%%time
csv_file_format_options = {"FIELD_OPTIONALLY_ENCLOSED_BY": "'\"'", "skip_header": 1}

print('Loading '+str(loaddf.count())+' records to table '+state_dict['load_table_name']+str('schema1'))
loaddf.copy_into_table(state_dict['load_table_name']+str('schema1'), 
                       files=files_to_load, 
                       format_type_options=csv_file_format_options)

### 5. Transform:
We have the raw data loaded. Now let's transform this data and clean it up. This will push the data to a final \"transformed\" table to be consumed by our Data Science team.


In [None]:
transdf = session.table(state_dict['load_table_name']+'schema1')

There are three different date formats "2014-08-10 15:21:22", "1/1/2015 1:30" and "12/1/2014 02:04:53"

In [None]:
date_format_2 = "1/1/2015 [0-9]:.*$"      #1/1/2015 1:30 -> #M*M/D*D/YYYY H*H:M*M(:SS)*
date_format_3 = "1/1/2015 [0-9][0-9]:.*$" #1/1/2015 10:30 -> #M*M/D*D/YYYY H*H:M*M(:SS)*
date_format_4 = "12/1/2014.*"             #12/1/2014 02:04:53 -> M*M/D*D/YYYY 

#Change all dates to YYYY-MM-DD HH:MI:SS format
date_format_match = "^([0-9]?[0-9])/([0-9]?[0-9])/([0-9][0-9][0-9][0-9]) ([0-9]?[0-9]):([0-9][0-9])(:[0-9][0-9])?.*$"
date_format_repl = "\\3-\\1-\\2 \\4:\\5\\6"

In [None]:
transdf.withColumn('STARTTIME', F.regexp_replace(F.col('STARTTIME'),
                                            F.lit(date_format_match), 
                                            F.lit(date_format_repl)))\
      .withColumn('STARTTIME', F.to_timestamp('STARTTIME'))\
      .withColumn('STOPTIME', F.regexp_replace(F.col('STOPTIME'),
                                            F.lit(date_format_match), 
                                            F.lit(date_format_repl)))\
      .withColumn('STOPTIME', F.to_timestamp('STOPTIME'))\
      .select(F.col('STARTTIME'), 
              F.col('STOPTIME'), 
              F.col('START_STATION_ID'), 
              F.col('START_STATION_NAME'), 
              F.col('START_STATION_LATITUDE'), 
              F.col('START_STATION_LONGITUDE'), 
              F.col('END_STATION_ID'), 
              F.col('END_STATION_NAME'), F.col('END_STATION_LATITUDE'), 
              F.col('END_STATION_LONGITUDE'), 
              F.col('USERTYPE'))\
      .write.mode('overwrite').saveAsTable(state_dict['trips_table_name'])

In [None]:
testdf = session.table(state_dict['trips_table_name'])
testdf.schema

In [None]:
testdf.count()

### 6. Export code in functional modules for MLOps and orchestration

In [None]:
%%writefile dags/elt.py
def schema1_definition():
    from snowflake.snowpark import types as T
    load_schema1 = T.StructType([T.StructField("TRIPDURATION", T.StringType()),
                             T.StructField("STARTTIME", T.StringType()), 
                             T.StructField("STOPTIME", T.StringType()), 
                             T.StructField("START_STATION_ID", T.StringType()),
                             T.StructField("START_STATION_NAME", T.StringType()), 
                             T.StructField("START_STATION_LATITUDE", T.StringType()),
                             T.StructField("START_STATION_LONGITUDE", T.StringType()),
                             T.StructField("END_STATION_ID", T.StringType()),
                             T.StructField("END_STATION_NAME", T.StringType()), 
                             T.StructField("END_STATION_LATITUDE", T.StringType()),
                             T.StructField("END_STATION_LONGITUDE", T.StringType()),
                             T.StructField("BIKEID", T.StringType()),
                             T.StructField("USERTYPE", T.StringType()), 
                             T.StructField("BIRTH_YEAR", T.StringType()),
                             T.StructField("GENDER", T.StringType())])
    return load_schema1

def schema2_definition():
    from snowflake.snowpark import types as T
    load_schema2 = T.StructType([T.StructField("ride_id", T.StringType()), 
                             T.StructField("rideable_type", T.StringType()), 
                             T.StructField("STARTTIME", T.StringType()), 
                             T.StructField("STOPTIME", T.StringType()), 
                             T.StructField("START_STATION_NAME", T.StringType()), 
                             T.StructField("START_STATION_ID", T.StringType()),
                             T.StructField("END_STATION_NAME", T.StringType()), 
                             T.StructField("END_STATION_ID", T.StringType()),
                             T.StructField("START_STATION_LATITUDE", T.StringType()),
                             T.StructField("START_STATION_LONGITUDE", T.StringType()),
                             T.StructField("END_STATION_LATITUDE", T.StringType()),
                             T.StructField("END_STATION_LONGITUDE", T.StringType()),
                             T.StructField("USERTYPE", T.StringType())])
    return load_schema2

def conformed_schema():
    from snowflake.snowpark import types as T
    trips_table_schema = T.StructType([T.StructField("STARTTIME", T.StringType()), 
                             T.StructField("STOPTIME", T.StringType()), 
                             T.StructField("START_STATION_NAME", T.StringType()), 
                             T.StructField("START_STATION_ID", T.StringType()),
                             T.StructField("END_STATION_NAME", T.StringType()), 
                             T.StructField("END_STATION_ID", T.StringType()),
                             T.StructField("START_STATION_LATITUDE", T.StringType()),
                             T.StructField("START_STATION_LONGITUDE", T.StringType()),
                             T.StructField("END_STATION_LATITUDE", T.StringType()),
                             T.StructField("END_STATION_LONGITUDE", T.StringType()),
                             T.StructField("USERTYPE", T.StringType())])
    return trips_table_schema


def setup_cdc():
    
    
    
    
    
    
    
    
    
    return


def extract_trips_to_stage(session, files_to_download: list, download_base_url: str, load_stage_name:str):
    import os 
    import requests
    from zipfile import ZipFile
    import gzip
    
    files_to_load = list()
    
    for zip_file_name in files_to_download:
        gz_file_name = os.path.splitext(zip_file_name)[0]+'.gz'
        url = download_base_url+zip_file_name

        print('Downloading file '+url)
        r = requests.get(url)
        with open(zip_file_name, 'wb') as fh:
            fh.write(r.content)

        with ZipFile(zip_file_name, 'r') as zipObj:
            csv_file_names = zipObj.namelist()
            with zipObj.open(name=csv_file_names[0], mode='r') as zf:
                print('Gzipping file '+csv_file_names[0])
                with gzip.open(gz_file_name, 'wb') as gzf:
                    gzf.write(zf.read())

        print('Putting file '+gz_file_name+' to stage '+load_stage_name)
        session.file.put(gz_file_name, '@'+load_stage_name)
        
        files_to_load.append(gz_file_name)
        os.remove(zip_file_name)
        os.remove(gz_file_name)
    
    return load_stage_name, files_to_load
        
def load_trips_to_raw(session, files_to_load:list, load_stage_name:str, load_table_name:str):
    from snowflake.snowpark import functions as F
    from snowflake.snowpark import types as T
    from datetime import datetime

    stage_table_names = list()
    schema1_files = list()
    schema2_files = list()
    schema2_start_date = datetime.strptime('202102', "%Y%m")
    
    for file_name in files_to_load:
        file_start_date = datetime.strptime(file_name.split("-")[0], "%Y%m")
        if file_start_date < schema2_start_date:
            schema1_files.append(file_name)
        else:
            schema2_files.append(file_name)

    if len(schema1_files) > 0:
        load_schema1 = asdfasdf
        
        
        
        
        
        
        
        csv_file_format_options = {"FIELD_OPTIONALLY_ENCLOSED_BY": "'\"'", "skip_header": 1}
        
        stage_table_name = load_table_name + str('schema1')
        
        loaddf = session.read.option("SKIP_HEADER", 1)\
                              .option("FIELD_OPTIONALLY_ENCLOSED_BY", "\042")\
                              .option("COMPRESSION", "GZIP")\
                              .option("NULL_IF", "\\\\N")\
                              .option("NULL_IF", "NULL")\
                              .schema(load_schema1)\
                              .csv("@"+load_stage_name)\
                              .copy_into_table(stage_table_name, 
                                               files=schema1_files, 
                                               format_type_options=csv_file_format_options)
        stage_table_names.append(stage_table_name)


    if len(schema2_files) > 0:
        load_schema2 = asdfasdf
        
        
        
        
        
        csv_file_format_options = {"FIELD_OPTIONALLY_ENCLOSED_BY": "'\"'", "skip_header": 1}

        stage_table_name = load_table_name + str('schema2')
        loaddf = session.read.option("SKIP_HEADER", 1)\
                              .option("FIELD_OPTIONALLY_ENCLOSED_BY", "\042")\
                              .option("COMPRESSION", "GZIP")\
                              .option("NULL_IF", "\\\\N")\
                              .option("NULL_IF", "NULL")\
                              .schema(load_schema2)\
                              .csv("@"+load_stage_name)\
                              .copy_into_table(stage_table_name, 
                                               files=schema2_files, 
                                               format_type_options=csv_file_format_options)
        stage_table_names.append(stage_table_name)
        
    return list(set(stage_table_names))
    
def transform_trips(session, stage_table_names:list, trips_table_name:str):
    from snowflake.snowpark import functions as F
        
    #Change all dates to YYYY-MM-DD HH:MI:SS format
    date_format_match = "^([0-9]?[0-9])/([0-9]?[0-9])/([0-9][0-9][0-9][0-9]) ([0-9]?[0-9]):([0-9][0-9])(:[0-9][0-9])?.*$"
    date_format_repl = "\\3-\\1-\\2 \\4:\\5\\6"
    
    for stage_table_name in stage_table_names:
        
        transdf = session.table(stage_table_name)
        transdf.withColumn('STARTTIME', F.regexp_replace(F.col('STARTTIME'),
                                                F.lit(date_format_match), 
                                                F.lit(date_format_repl)))\
               .withColumn('STARTTIME', F.to_timestamp('STARTTIME'))\
               .withColumn('STOPTIME', F.regexp_replace(F.col('STOPTIME'),
                                                F.lit(date_format_match), 
                                                F.lit(date_format_repl)))\
               .withColumn('STOPTIME', F.to_timestamp('STOPTIME'))\
               .select(F.col('STARTTIME'), 
                       F.col('STOPTIME'), 
                       F.col('START_STATION_ID'), 
                       F.col('START_STATION_NAME'), 
                       F.col('START_STATION_LATITUDE'), 
                       F.col('START_STATION_LONGITUDE'), 
                       F.col('END_STATION_ID'), 
                       F.col('END_STATION_NAME'), F.col('END_STATION_LATITUDE'), 
                       F.col('END_STATION_LONGITUDE'), 
                       F.col('USERTYPE'))\
               .write.saveAsTable(trips_table_name)

    return trips_table_name
    

In [None]:
session.close()