Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FeatureStore] Support SQL DBs as source and online target #2869

Merged
merged 18 commits into from Jan 19, 2023
1 change: 1 addition & 0 deletions dev-requirements.txt
Expand Up @@ -19,3 +19,4 @@ scikit-learn~=1.0
# needed for frameworks tests
lightgbm~=3.0
xgboost~=1.1
sqlalchemy_utils~=0.39.0
1 change: 1 addition & 0 deletions mlrun/api/crud/client_spec.py
Expand Up @@ -44,6 +44,7 @@ def get_client_spec(self):
generate_artifact_target_path_from_artifact_hash=config.artifacts.generate_target_path_from_artifact_hash,
redis_url=config.redis.url,
redis_type=config.redis.type,
sql_url=config.sql.url,
# These don't have a default value, but we don't send them if they are not set to allow the client to know
# when to use server value and when to use client value (server only if set). Since their default value is
# empty and not set is also empty we can use the same _get_config_value_if_not_default
Expand Down
1 change: 1 addition & 0 deletions mlrun/api/schemas/client_spec.py
Expand Up @@ -56,6 +56,7 @@ class ClientSpec(pydantic.BaseModel):
function: typing.Optional[Function]
redis_url: typing.Optional[str]
redis_type: typing.Optional[str]
sql_url: typing.Optional[str]

# ce_mode is deprecated, we will use the full ce config instead and ce_mode will be removed in 1.6.0
ce_mode: typing.Optional[str]
Expand Down
3 changes: 3 additions & 0 deletions mlrun/config.py
Expand Up @@ -117,6 +117,9 @@
"url": "",
"type": "standalone", # deprecated.
},
"sql": {
"url": "",
},
"v3io_framesd": "http://framesd:8080",
"datastore": {"async_source_mode": "disabled"},
# default node selector to be applied to all functions - json string base64 encoded format
Expand Down
111 changes: 111 additions & 0 deletions mlrun/datastore/sources.py
Expand Up @@ -19,6 +19,7 @@
from datetime import datetime
from typing import Dict, List, Optional, Union

import pandas as pd
import v3io
import v3io.dataplane
from nuclio import KafkaTrigger
Expand Down Expand Up @@ -858,6 +859,115 @@ def add_nuclio_trigger(self, function):
return func


class SQLSource(BaseSourceDriver):
kind = "sqldb"
support_storey = True
support_spark = False

def __init__(
self,
name: str = "",
chunksize: int = None,
key_field: str = None,
time_field: str = None,
schedule: str = None,
start_time: Optional[Union[datetime, str]] = None,
end_time: Optional[Union[datetime, str]] = None,
db_url: str = None,
table_name: str = None,
spark_options: dict = None,
time_fields: List[str] = None,
):
"""
Reads SqlDB as input source for a flow.
example::
db_path = "mysql+pymysql://<username>:<password>@<host>:<port>/<db_name>"
source = SqlDBSource(
collection_name='source_name', db_path=self.db, key_field='key'
)
:param name: source name
:param chunksize: number of rows per chunk (default large single chunk)
:param key_field: the column to be used as the key for the collection.
:param time_field: the column to be parsed as timestamp for events. Defaults to None
:param start_time: filters out data before this time
:param end_time: filters out data after this time
:param schedule: string to configure scheduling of the ingestion job.
For example '*/30 * * * *' will
cause the job to run every 30 minutes
:param db_url: url string connection to sql database.
If not set, the MLRUN_SQL__URL environment variable will be used.
:param table_name: the name of the collection to access,
from the current database
:param spark_options: additional spark read options
:param time_fields : all the field to be parsed as timestamp.
"""

db_url = db_url or mlrun.mlconf.sql.url
if db_url is None:
raise mlrun.errors.MLRunInvalidArgumentError(
"cannot specify without db_path arg or secret MLRUN_SQL__URL"
)
attrs = {
"chunksize": chunksize,
"spark_options": spark_options,
"table_name": table_name,
"db_path": db_url,
"time_fields": time_fields,
}
attrs = {key: value for key, value in attrs.items() if value is not None}
super().__init__(
name,
attributes=attrs,
key_field=key_field,
time_field=time_field,
schedule=schedule,
start_time=start_time,
end_time=end_time,
)

def to_dataframe(self):
import sqlalchemy as db

query = self.attributes.get("query", None)
db_path = self.attributes.get("db_path")
table_name = self.attributes.get("table_name")
if not query:
query = f"SELECT * FROM {table_name}"
if table_name and db_path:
engine = db.create_engine(db_path)
with engine.connect() as con:
return pd.read_sql(
query,
con=con,
chunksize=self.attributes.get("chunksize"),
parse_dates=self.attributes.get("time_fields"),
)
else:
raise mlrun.errors.MLRunInvalidArgumentError(
"table_name and db_name args must be specified"
)

def to_step(self, key_field=None, time_field=None, context=None):
import storey

attributes = self.attributes or {}
if context:
attributes["context"] = context

return storey.SQLSource(
key_field=self.key_field or key_field,
time_field=self.time_field or time_field,
end_filter=self.end_time,
start_filter=self.start_time,
filter_column=self.time_field or time_field,
**attributes,
)
pass

def is_iterator(self):
return True if self.attributes.get("chunksize") else False


# map of sources (exclude DF source which is not serializable)
source_kind_to_driver = {
"": BaseSourceDriver,
Expand All @@ -869,4 +979,5 @@ def add_nuclio_trigger(self, function):
CustomSource.kind: CustomSource,
BigQuerySource.kind: BigQuerySource,
SnowflakeSource.kind: SnowflakeSource,
SQLSource.kind: SQLSource,
}