In [1]:
import dataclasses
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import List

from ruamel.yaml import YAML
from IPython.display import display, JSON
import ipywidgets as widgets
import plotly.graph_objects as go

from openfisca_core.simulation_builder import SimulationBuilder
from openfisca_france import FranceTaxBenefitSystem

In [2]:
tax_benefit_system = FranceTaxBenefitSystem()

In [3]:
def count_to_step(min, max, count):
    """Examples:
    >>> count_to_step(0, 80, 5)
    20
    """
    return float(max - min) / (count - 1)

def value_to_index(min, step, value):
    """Examples:
    >>> value_to_index(0, 10, 0)
    0
    >>> value_to_index(0, 10, 40)
    4
    >>> value_to_index(3, 1, 6)
    3
    """
    return int((value / step) - min)

In [4]:
COUNT = 20
MIN = 0
MAX = 500000
STEP = count_to_step(MIN, MAX, COUNT)
INITIAL_VALUE = 0
PERIOD = "2019"

In [5]:
test_case = {
    "individus": {
        "Michel": {
            'date_naissance': {'ETERNITY': '1980-01-01'},
        },
    },
    "familles": {
        "famille_1": {
            "parents": ["Michel"]
        }
    },
    "foyers_fiscaux": {
        "foyer_fiscal_1": {
            "declarants": ["Michel"],
        },
    },
    "menages": {
        "menage_1": {
            "personne_de_reference": ["Michel"],
        }},
    "axes": [[
        {
            "name": 'salaire_de_base',
            "count": COUNT,
            "min": MIN,
            "max": MAX,
            "period": PERIOD,
        },
    ]],
}


In [6]:
simulation_builder = SimulationBuilder()
simulation = simulation_builder.build_from_entities(tax_benefit_system, test_case)

In [7]:
salaire_de_base = simulation.calculate_add("salaire_de_base", PERIOD)

In [8]:
revenu_disponible = simulation.calculate("revenu_disponible", PERIOD)

In [10]:
yaml = YAML(typ='safe')
decomposition = yaml.load(Path("./decomposition.yaml"))

In [11]:
class BarType(Enum):
    VALUE = 1
    SUB_TOTAL = 2
    TOTAL = 3

@dataclass
class Bar:
    code: str
    type: BarType
    name: str = None
    value: float = None

In [12]:
def get_bars(decomposition: dict) -> List[Bar]:
    def visit(node):
        children = node.get("children")
        if children:
            for child in children:
                yield from visit(child)
            type_ = BarType.TOTAL if node["code"] == decomposition["code"] else BarType.SUB_TOTAL
            yield Bar(code=node["code"], name=node.get("name"), type=type_)
        else:
            yield Bar(code=node["code"], name=node.get("name"), type=BarType.VALUE)

    return list(visit(decomposition))

In [13]:
bars = get_bars(decomposition)

In [14]:
def calculate_bars(simulation, period, bars: List[Bar]) -> dict:
    return {
        bar.code: simulation.calculate_add(bar.code, period)
        for bar in bars
    }

In [15]:
results = calculate_bars(simulation, PERIOD, bars)

In [16]:
def iter_bars_with_value_at_index(bars, results, index, include_zero=False):
    for bar in bars:
        value = results[bar.code][index]
        if include_zero or value != 0:
            yield dataclasses.replace(bar, value=value)

In [17]:
def iter_displayed_bars(bars: List[Bar], results: dict, index: int, include_subtotals: bool = True):
    displayed_bars = filter(lambda bar: include_subtotals or bar.type != BarType.SUB_TOTAL, bars)
    displayed_bars = iter_bars_with_value_at_index(displayed_bars, results, index)
    yield from displayed_bars

In [18]:
def create_waterfall(bars, results, title="Waterfall"):
    def update_plot(*args):
        index = value_to_index(MIN, STEP, slider.value)
        displayed_bars = list(iter_displayed_bars(bars, results, index, include_subtotals.value))
        with fig.batch_update():
            waterfall = fig.data[0]
            waterfall.x = [bar.name or bar.code for bar in displayed_bars]
            waterfall.y = [
                bar.value if bar.type == BarType.VALUE else 0
                for bar in displayed_bars
            ]
            waterfall.measure = [
                "relative" if bar.type == BarType.VALUE else "total"
                for bar in displayed_bars
            ]
            waterfall.text = [
                ("{:+.0f}" if bar.type == BarType.VALUE else "{:.0f}").format(bar.value)
                for bar in displayed_bars
            ]

    include_subtotals = widgets.Checkbox(description="Include subtotals")
    include_subtotals.observe(update_plot, 'value')

    slider = widgets.FloatSlider(min=MIN, max=MAX, step=STEP, value=INITIAL_VALUE)
    slider.observe(update_plot, 'value')

    waterfall = go.Waterfall(textposition="inside")
    layout = go.Layout(title=title)
    fig = go.FigureWidget(data=[waterfall], layout=layout)
    update_plot()

    return widgets.VBox([
        widgets.HBox([widgets.Label('Salaire de base'), slider]),
        include_subtotals,
        fig,
    ])

In [19]:
create_waterfall(bars, results)

VBox(children=(HBox(children=(Label(value='Salaire de base'), FloatSlider(value=0.0, max=500000.0, step=26315.…