In [1]:
import json
from hashlib import md5
from collections import deque
from typing import Dict, Tuple, Union, List

from moz_sql_parser import parse, ParseException
import networkx as nx
from pyvis.network import Network

In [2]:
query = parse("""
with scores_entities as (
    select
        score_type,
        score_id,
        score_name
    from "onefootball"."analytics"."stg_scores__entities"
),
articles as (
    select
        language_code,
        article_id,
        publish_time,
        title
    from
        "onefootball"."analytics_datasets"."train_set_tagging"
),
per_language_count as (
    select
        language_code,
        count(article_id) as per_language_count
    from articles
    group by 1
),
article_entities_counts as (
    select
        language as language_code,
        score_type,
        score_id,
        count(feed_item_id) as per_entity_count
    from "onefootball"."analytics"."stg_cms__feed_item_streams" as fis
    join articles on fis.feed_item_id = articles.article_id
    where
        score_type in ('team', 'competition', 'player')
    group by 1, 2, 3
)
select
    language_code,
    score_type,
    score_id,
    score_name,
    per_language_count,
    per_entity_count,
    log((1 + per_language_count) / (1 + per_entity_count)) + 1 as idf
from article_entities_counts
join per_language_count using (language_code)
join scores_entities using (score_type, score_id)
""")

In [5]:
def find_col(query: Dict, col_name: str) -> Tuple[Dict, Union[str, List[Union[str, Dict]]]]:
    """Get matching colum from query
    
    Arguments:
        query: moz-sql-parser query with 'select' key and 'from' key
        col_name: col we want to get lineage info from
        
    Returns:
        lineage info of column
    """
    matching_cols = [col for col in query['select'] if col.get('name', col['value']) == col_name]
    # TODO: handle '*'
    if len(matching_cols):
        return (matching_cols[0], query['from'])
    else:
        return None

    
def slice_ctes(query: Dict, cte_name: str):
    return {w['name']: w['value'] for w in query['with']}.get(cte_name)


class LineagePoint():
    def __init__(self, x, y):
        self.x: Dict = x
        self.y: Union[str, Dict, List[Union[str, Dict]]] = y

    @property
    def col_name(self):
        return self.x.get('name', self.x['value'])

    @property
    def table_name(self):
        if type(self.y) == list:
            # more than 1 table
            return md5(json.dumps(self.y).encode("utf-8")).hexdigest()
        elif type(self.y) == str:
            return self.y
        elif type(self.y) == dict:
            # TODO: check when JOIN + rename
            # TODO: check for left, inner and right JOIN
            return self.y.get('join') or self.y.get('value')

    def __repr__(self):
        return f'{self.table_name}.{self.col_name}'

In [7]:
# nodes = []
edges = []

# initialise the deque https://docs.python.org/3/library/collections.html#collections.deque
root = LineagePoint(*find_col(query, 'language_code'))
# nodes.append(str(root))
to_be_parsed = deque([root])

while len(to_be_parsed):
    lp = to_be_parsed.popleft()
    print(lp)
    
    if type(lp.y) == list:
        # flatten
        new_lps = [LineagePoint(lp.x, tbl) for tbl in lp.y]
        to_be_parsed.extend(new_lps)
        edges.extend([(str(lp), str(nlp)) for nlp in new_lps])
        continue
    else:
        q = slice_ctes(query, lp.table_name)
        if q:
            found_col = find_col(
                query=q,
                col_name=lp.col_name
            )
            if found_col:
                new_lp = LineagePoint(*found_col)
                to_be_parsed.append(new_lp)
#                 nodes.append(str(new_lp))
                edges.append((str(lp), str(new_lp)))
                continue
        else:
            # no CTE with the table_name was found
            continue

1c7a0d3e5990366b820db55a43c6c179.language_code
article_entities_counts.language_code
per_language_count.language_code
scores_entities.language_code
84ae4007fef5247e919e3c50e7e8c424.language_code
articles.language_code
onefootball.analytics.stg_cms__feed_item_streams.language_code
articles.language_code
onefootball.analytics_datasets.train_set_tagging.language_code
onefootball.analytics_datasets.train_set_tagging.language_code


In [13]:
G = nx.DiGraph()
G.add_edges_from(edges)

nt = Network("500px", "1000px", notebook=True, directed=True)
nt.toggle_physics(False)
nt.from_nx(G)
nt.show("nx.html")