In [None]:
!pip install mercury --quiet
## ignore errors

In [None]:
import mercury as mr

## Classes for consuming Knowledge Engine API

In [2]:
"""
  ------------------------------------------
  Classes for consuming Knowledge Engine API
  ------------------------------------------
"""

from datetime import datetime
from typing import List, Optional, Any, Dict, ClassVar
from uuid import UUID

from pydantic import BaseModel, Field
from enum import Enum

"""
  Generic Models for KE Scans API
"""
class Data(BaseModel):
    """Represents the data source for the scan."""
    RESOURCE_TYPE_TABLE: ClassVar[str] = "table"
    RESOURCE_TYPE_DATASET: ClassVar[str] = "dataset"

    resource: str

    @property
    def is_for_table(self) -> bool:
        return self.resource.split('/')[-2][:-1] == self.RESOURCE_TYPE_TABLE

    @property
    def is_for_dataset(self) -> bool:
        return self.resource.split('/')[-2][:-1] == self.RESOURCE_TYPE_DATASET

    @property
    def resource_short_name(self) -> str:
        return self.resource.split('/')[-1]


class OnDemand(BaseModel):
    """Represents an on-demand trigger configuration. Empty in the provided data."""
    pass


class Trigger(BaseModel):
    """Represents the trigger mechanism for a scan."""
    on_demand: OnDemand = Field(..., alias='onDemand')


class ExecutionSpec(BaseModel):
    """Represents the execution specification for a scan."""
    trigger: Trigger


class ExecutionStatus(BaseModel):
    """Represents the execution status of the latest job for a scan."""
    # This field is optional as it's not present in all scan types (e.g., KNOWLEDGE_ENGINE).
    latest_job_start_time: Optional[datetime] = Field(None, alias='latestJobStartTime')
    latest_job_end_time: datetime = Field(..., alias='latestJobEndTime')
    latest_job_create_time: datetime = Field(..., alias='latestJobCreateTime')


class ScanTypeValue(Enum):
    KNOWLEDGE_ENGINE = "KNOWLEDGE_ENGINE"
    DATA_DOCUMENTATION = "DATA_DOCUMENTATION"
    DATA_PROFILE = "DATA_PROFILE"


class DataScan(BaseModel):
    """Represents a single data scan item."""
    name: str
    uid: UUID
    description: str
    display_name: str = Field(..., alias='displayName')
    state: str
    create_time: datetime = Field(..., alias='createTime')
    update_time: datetime = Field(..., alias='updateTime')
    data: Data
    execution_spec: ExecutionSpec = Field(..., alias='executionSpec')
    execution_status: ExecutionStatus = Field(..., alias='executionStatus')
    type: ScanTypeValue

    @property
    def is_for_table(self) -> bool:
        return self.data.is_for_table

    @property
    def is_for_dataset(self) -> bool:
        return self.data.is_for_dataset

    @property
    def resource_short_name(self) -> str:
        return self.data.resource_short_name

    @property
    def resource_name(self) -> str:
        return self.data.resource


class DataScansResponse(BaseModel):
    """The root model for the entire JSON API response."""
    data_scans: List[DataScan] = Field(..., alias='dataScans')


"""
  type KNOWLEDGE_ENGINE models
"""
class KESpec(BaseModel):
    """Represents knowledgeEngineSpec."""
    pass


class ColumnTuple(BaseModel):
    """Represents a fully qualified column used in a join relationship."""
    entry_fqn: str = Field(..., alias='entryFqn', description="Fully qualified name of the BigQuery table.")
    field_path: str = Field(..., alias='fieldPath', description="The name of the column.")


class SchemaRelationship(BaseModel):
    """Defines a join relationship between two sets of columns."""
    left_columns_tuple: List[ColumnTuple] = Field(..., alias='leftColumnsTuple')
    right_columns_tuple: List[ColumnTuple] = Field(..., alias='rightColumnsTuple')
    type: str = Field(..., description="The type of relationship, e.g., 'JOIN'.")


class BusinessTerm(BaseModel):
    """A single term and its definition from the business glossary."""
    title: str
    description: str


class BusinessGlossary(BaseModel):
    """Contains a list of business terms relevant to the dataset."""
    terms: List[BusinessTerm]


class DatasetResult(BaseModel):
    """Contains the description, schema relationships, and glossary for a dataset."""
    description: str
    schema_relationship: List[SchemaRelationship] = Field(..., alias='schemaRelationship')
    business_glossary: BusinessGlossary = Field(..., alias='businessGlossary')


class KEResult(BaseModel):
    """The main result object from a KNOWLEDGE_ENGINE data scan."""
    dataset_result: DatasetResult = Field(..., alias='datasetResult')


class KEScan(DataScan):
    """Represents a KNOWLEDGE_ENGINE data scan."""
    knowledge_engine_spec: Optional[KESpec] = Field(None, alias='knowledgeEngineSpec')
    knowledge_engine_result: KEResult = Field(..., alias='knowledgeEngineResult')

    @property
    def dataset_description(self) -> str:
        return self.knowledge_engine_result.dataset_result.description # shortcut

    @property
    def business_glossary(self) -> BusinessGlossary:
        return self.knowledge_engine_result.dataset_result.business_glossary

    @property
    def schema_relationships(self) -> SchemaRelationship:
        return self.knowledge_engine_result.dataset_result.schema_relationship

"""
  type DATA_DOCUMENTATION generic models
"""

class DDSpec(BaseModel):
    """Represents dataDocumentationSpec."""
    pass


class Query(BaseModel):
    """Represents a single SQL query with its description."""
    sql: str
    description: str



"""
  type DATA_DOCUMENTATION table models
"""
class SchemaField(BaseModel):
    """Represents a single field (column) in a table schema."""
    name: str
    description: str


class Schema(BaseModel):
    """Represents the schema of a table, containing a list of fields."""
    fields: List[SchemaField]


class TableResult(BaseModel):
    """Contains the detailed documentation results for a specific table."""
    overview: str
    the_schema: Schema = Field(alias="schema") # renamed to the_schema to preven collision
    queries: List[Query]
    query_theme: Optional[Dict[str, Any]] = Field(None, alias='queryTheme')


class DDTableResult(BaseModel):
    """The main result object from a DATA_DOCUMENTATION table scan."""
    queries: List[Query]
    overview: str
    the_schema: Schema = Field(alias="schema") # renamed to the_schema to preven collision
    table_result: TableResult = Field(..., alias='tableResult')


class DDTableScan(DataScan):
    """Represents a DATA_DOCUMENTATION data scan."""
    data_documentation_spec: Optional[DDSpec] = Field(None, alias='dataDocumentationSpec')
    data_documentation_result: DDTableResult = Field(..., alias='dataDocumentationResult')

    @property
    def full_table_name(self) -> str:
        parts = self.data.resource.split('/')
        return f"{parts[4]}.{parts[6]}.{parts[8]}"

    @property
    def overview(self) -> str:
        return self.data_documentation_result.table_result.overview # shortcut

    @property
    def fields(self) -> List[SchemaField]:
        return self.data_documentation_result.table_result.the_schema.fields

    @property
    def queries(self) -> List[Query]:
        return self.data_documentation_result.table_result.queries

"""
  type DATA_DOCUMENTATION dataset models
"""
class DDDatasetResult(BaseModel):
    queries: List[Query]

class DDDataDocumentationResult(BaseModel):
    """The main result object from a DATA_DOCUMENTATION dataset scan."""
    queries: List[Query]
    dataset_result: DDDatasetResult = Field(..., alias='datasetResult')


class DDDatasetScan(DataScan):
    """Represents a DATA_DOCUMENTATION dataset scan."""
    data_documentation_spec: Optional[DDSpec] = Field(None, alias='dataDocumentationSpec')
    data_documentation_result: DDDataDocumentationResult = Field(..., alias='dataDocumentationResult')

    @property
    def queries(self) -> List[Query]:
        return self.data_documentation_result.dataset_result.queries



## KEDataScanHelper

### Classes for output from KEDatasetScanHelper

In [3]:
"""
  ------------------------------------------
  Classes for output from KEDatasetScanHelper
  ------------------------------------------
"""
import json

class KEDatasetTable(BaseModel):
    """
    Represents a single table.
    """
    name: str
    overview: str
    fields: List[SchemaField] = Field(..., description="A list of fields in the table.")
    queries: List[Query] = Field(..., description="A list of queries that can be run against the table.")

    @property
    def fields_json(self) -> str:
        full_model = self.model_dump()
        return json.dumps(full_model['fields'])

    @property
    def queries_json(self) -> str:
        full_model = self.model_dump()
        return json.dumps(full_model['queries'])

    @property
    def text_field_descriptions(self) -> str:
        field_descriptions = '```\n'
        for field in self.fields:
            field_descriptions += f"`{field.name}` -- Definition: {field.description}\n"

        field_descriptions += '```'

        return field_descriptions


class KEDatasetRelationship(BaseModel):
    """
    Represents a single relationship between two database tables.
    """
    table1: str = Field(..., description="The name of the first table in the relationship.")
    table2: str = Field(..., description="The name of the second table in the relationship.")
    relationship: str = Field(..., description="The join condition that defines the relationship.")
    source: str = Field(..., description="The source that inferred or defined this relationship.")

class KEDatasetDetails(BaseModel):
    """
    Represents the detailed documentation results for a specific dataset.
    """
    project_id: str = Field(..., description="Project ID of the dataset.")
    dataset_name: str = Field(..., description="Name of the dataset")
    dataset_location: str = Field(..., description="Location of the dataset.")
    dataset_description: str = Field(..., description="A brief overview of the dataset.")
    dataset_relationships: List[KEDatasetRelationship] = Field(..., description="A list of table relationships.")
    dataset_queries: List[Query] = Field(..., description="A list of queries that can be run against the dataset.")
    dataset_business_glossary: List[BusinessTerm] = Field(..., description="A list of business glossary terms.")
    dataset_tables: List[KEDatasetTable] = Field(..., description="A list of tables in the dataset.")

    @property
    def dataset_relationships_json(self) -> str:
        full_model = self.model_dump()
        return json.dumps(full_model['dataset_relationships'])

    @property
    def dataset_queries_json(self) -> str:
        full_model = self.model_dump()
        return json.dumps(full_model['dataset_queries'])

    @property
    def dataset_glossary_terms_json(self) -> str:
        full_model = self.model_dump()
        return json.dumps(full_model['dataset_business_glossary'])

### Helper Authentication

In [4]:
import requests, re

from google.cloud import bigquery
from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials
import google.auth

class KEAuth:

    def __init__(self):
        self.__credentials = None
        self.__project = None

    def __refresh_credentials(self):
        if not self.__credentials:
            self.__credentials, __project = google.auth.default()

        if not self.__credentials.valid:
            self.__credentials.refresh(Request())

        return self

    def __get_headers(self):
        self.__refresh_credentials()
        return {
          "Authorization": f"Bearer {self.__credentials.token}",
          "Content-Type": "application/json"
        }

    def get_url_content(self, url: str):
        try:
            response = requests.get(url,headers=self.__get_headers())
            response.raise_for_status()  # Raise an HTTPError for bad responses (4xx or 5xx)
            return response.text
        except requests.exceptions.RequestException as e:
            print(f"Error fetching URL {url}: {e}")
            raise e

### KEDatasetScanHelper

In [5]:

"""
  ------------------------------------------
  KEDatasetScanHelper
  ------------------------------------------
"""

class KEDatasetScanHelper(KEAuth):
    """A helper for interacting with the Knowledge Engine API."""
    DATAPLEX_BASE_URL = "https://dataplex.googleapis.com/v1"
    DATAPLEX_LIST_SCANS_URL = DATAPLEX_BASE_URL + "/projects/{project_id}/locations/{location}/dataScans"
    TABLE_RESOURCE_TEMPLATE = "//bigquery.googleapis.com/projects/{project_id}/datasets/{dataset_name}/tables/{table_name}"

    def __init__(self, project_id: str, dataset_name: str):
        super().__init__()
        self.dataset_name = dataset_name
        self.project_id = project_id
        self.__dataset_location = None
        self.__tables = []
        self.__data_scans = []
        self.__allowlist_tables = set()
        self.__blocklist_tables = set()

    def _flush(self):
        self.__tables.clear()
        self.__data_scans.clear()
        self.__allowlist_tables.clear()
        self.__blocklist_tables.clear()

    def _table_is_allowed(self, table: str) -> bool:
        return self._is_in_allowlist(table) and not self._is_in_blocklist(table)

    def _is_in_allowlist(self, table: str) -> bool:
        if not self.__allowlist_tables:
            return True

        return table in self.__allowlist_tables

    def _is_in_blocklist(self, table: str) -> bool:
        if not self.__blocklist_tables:
            return False

        return table in self.__blocklist_tables

    def _get_scans_of_interest(self) -> list:
        scan_url = self.DATAPLEX_LIST_SCANS_URL.format(
            base_url=self.DATAPLEX_BASE_URL,
            project_id=self.project_id,
            location=self.dataset_location
        )

        try:
            response = self.get_url_content(scan_url)
        except Exception as e:
            print(f"Error fetching data scans: {e}")
            raise e

        try:
            scans = json.loads(response)
        except json.JSONDecodeError as e:
            print(f"Error decoding JSON response: {e}")
            raise e

        # Limit the scans to items in the requested dataset (per constructor)
        ds_test_string = f"/datasets/{self.dataset_name}"
        table_test_string = f"{ds_test_string}/tables/"

        scans_of_interest = []
        for scan in scans.get('dataScans', []):
            if scan.get('data') and scan.get('data').get('resource'):
                resource = scan.get('data').get('resource')

                if resource.endswith(ds_test_string) or table_test_string in resource:
                    new_scan = DataScan(**scan)

                    if new_scan.is_for_table:
                        if self._table_is_allowed(new_scan.resource_short_name):
                            scans_of_interest.append(new_scan)

                    if new_scan.is_for_dataset:
                        scans_of_interest.append(new_scan)

        return scans_of_interest

    def set_table_list_constraints(self, allowlist: list = [], blocklist: list = []):
        overlap = list(set(allowlist).intersection(set(blocklist)))
        if overlap:
            raise ValueError(f"Allowlist and blocklist cannot contain the same items: {overlap}")

        def table_name_to_resource(table_name: str) -> str:
            return self.TABLE_RESOURCE_TEMPLATE.format(
                project_id=self.project_id,
                dataset_name=self.dataset_name,
                table_name=table_name
            )

        self._flush()
        self.__allowlist_tables.update(map(table_name_to_resource, allowlist))
        self.__blocklist_tables.update(map(table_name_to_resource, blocklist))

        return self

    @property
    def dataset_location(self) -> str:
        if not self.__dataset_location:
            client = bigquery.Client()
            dataset = client.get_dataset(f'{self.project_id}.{self.dataset_name}')
            self.__dataset_location = dataset.location

        return self.__dataset_location

    @property
    def dataplex_scans(self) -> list:
        if not self.__data_scans:
            scans = self._get_scans_of_interest()

            for scan in scans:
                full_scan_url = f"{self.DATAPLEX_BASE_URL}/{scan.name}?view=FULL"

                try:
                    response = self.get_url_content(full_scan_url)
                except Exception as e:
                    print(f"Error fetching data scans: {e}")
                    raise e

                try:
                    full_view_scan = json.loads(response)
                except json.JSONDecodeError as e:
                    print(f"Error decoding JSON response: {e}")
                    raise e

                new_scan = None

                if scan.type == ScanTypeValue.KNOWLEDGE_ENGINE:
                    new_scan = KEScan(**full_view_scan)

                if scan.type == ScanTypeValue.DATA_DOCUMENTATION:
                    if scan.is_for_table:
                        new_scan = DDTableScan(**full_view_scan)

                    if scan.is_for_dataset:
                        new_scan = DDDatasetScan(**full_view_scan)

                if new_scan:
                  self.__data_scans.append(new_scan)

        return self.__data_scans

    @property # dataset knowledge engine scan
    def dataset_ke_scan(self) -> KEScan:
        for scan in self.dataplex_scans:
            if isinstance(scan, KEScan):
                return scan

    @property # dataset data documentation scan
    def dataset_dd_scan(self) -> DDDatasetScan:
        for scan in self.dataplex_scans:
            if isinstance(scan, DDDatasetScan):
                return scan

    @property
    def dataset_description(self) -> str:
        return self.dataset_ke_scan.dataset_description

    @property
    def dataset_tables(self) -> List[KEDatasetTable]:
        tables = []
        for scan in self.dataplex_scans:
            if isinstance(scan, DDTableScan):
                if self._table_is_allowed(scan.resource_name):
                    tables.append(KEDatasetTable(**{
                        "name": scan.full_table_name,
                        "overview": scan.overview,
                        "fields": scan.fields,
                        "queries": scan.queries
                    }))

        return tables

    @property
    def dataset_queries(self) -> List[Query]:
        return self.dataset_dd_scan.queries

    @property
    def dataset_business_glossary(self) -> List[BusinessTerm]:
        return self.dataset_ke_scan.business_glossary.terms

    @property
    def dataset_relationships(self) -> List[KEDatasetRelationship]:
        """
          This will require update when the relation representation becomes more complex.
          Currently should handle multple anded = conditions between left and right side.
        """
        project_dataset = self.project_id + '.' + self.dataset_name

        return_relationships = []

        relationships = self.dataset_ke_scan.schema_relationships
        for relationship in relationships:

          left_tuples = relationship.left_columns_tuple
          table1_fqn = left_tuples[0].entry_fqn
          table1_sql_name = f"{project_dataset}.{table1_fqn.split('/')[-1]}"
          if not self._table_is_allowed(table1_fqn):
              continue

          right_tuples = relationship.right_columns_tuple
          table2_fqn = right_tuples[0].entry_fqn
          table2_sql_name = f"{project_dataset}.{table2_fqn.split('/')[-1]}"
          if not self._table_is_allowed(table2_fqn):
              continue

          join_conditions = []

          for i, left_item in enumerate(left_tuples):
              right_item = right_tuples[i]
              new_join_condition = table1_sql_name + '.' + left_item.field_path
              new_join_condition += ' = '
              new_join_condition += table2_sql_name + '.' + right_item.field_path
              join_conditions.append(new_join_condition)

          return_relationships.append(KEDatasetRelationship(**{
              'table1': table1_sql_name,
              'table2': table2_sql_name,
              'relationship': ' AND '.join(join_conditions),
              'source': 'LLM-inferred'
          }))

        return return_relationships

    @property
    def dataset_all_details(self) -> KEDatasetDetails:
        return KEDatasetDetails(**{
            "project_id": self.project_id,
            "dataset_name": self.dataset_name,
            "dataset_location": self.dataset_location,
            "dataset_description": self.dataset_description,
            "dataset_relationships": self.dataset_relationships,
            "dataset_queries": self.dataset_queries,
            "dataset_business_glossary": self.dataset_business_glossary,
            "dataset_tables": self.dataset_tables
        })

## Testing

In [None]:
ke_helper = KEDatasetScanHelper('ai-learning-agents', 'thelook')
ds_details = ke_helper.dataset_all_details

print(ds_details.model_dump_json())

In [None]:
# Display with Mercury
app = mr.App(title="Display notebook", static_notebook=True)
mr.JSON(ds_details.model_dump_json())