In [24]:
import snowflake.snowpark as snp
from snowflake.snowpark import functions as F
from snowflake.snowpark import types as T

import pandas as pd
from datetime import datetime
import requests
from zipfile import ZipFile
from io import BytesIO
import os

In [25]:
from dags.snowpark_connection import snowpark_connect
session, state_dict = snowpark_connect('./include/state.json')

In [26]:
state_dict['load_stage_name']='LOAD_STAGE' 
state_dict['download_base_url']='https://s3.amazonaws.com/tripdata/'
state_dict['trips_table_name']='TRIPS'
state_dict['load_table_name'] = 'RAW_'

import json
with open('./include/state.json', 'w') as sdf:
    json.dump(state_dict, sdf)

In [27]:
def reset_database(session, state_dict:dict, prestaged=False):
    _ = session.sql('CREATE OR REPLACE DATABASE '+state_dict['connection_parameters']['database']).collect()
    _ = session.sql('CREATE SCHEMA '+state_dict['connection_parameters']['schema']).collect() 

    if prestaged:
        sql_cmd = 'CREATE OR REPLACE STAGE '+state_dict['load_stage_name']+\
                  ' url='+state_dict['connection_parameters']['download_base_url']
        _ = session.sql(sql_cmd).collect()
    else: 
        _ = session.sql('CREATE STAGE IF NOT EXISTS '+state_dict['load_stage_name']).collect()

In [28]:
reset_database(session, state_dict)

When I tried running the following code snippets, I faced error that files don't exist. Later, I figured that the file structures within S3 got updated. Data up to 2023 has been regrouped into yearly data.

In [6]:
#For files like 201306-citibike-tripdata.zip
date_range1 = pd.period_range(start=datetime.strptime("201306", "%Y%m"), 
                             end=datetime.strptime("201612", "%Y%m"), 
                             freq='M').strftime("%Y%m")
file_name_end1 = '-citibike-tripdata.zip'
files_to_download = [date+file_name_end1 for date in date_range1.to_list()]

In [13]:
date_range2 = pd.period_range(start=datetime.strptime("201701", "%Y%m"), 
                             end=datetime.strptime("202112", "%Y%m"), 
                             freq='M').strftime("%Y%m")
file_name_end2 = '-citibike-tripdata.csv.zip'
files_to_download = files_to_download + [date+file_name_end2 for date in date_range2.to_list()]

In [14]:
files_to_download = [files_to_download[i] for i in [0,102]]
files_to_download

['201306-citibike-tripdata.zip', '202005-citibike-tripdata.csv.zip']

In [29]:
session.use_warehouse(state_dict['compute_parameters']['fe_warehouse'])

In [16]:
schema1_download_files = list()
schema2_download_files = list()
schema2_start_date = datetime.strptime('202102', "%Y%m")

for file_name in files_to_download:
    file_start_date = datetime.strptime(file_name.split("-")[0], "%Y%m")
    if file_start_date < schema2_start_date:
        schema1_download_files.append(file_name)
    else:
        schema2_download_files.append(file_name)

In [17]:
schema1_download_files, schema2_download_files

(['201306-citibike-tripdata.zip', '202005-citibike-tripdata.csv.zip'], [])

It seems like the source S3 has structure change, so the testing files presenter provided doesn't work.

In [30]:
schema1_download_files = ['202502-citibike-tripdata.zip']
schema2_download_files = ['202503-citibike-tripdata.csv.zip']

In [31]:
schema1_load_stage = state_dict['load_stage_name']+'/schema1/'
schema2_load_stage = state_dict['load_stage_name']+'/schema2/'

schema1_files_to_load = list()
for zip_file_name in schema1_download_files:
    
    url = state_dict['download_base_url']+zip_file_name
    
    print('Downloading and unzipping: '+url)
    r = requests.get(url)
    file = ZipFile(BytesIO(r.content))
    csv_file_name=file.namelist()[0]
    file.extract(csv_file_name)
    file.close()
    
    print('Putting '+csv_file_name+' to stage: '+schema1_load_stage)
    session.file.put(local_file_name=csv_file_name, 
                     stage_location=schema1_load_stage, 
                     source_compression='NONE', 
                     overwrite=True)
    schema1_files_to_load.append(csv_file_name)
    os.remove(csv_file_name)
    
schema2_files_to_load = list()
for zip_file_name in schema2_download_files:
    
    url = state_dict['download_base_url']+zip_file_name
    
    print('Downloading and unzipping: '+url)
    r = requests.get(url)
    file = ZipFile(BytesIO(r.content))
    csv_file_name=file.namelist()[0]
    file.extract(csv_file_name)
    file.close()
    
    print('Putting '+csv_file_name+' to stage: '+schema2_load_stage)
    session.file.put(local_file_name=csv_file_name, 
                     stage_location=schema2_load_stage, 
                     source_compression='NONE', 
                     overwrite=True)
    schema2_files_to_load.append(csv_file_name)
    os.remove(csv_file_name)

Downloading and unzipping: https://s3.amazonaws.com/tripdata/202502-citibike-tripdata.zip
Putting 202502-citibike-tripdata_3.csv to stage: LOAD_STAGE/schema1/
Downloading and unzipping: https://s3.amazonaws.com/tripdata/202503-citibike-tripdata.csv.zip
Putting 202503-citibike-tripdata.csv to stage: LOAD_STAGE/schema2/


In [32]:

session.sql("list @"+state_dict['load_stage_name']+" pattern='.*20.*[.]gz'").collect()

[Row(name='load_stage/schema1/202502-citibike-tripdata_3.csv.gz', size=1218176, md5='ba700f6ccef6e67134de7c440a2797b3', last_modified='Sat, 26 Apr 2025 06:16:21 GMT'),
 Row(name='load_stage/schema2/202503-citibike-tripdata.csv.gz', size=123891344, md5='66aa886404ebada6e88653433838b655', last_modified='Sat, 26 Apr 2025 06:17:39 GMT')]

In [33]:
#starting in February 2021 the schema changed
load_schema = T.StructType([T.StructField("ride_id", T.StringType()), 
                             T.StructField("rideable_type", T.StringType()), 
                             T.StructField("STARTTIME", T.StringType()), 
                             T.StructField("STOPTIME", T.StringType()), 
                             T.StructField("START_STATION_NAME", T.StringType()), 
                             T.StructField("START_STATION_ID", T.StringType()),
                             T.StructField("END_STATION_NAME", T.StringType()), 
                             T.StructField("END_STATION_ID", T.StringType()),
                             T.StructField("START_STATION_LATITUDE", T.StringType()),
                             T.StructField("START_STATION_LONGITUDE", T.StringType()),
                             T.StructField("END_STATION_LATITUDE", T.StringType()),
                             T.StructField("END_STATION_LONGITUDE", T.StringType()),
                             T.StructField("USERTYPE", T.StringType())])

trips_table_schema = T.StructType([T.StructField("STARTTIME", T.StringType()), 
                             T.StructField("STOPTIME", T.StringType()), 
                             T.StructField("START_STATION_NAME", T.StringType()), 
                             T.StructField("START_STATION_ID", T.StringType()),
                             T.StructField("END_STATION_NAME", T.StringType()), 
                             T.StructField("END_STATION_ID", T.StringType()),
                             T.StructField("START_STATION_LATITUDE", T.StringType()),
                             T.StructField("START_STATION_LONGITUDE", T.StringType()),
                             T.StructField("END_STATION_LATITUDE", T.StringType()),
                             T.StructField("END_STATION_LONGITUDE", T.StringType()),
                             T.StructField("USERTYPE", T.StringType())])

In [34]:
for s in ['schema1', 'schema2']:
    session.create_dataframe([[None]*len(load_schema.names)], schema=load_schema)\
        .na.drop()\
        .write\
        .save_as_table(state_dict['load_table_name'] + s)

In [35]:
csv_file_format_options = {"FIELD_OPTIONALLY_ENCLOSED_BY": "'\"'", "skip_header": 1}
load_stages = [schema1_load_stage, schema2_load_stage]
table_name_suffix_ls = ['schema1', 'schema2']
for i in range(2):
    loaddf = session.read.option("SKIP_HEADER", 1)\
                        .option("FIELD_OPTIONALLY_ENCLOSED_BY", "\042")\
                        .option("COMPRESSION", "GZIP")\
                        .option("NULL_IF", "\\\\N")\
                        .option("NULL_IF", "NULL")\
                        .option("pattern", "'.*20.*[.]gz'")\
                        .schema(load_schema)\
                        .csv('@'+load_stages[i])\
                        .copy_into_table(state_dict['load_table_name']+str(table_name_suffix_ls[i]), 
                                        format_type_options=csv_file_format_options)