Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingest/superset): add datasets ingestion #10592

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 283 additions & 2 deletions metadata-ingestion/src/datahub/ingestion/source/superset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
EnvConfigMixin,
PlatformInstanceConfigMixin,
)
import datahub.emitter.mce_builder as builder
from datahub.emitter.mce_builder import (
make_chart_urn,
make_dashboard_urn,
make_dataset_urn,
make_domain_urn,
)
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.emitter.mcp_builder import add_domain_to_entity_wu
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.decorators import (
Expand All @@ -31,6 +33,7 @@
)
from datahub.ingestion.api.source import MetadataWorkUnitProcessor, Source
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.common.subtypes import DatasetSubTypes
from datahub.ingestion.source.sql.sqlalchemy_uri_mapper import (
get_platform_from_sqlalchemy_uri,
)
Expand All @@ -48,23 +51,117 @@
ChangeAuditStamps,
Status,
)
from datahub.metadata.com.linkedin.pegasus2avro.dataset import DatasetProperties
from datahub.metadata.com.linkedin.pegasus2avro.metadata.snapshot import (
ChartSnapshot,
DashboardSnapshot,
DatasetSnapshot,
)
from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeEvent
from datahub.metadata.com.linkedin.pegasus2avro.schema import (
ArrayTypeClass,
BooleanTypeClass,
DateTypeClass,
NullTypeClass,
NumberTypeClass,
OtherSchema,
SchemaField,
SchemaFieldDataType,
SchemaMetadata,
StringTypeClass,
TimeTypeClass,
)
from datahub.metadata.schema_classes import (
ChartInfoClass,
ChartTypeClass,
DashboardInfoClass,
DashboardSnapshotClass,
DatasetLineageTypeClass,
GlobalTagsClass,
MetadataChangeEventClass,
OwnerClass,
OwnershipClass,
OwnershipTypeClass,
SubTypesClass,
TagAssociationClass,
UpstreamClass,
UpstreamLineageClass,
ViewPropertiesClass,
)
f
from datahub.utilities import config_clean
from datahub.utilities.registries.domain_registry import DomainRegistry

logger = logging.getLogger(__name__)

PAGE_SIZE = 25

FIELD_TYPE_MAPPING = {
"BIGINT": NumberTypeClass,
"INT8": NumberTypeClass,
"BIGSERIAL": NumberTypeClass,
"SERIAL8": NumberTypeClass,
"BIT": StringTypeClass,
"VARBIT": StringTypeClass,
"BIT VARYING": StringTypeClass,
"BOOLEAN": BooleanTypeClass,
"BOOL": BooleanTypeClass,
"BOX": NullTypeClass,
"CHARACTER": StringTypeClass,
"CHAR": StringTypeClass,
"CHARACTER VARYING": StringTypeClass,
"VARCHAR": StringTypeClass,
"CIDR": StringTypeClass,
"CIRCLE": NullTypeClass,
"DATE": DateTypeClass,
"DOUBLE PRECISION": NumberTypeClass,
"FLOAT8": NumberTypeClass,
"INET": StringTypeClass,
"INTEGER": NumberTypeClass,
"INT": NumberTypeClass,
"INT4": NumberTypeClass,
"INTERVAL": StringTypeClass,
"JSON": StringTypeClass,
"JSONB": StringTypeClass,
"LINE": NullTypeClass,
"LSEG": NullTypeClass,
"MACADDR": StringTypeClass,
"MACADDR8": StringTypeClass,
"MONEY": NumberTypeClass,
"NUMERIC": NumberTypeClass,
"DECIMAL": NumberTypeClass,
"PATH": NullTypeClass,
"PG_LSN": StringTypeClass,
"POINT": NullTypeClass,
"POLYGON": NullTypeClass,
"REAL": NumberTypeClass,
"FLOAT4": NumberTypeClass,
"SMALLINT": NumberTypeClass,
"INT2": NumberTypeClass,
"SMALLSERIAL": NumberTypeClass,
"SERIAL2": NumberTypeClass,
"SERIAL": NumberTypeClass,
"SERIAL4": NumberTypeClass,
"STRING": StringTypeClass,
"TEXT": StringTypeClass,
"TIME": TimeTypeClass,
"TIME WITHOUT TIME ZONE": TimeTypeClass,
"TIME WITH TIME ZONE": TimeTypeClass,
"TIMETZ": TimeTypeClass,
"TIMESTAMP": TimeTypeClass,
"TIMESTAMP WITHOUT TIME ZONE": TimeTypeClass,
"TIMESTAMP WITH TIME ZONE": TimeTypeClass,
"TIMESTAMPTZ": TimeTypeClass,
"TSQUERY": StringTypeClass,
"TSVECTOR": StringTypeClass,
"TXID_SNAPSHOT": StringTypeClass,
"UUID": StringTypeClass,
"XML": StringTypeClass,
"ARRAY": ArrayTypeClass,
"HSTORE": StringTypeClass,
"RANGE": ArrayTypeClass,
"UNKNOWN": NullTypeClass,
}

chart_type_from_viz_type = {
"line": ChartTypeClass.LINE,
Expand Down Expand Up @@ -110,6 +207,15 @@ class SupersetConfig(
provider: str = Field(default="db", description="Superset provider.")
options: Dict = Field(default={}, description="")

extract_datasets: bool = Field(
False,
description="When enabled, extracts Superset datasets",
)
use_superset_platform: bool = Field(
default=False,
description="Set to true to always use 'superset' as the platform for all datasets.",
)

# TODO: Check and remove this if no longer needed.
# Config database_alias is removed from sql sources.
database_alias: Dict[str, str] = Field(
Expand Down Expand Up @@ -151,6 +257,26 @@ def get_filter_name(filter_obj):
comparator = filter_obj.get("comparator")
return f"{clause} {column} {operator} {comparator}"

def get_dataset_id_from_metadata(json_data):
try:
data = eval(json_data)
json_metadata = data.get('json_metadata', '{}')
d_metadata = json.loads(json_metadata)
native_filters = d_metadata.get('native_filter_configuration',[])
dataset_ids = []
for native_filter in native_filters:
if 'targets' in native_filter:
targets = native_filter['targets']
for target in targets:
dataset_id = target.get("datasetId")
if dataset_id is not None:
dataset_ids.append(dataset_id)
return dataset_ids
except json.JSONDecodeError:
print("Error parsing JSON")
# print(json_data)
return []


@platform_name("Superset")
@config_class(SupersetConfig)
Expand Down Expand Up @@ -246,10 +372,16 @@ def get_datasource_urn_from_id(self, datasource_id):
dataset_response.get("result", {}).get("database", {}).get("database_name")
)
database_name = self.config.database_alias.get(database_name, database_name)


platform_nm = "unknown"
if self.config.use_superset_platform:
platform_nm = "superset"
else:
platform_nm = self.get_platform_from_database_id(database_id)

if database_id and table_name:
return make_dataset_urn(
platform=self.get_platform_from_database_id(database_id),
platform=platform_nm,
name=".".join(
name for name in [database_name, schema_name, table_name] if name
),
Expand Down Expand Up @@ -469,10 +601,159 @@ def emit_chart_mces(self) -> Iterable[MetadataWorkUnit]:
title=chart_data.get("slice_name", ""),
entity_urn=chart_snapshot.urn,
)
def get_schema_metadata(self, json_payload: dict) -> Optional[SchemaMetadata]:
columns = json_payload.get("result").get("columns")

fields = []
for field in columns:
column_name = field.get("column_name")
if column_name is None:
# Skip field as it has no name
continue

native_data_type = field.get("type", "UNKNOWN") if field.get("type") else "UNKNOWN"
type_class = FIELD_TYPE_MAPPING.get(native_data_type, NullTypeClass)

global_tags = None
calculated = "Calculated" if field.get("expression") else None
if calculated:
global_tags = GlobalTagsClass(tags=[
TagAssociationClass(f"urn:li:tag:Calculated")
])

schema_field = SchemaField(fieldPath=column_name,
type=SchemaFieldDataType(type=type_class()),
nativeDataType=native_data_type,
description=field.get("description"),
globalTags=global_tags,
)
fields.append(schema_field)

schema_metadata = SchemaMetadata(
schemaName=json_payload.get("result", {}).get("datasource_name", "unknown"),
platform=f"urn:li:dataPlatform:{self.platform}",
version=0,
fields=fields,
hash="",
platformSchema=OtherSchema(rawSchema=""),
)

return schema_metadata

def _get_ownership(self, user: str) -> Optional[OwnershipClass]:
if user:
owner_urn = make_user_urn(user)
ownership: OwnershipClass = OwnershipClass(
owners=[
OwnerClass(
owner=owner_urn,
type=OwnershipTypeClass.DATAOWNER,
)
]
)
return ownership

return None

def construct_dataset_snapshot(self, dataset_data: dict) -> DatasetSnapshot:
dataset_id = dataset_data.get("id")
dataset_name = dataset_data.get("table_name")
schema_name = dataset_data.get("schema", "unknown")
database_name = dataset_data.get("database").get("database_name")
dataset_type = dataset_data.get("kind")
dataset_query = dataset_data.get("sql")
owners = dataset_data.get("owners")

dataset_urn = self.get_datasource_urn_from_id(dataset_id)

dataset_status = Status(removed=False)

dataset_properties = DatasetProperties(description=dataset_data.get("description", ""),
customProperties={"type": dataset_type,
"database_name": database_name,
"schema_name": schema_name if schema_name else "unknown",
"owners": ", ".join(
map(
lambda owner: owner.get("username", "unknown"),
dataset_data.get("owners", []),
)
),
},
tags=[],
)

dataset_snapshot = DatasetSnapshot(urn=dataset_urn,
aspects=[dataset_properties, dataset_status]
)

usernames = ", ".join(map(lambda owner: owner.get("username", "unknown"), dataset_data.get("owners", []), ))
owner = (
self._get_ownership(usernames)
if usernames
else None
)
if owner is not None:
dataset_snapshot.aspects.append(owner)

if dataset_type == "virtual":
view_properties = ViewPropertiesClass(
materialized=False,
viewLanguage="SQL",
viewLogic=UrnEncoder.encode_string(dataset_query),
)
dataset_snapshot.aspects.append(view_properties)

subtypes = SubTypesClass(typeNames=[DatasetSubTypes.VIEW])

data_response = self.session.get(
f"{self.config.connect_uri}/api/v1/dataset/{dataset_id}"
)
if data_response.status_code != 200:
logger.warning(f"Couldn't get dataset data: {data_response.text}")
data_response.raise_for_status()

payload = data_response.json()
schema_metadata = self.get_schema_metadata(payload)
if schema_metadata is not None:
dataset_snapshot.aspects.append(schema_metadata)

return dataset_snapshot

def emit_dataset_mces(self) -> Iterable[MetadataWorkUnit]:
current_dataset_page = 0

total_datasets = PAGE_SIZE

while current_dataset_page * PAGE_SIZE <= total_datasets:
dataset_response = self.session.get(
f"{self.config.connect_uri}/api/v1/dataset/",
params=f"q=(page:{current_dataset_page},page_size:{PAGE_SIZE})"
)
if dataset_response.status_code != 200:
logger.warning(f"Couldn't get dataset data: {dataset_response.text}")
dataset_response.raise_for_status()

current_dataset_page += 1

payload = dataset_response.json()
total_datasets = payload["count"]
for dataset_data in payload["result"]:
dataset_snapshot = self.construct_dataset_snapshot(dataset_data)

mce = MetadataChangeEvent(proposedSnapshot=dataset_snapshot)
yield MetadataWorkUnit(id=dataset_snapshot.urn, mce=mce)

if dataset_data.get("kind") == "virtual":
yield MetadataChangeProposalWrapper(
entityUrn=dataset_snapshot.urn,
aspect=SubTypesClass(typeNames=[DatasetSubTypes.VIEW]),
).as_workunit()

def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
yield from self.emit_dashboard_mces()
yield from self.emit_chart_mces()
if self.config.extract_datasets:
yield from self.emit_dataset_mces()

def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]:
return [
Expand Down