In [1]:
import datetime as dt
import gettsim as gt
import pandas as pd
import numpy as np
import inspect
import functools
from gettsim.config import ROOT_DIR
from gettsim.tests.test_zu_versteuerndes_eink import INPUT_COLS

In [2]:
df = pd.read_csv(ROOT_DIR / "tests" / "test_data" / "test_dfs_zve.csv", usecols=INPUT_COLS).query("jahr == 2018")

In [3]:
policy_date = dt.date(2018, 1, 1)
params_dict, policy_func_dict = gt.get_policies_for_date(
    policy_date=policy_date,
    groups=["eink_st_abzuege", "soz_vers_beitr", "kindergeld"],
)

In [4]:
user_columns = [
    "ges_krankenv_beitr_m",
    "arbeitsl_v_beitr_m",
    "pflegev_beitr_m",
    "rentenv_beitr_m",
]

In [5]:
result, dag = gt.compute_taxes_and_transfers(
    df,
    user_columns=user_columns,
    user_functions=policy_func_dict,
    targets="sum_brutto_eink",
    params=params_dict,
    return_dag=True
)

In [6]:
import networkx as nx

from bokeh.io import output_file, show, output_notebook
from bokeh.models import (BoxZoomTool, Circle, HoverTool,
                          MultiLine, Plot, Range1d, ResetTool,)
from bokeh.palettes import Spectral4
from bokeh.plotting import from_networkx

In [7]:
def safe_pydot_layout(dag):
    """Create a layout with pydot.
    
    In contrast to the 
    
    """
    # Get the pydot layout by converting nodes to integers, creating the
    # layout with the dag with integer nodes, and relabel the layout.
    dag_w_integer_nodes = nx.relabel.convert_node_labels_to_integers(dag)
    integer_layout = nx.drawing.nx_pydot.pydot_layout(dag_w_integer_nodes, prog="dot")
    integer_to_labels = dict(zip(dag_w_integer_nodes.nodes, dag.nodes))
    label_layout = {integer_to_labels[i]: np.array(integer_layout[i]) for i in integer_to_labels}
    
    # Convert coordinates to unit cube which seems necessary for bokeh.
    max_x = max([i[0] for i in label_layout.values()])
    max_y = max([i[1] for i in label_layout.values()])
    
    for k, v in label_layout.items():
        label_layout[k] = v / (max_x, max_y) * 2 - 1
    
    return label_layout

In [8]:
def replace_functions_with_source_code(dag):
    for node in dag.nodes:
        if "function" in dag.nodes[node]:
            function = dag.nodes[node].pop("function")
            if isinstance(function, functools.partial):
                source = inspect.getsource(function.func)
            else:
                source = inspect.getsource(function)
            dag.nodes[node]["source_code"] = source.replace("\n", "<br>").replace("    ", "&#09;")
        else:
            dag.nodes[node]["source_code"] = "Column in data"
            
    return dag

In [9]:
dag = replace_functions_with_source_code(dag)

In [12]:
TOOLTIPS = """
column: @index <br>
source code: <br><br>@source_code{safe}
"""

plot = Plot(plot_width=800, plot_height=800,
            x_range=Range1d(-1.1, 1.1), y_range=Range1d(-1.1, 1.1))
plot.title.text = "Tax and Transfer System - 2018 - Target: sum_brutto_eink"

node_hover_tool = HoverTool(tooltips=TOOLTIPS)

plot.add_tools(node_hover_tool, BoxZoomTool(), ResetTool())

graph_renderer = from_networkx(dag, safe_pydot_layout(dag), scale=1, center=(0, 0))

graph_renderer.node_renderer.glyph = Circle(size=15, fill_color=Spectral4[0])
graph_renderer.edge_renderer.glyph = MultiLine(line_color="red", line_alpha=0.8, line_width=1)
plot.renderers.append(graph_renderer)

output_notebook()
show(plot)