In [None]:
from typing import List, Dict

from lattedb.config.settings import PROJECT_APPS, GRAPH_MODELS
from django_extensions.management.modelviz import ModelGraph


from bokeh.plotting import figure, show
from bokeh.io import output_notebook
from bokeh.models.glyphs import Text
from bokeh.models import ColumnDataSource, LabelSet

import networkx as nx
from networkx.drawing.nx_agraph import write_dot

In [None]:
%load_ext blackcellmagic

In [None]:
def get_graph_data(
    all_applications=GRAPH_MODELS.get("all_applications"),
    exclude_models=GRAPH_MODELS.get("exclude_models"),
    exclude_columns=GRAPH_MODELS.get("exclude_columns"),
    **kwargs,
) -> List[Dict]:
    """Returns graph data for all apps and models using `django_extensions` api.
    """
    model_graph = ModelGraph(
        [],
        all_applications=all_applications,
        exclude_models=exclude_models,
        exclude_columns=exclude_columns,
        **kwargs,
    )
    model_graph.process_apps()
    data = model_graph.get_graph_data(as_json=True)

    return data["graphs"]

In [None]:
from bokeh.models import (
    Plot,
    Range1d,
    MultiLine,
    Circle,
    HoverTool,
    BoxZoomTool,
    ResetTool,
    PanTool,
    CustomJSHover,
    TapTool,
    BoxSelectTool,
    CustomJS,
)
from bokeh.models.graphs import from_networkx, NodesAndLinkedEdges, EdgesAndLinkedNodes
from bokeh.palettes import Spectral4


In [None]:
output_notebook()

In [None]:
lattedb_data= get_graph_data()

In [None]:
G = nx.DiGraph()

edges = []
node_map = {}
node_groups = {}
bases = []

counter = 1
for app_data in lattedb_data:
    app_name = app_data["app_name"]

    model_names = []
    for model in app_data["models"]:

        node_groups.setdefault(app_name, []).append(model["app_name"])
        model_name = model["name"]
        model_names.append(model_name)

        for relation in model.get("relations", []):

            target_app = relation.get("target_app")
            target_model = relation.get("target")

            column = relation.get("name")
            edges.append(
                {
                    "app_name": app_name,
                    "model": model_name,
                    "target_cluster_name": target_app,
                    "target_model": target_model,
                    "column": column,
                }
            )

    tooltip = "<h2>models:</h2> <TABLE><li>" + "</li><li>".join(model_names) + "</ul></TABLE>"
    G.add_node(
        counter,
        label="<"
        + "<table border='0' cellborder='0' cellspacing='1'>"
        + "<tr><td align='left'><b>"+app_name+"</b></td></tr>"
        + "<tr><td align='left'>" 
        + "</td></tr><tr><td align='left'>".join(model_names)
        + "</td></tr></table>>",
        shape="plaintext"
    )
    node_map[app_name] = counter
    counter += 1


for edge in edges:
    start = node_map.get(edge["app_name"])
    end = None

    for app_label, group in node_groups.items():
        if edge["target_cluster_name"] in group:
            end = node_map.get(app_label)
            break

    if start is not None and end is not None and start != end:
        G.add_edge(start, end, tooltip="'"+edge["column"] + "'")  # , weight=relation["weight"])


In [None]:
write_dot(G, "test.dot")

In [None]:
!atom test.dot

In [None]:
plot = Plot(
    plot_width=800,
    plot_height=800,
    x_range=Range1d(-1.1, 1.1),
    y_range=Range1d(-1.1, 1.1),
)
plot.title.text = "Graph Interaction Demonstration"

node_hover_tool = HoverTool(tooltips=[("models", "<p>@models</p>"), ("column", "@column")], line_policy='next')

plot.add_tools(node_hover_tool, BoxZoomTool(), ResetTool(), PanTool(),BoxSelectTool())

graph_renderer = from_networkx(G, nx.spring_layout, k=2, scale=1)

graph_renderer.node_renderer.glyph = Circle(size=80, fill_color="white", line_color="white")
graph_renderer.node_renderer.selection_glyph = Circle(size=80, fill_color=Spectral4[2])

graph_renderer.edge_renderer.glyph = MultiLine(line_alpha=0.8, line_width=1)
graph_renderer.edge_renderer.hover_glyph = MultiLine(line_color=Spectral4[1], line_width=2)

graph_renderer.selection_policy = NodesAndLinkedEdges()
graph_renderer.inspection_policy = EdgesAndLinkedNodes()

plot.renderers.append(graph_renderer)

In [None]:
x, y = zip(*graph_renderer.layout_provider.graph_layout.values())
node_labels = nx.get_node_attributes(G, "module")
source = ColumnDataSource({"x": x, "y": y, "module": list(node_labels.values())})
labels = LabelSet(
    x="x",
    y="y",
    text="module",
    source=source,
    background_fill_color="white",
    border_line_color=None,
    text_baseline="middle",
    text_align="center",
)
plot.renderers.append(labels)


In [None]:
node_labels

In [None]:
node_labels

In [None]:
show(plot)