In [None]:
use role sysadmin;

In [None]:
import streamlit as st
from snowflake.snowpark import Session
session = Session.builder.getOrCreate()

In [None]:
from snowflake.snowpark import Session
from snowflake.snowpark.functions import col
import snowflake.connector
from snowflake.connector.cursor import SnowflakeCursor
from snowflake.connector.converter import SnowflakeConverter
from sqlalchemy import create_engine
import sqlalchemy
import pandas as pd
import logging
import re
import time

In [None]:
snowflake.connector.paramstyle = "pyformat"
import wheel_loader

In [None]:
from snowflake.snowpark._internal.utils import is_in_stored_procedure

In [None]:
def _process_params_dict(self, params, cursor):
    try:
        res = {k: self._process_single_param(v) for k, v in params.items()}
        return res
    except Exception as e:
        raise Exception("Failed processing pyformat-parameters: {e}")

In [None]:
def _process_params_pyformat(self,params, cursor):
        #todoremove print("----->")
        #todoremove print(params)
        if params is None:
            return {}
        if isinstance(params, dict):
            return self._process_params_dict(params,cursor)
        if not isinstance(params, (tuple, list)):
            params = [params,]
        try:
            res = map(self._process_single_param, params)
            ret = tuple(res)
            return ret
        except Exception as e:
            raise Exception(f"Failed processing pyformat-parameters; {self}{params} {e}")

In [None]:
def _process_single_param(self, param):
        to_snowflake = self.converter.to_snowflake
        escape = self.converter.escape
        _quote = self.converter.quote
        return _quote(escape(to_snowflake(param)))

In [None]:
def create_sqlalchemy_engine(session: Session):
    import snowflake.connector.connection
    from sqlalchemy.engine.url import URL
    from sqlalchemy.engine.base import Connection
    setattr(Connection,"url",URL.create("snowflake"))
    # patch this import
    # patch missing method
    if is_in_stored_procedure():
        wheel_loader.add_wheels() # download wheel from pypi
        snowflake.connector.connection.SnowflakeConnection = snowflake.connector.connection.StoredProcConnection
        setattr(snowflake.connector.connection.StoredProcConnection,"_process_params_pyformat",_process_params_pyformat)
        setattr(snowflake.connector.connection.StoredProcConnection,"_process_params_dict",_process_params_dict)
        setattr(snowflake.connector.connection.StoredProcConnection,"_process_single_param",_process_single_param)

    # Your existing Snowflake connection (replace with your actual connection)
    existing_snowflake_connection = session._conn._conn
    setattr(existing_snowflake_connection,"_interpolate_empty_sequences",False)
    # sql alchemy needs pyformat binding
    existing_snowflake_connection._paramstyle = "pyformat"
    opts = ""
    if session.get_current_warehouse() is not None:
        opts += f"&warehouse={session.get_current_warehouse()}"
    if session.get_current_role() is not None:
        opts += f"&role={session.get_current_role()}"
    conn_url = f"snowflake://{session.get_current_user() or ''}@{session.get_current_account()}/{session.get_current_database() or ''}/{session.get_current_schema() or ''}?{opts}"
    # Create an engine and bind it to the existing Snowflake connection
    engine = create_engine(
        url=conn_url,
        creator=lambda: existing_snowflake_connection
    )
    return engine

In [None]:
from sqlalchemy.dialects import registry
registry.register('snowflake', 'snowflake.sqlalchemy', 'dialect')

In [None]:
import sqlalchemy
sqlalchemy.__version__
import snowflake.sqlalchemy

In [None]:
engine = create_sqlalchemy_engine(session)

In [None]:
stage_path = '@DMAS.AOEC.AOEC_FILES_HS12/'
input_file_list = [
'hs12_country_country_product_year_1.dta',
'hs12_country_country_product_year_2.dta',
'hs12_country_country_product_year_4_2012_2016.dta',
'hs12_country_country_product_year_4_2017_2021.dta',
'hs12_country_country_product_year_4_2022.dta',
'hs12_country_country_product_year_6_2012_2016.dta',
'hs12_country_country_product_year_6_2017_2021.dta',
'hs12_country_country_product_year_6_2022.dta',
'hs12_country_product_year_1.dta',
'hs12_country_product_year_2.dta',
'hs12_country_product_year_4.dta',
'hs12_country_product_year_6.dta',
]

In [None]:
for file in input_file_list:
    file_stage = file+'.gz'
    prog = re.compile(r'^(.*)\.dta')
    file_base = prog.match(file).group(1)
    table_name = 'raw_'+file_base
    print('--- (Info):  {0:s}'.format(table_name), end='')

    ts_A_secs = time.time()
    ts_A_time = time.localtime(ts_A_secs)
    session.file.get(stage_path+file_stage, 'TMP_STATA')
    df = pd.read_stata('TMP_STATA/'+file_stage)
    ts_B_secs = time.time()
    ts_B_time = time.localtime(ts_B_secs)
    with engine.begin() as tx:
        try:
            df.to_sql(table_name, engine, index = False, chunksize = 100000, if_exists = 'replace')
        except:
            print(": Get:{0:02d}:{1:02d}:{2:02d} ({3:,d} sec)".format(ts_A_time.tm_hour, ts_A_time.tm_min, ts_A_time.tm_sec, int(ts_B_secs - ts_A_secs)))
            print("*** (Error): {0:s}".format(repr(sys.exception())))
            try:
                tx.rollback()
            except:
                print("*** (Error): {0:s}".format(repr(sys.exception())))
                print("*** (Error): rollback failed.")
                raise
                  
    ts_C_secs = time.time()
    ts_C_time = time.localtime(ts_C_secs)
    print(' Get:  {0:02d}:{1:02d}:{2:02d} ({3:,d} sec)'.format(ts_A_time.tm_hour, ts_A_time.tm_min, ts_A_time.tm_sec, int(ts_B_secs - ts_A_secs)))
    print('--- (Info):  {0:s} Load: {1:02d}:{2:02d}:{3:02d} ({4:,d} sec)'.format(table_name, ts_B_time.tm_hour, ts_B_time.tm_min, ts_B_time.tm_sec, int(ts_C_secs - ts_B_secs)))
