/
snowpark_dataset.py
248 lines (208 loc) · 8.61 KB
/
snowpark_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
"""``AbstractDataset`` implementation to access Snowflake using Snowpark dataframes
"""
from __future__ import annotations
import logging
from copy import deepcopy
from typing import Any
import snowflake.snowpark as sp
from kedro.io.core import AbstractDataset, DatasetError
logger = logging.getLogger(__name__)
class SnowparkTableDataset(AbstractDataset):
"""``SnowparkTableDataset`` loads and saves Snowpark dataframes.
As of Mar-2023, the snowpark connector only works with Python 3.8.
Example usage for the
`YAML API <https://kedro.readthedocs.io/en/stable/data/\
data_catalog_yaml_examples.html>`_:
.. code-block:: yaml
weather:
type: kedro_datasets.snowflake.SnowparkTableDataset
table_name: "weather_data"
database: "meteorology"
schema: "observations"
credentials: db_credentials
save_args:
mode: overwrite
column_order: name
table_type: ''
You can skip everything but "table_name" if the database and
schema are provided via credentials. That way catalog entries can be shorter
if, for example, all used Snowflake tables live in same database/schema.
Values in the dataset definition take priority over those defined in credentials.
Example:
Credentials file provides all connection attributes, catalog entry
"weather" reuses credentials parameters, "polygons" catalog entry reuses
all credentials parameters except providing a different schema name.
Second example of credentials file uses ``externalbrowser`` authentication.
catalog.yml
.. code-block:: yaml
weather:
type: kedro_datasets.snowflake.SnowparkTableDataset
table_name: "weather_data"
database: "meteorology"
schema: "observations"
credentials: snowflake_client
save_args:
mode: overwrite
column_order: name
table_type: ''
polygons:
type: kedro_datasets.snowflake.SnowparkTableDataset
table_name: "geopolygons"
credentials: snowflake_client
schema: "geodata"
credentials.yml
.. code-block:: yaml
snowflake_client:
account: 'ab12345.eu-central-1'
port: 443
warehouse: "datascience_wh"
database: "detailed_data"
schema: "observations"
user: "service_account_abc"
password: "supersecret"
credentials.yml (with externalbrowser authenticator)
.. code-block:: yaml
snowflake_client:
account: 'ab12345.eu-central-1'
port: 443
warehouse: "datascience_wh"
database: "detailed_data"
schema: "observations"
user: "john_doe@wdomain.com"
authenticator: "externalbrowser"
"""
# this dataset cannot be used with ``ParallelRunner``,
# therefore it has the attribute ``_SINGLE_PROCESS = True``
# for parallelism within a pipeline please consider
# ``ThreadRunner`` instead
_SINGLE_PROCESS = True
DEFAULT_LOAD_ARGS: dict[str, Any] = {}
DEFAULT_SAVE_ARGS: dict[str, Any] = {}
def __init__( # noqa: PLR0913
self,
*,
table_name: str,
schema: str | None = None,
database: str | None = None,
load_args: dict[str, Any] | None = None,
save_args: dict[str, Any] | None = None,
credentials: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""Creates a new instance of ``SnowparkTableDataset``.
Args:
table_name: The table name to load or save data to.
schema: Name of the schema where ``table_name`` is.
Optional as can be provided as part of ``credentials``
dictionary. Argument value takes priority over one provided
in ``credentials`` if any.
database: Name of the database where ``schema`` is.
Optional as can be provided as part of ``credentials``
dictionary. Argument value takes priority over one provided
in ``credentials`` if any.
load_args: Currently not used
save_args: Provided to underlying snowpark ``save_as_table``
To find all supported arguments, see here:
https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/api/snowflake.snowpark.DataFrameWriter.saveAsTable.html
credentials: A dictionary with a snowpark connection string.
To find all supported arguments, see here:
https://docs.snowflake.com/en/user-guide/python-connector-api.html#connect
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.
"""
if not table_name:
raise DatasetError("'table_name' argument cannot be empty.")
if not credentials:
raise DatasetError("'credentials' argument cannot be empty.")
if not database:
if not ("database" in credentials and credentials["database"]):
raise DatasetError(
"'database' must be provided by credentials or dataset."
)
database = credentials["database"]
if not schema:
if not ("schema" in credentials and credentials["schema"]):
raise DatasetError(
"'schema' must be provided by credentials or dataset."
)
schema = credentials["schema"]
# Handle default load and save arguments
self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)
self._table_name = table_name
self._database = database
self._schema = schema
connection_parameters = credentials
connection_parameters.update(
{"database": self._database, "schema": self._schema}
)
self._connection_parameters = connection_parameters
self.metadata = metadata
def _describe(self) -> dict[str, Any]:
return {
"table_name": self._table_name,
"database": self._database,
"schema": self._schema,
}
@staticmethod
def _get_session(connection_parameters) -> sp.Session:
"""Given a connection string, create singleton connection
to be used across all instances of `SnowparkTableDataset` that
need to connect to the same source.
connection_parameters is a dictionary of any values
supported by snowflake python connector:
https://docs.snowflake.com/en/user-guide/python-connector-api.html#connect
example:
connection_parameters = {
"account": "",
"user": "",
"password": "", (optional)
"role": "", (optional)
"warehouse": "", (optional)
"database": "", (optional)
"schema": "", (optional)
"authenticator: "" (optional)
}
"""
try:
logger.debug("Trying to reuse active snowpark session...")
session = sp.context.get_active_session()
except sp.exceptions.SnowparkSessionException:
logger.debug("No active snowpark session found. Creating...")
session = sp.Session.builder.configs(connection_parameters).create()
return session
@property
def _session(self) -> sp.Session:
return self._get_session(self._connection_parameters)
def _load(self) -> sp.DataFrame:
table_name: list = [
self._database,
self._schema,
self._table_name,
]
sp_df = self._session.table(".".join(table_name))
return sp_df
def _save(self, data: sp.DataFrame) -> None:
table_name = [
self._database,
self._schema,
self._table_name,
]
data.write.save_as_table(table_name, **self._save_args)
def _exists(self) -> bool:
session = self._session
query = "SELECT COUNT(*) FROM {database}.INFORMATION_SCHEMA.TABLES \
WHERE TABLE_SCHEMA = '{schema}' \
AND TABLE_NAME = '{table_name}'"
rows = session.sql(
query.format(
database=self._database,
schema=self._schema,
table_name=self._table_name,
)
).collect()
return rows[0][0] == 1