In [1]:
import dataclasses
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import List

from ruamel.yaml import YAML
from IPython.display import display, JSON
import ipywidgets as widgets
import plotly.graph_objects as go

from openfisca_core.simulation_builder import SimulationBuilder
from openfisca_france import FranceTaxBenefitSystem

In [2]:
tax_benefit_system = FranceTaxBenefitSystem()

In [3]:
yaml = YAML(typ='safe')
decomposition = yaml.load(Path("./decomposition.yaml"))

In [4]:
class BarType(Enum):
    VALUE = 1
    SUB_TOTAL = 2
    TOTAL = 3

In [5]:
@dataclass
class Bar:
    code: str
    type: BarType
    short_name: str = None
    value: float = None

In [6]:
def get_bars(decomposition: dict) -> List[Bar]:
    def visit(node):
        children = node.get("children")
        if children:
            for child in children:
                yield from visit(child)
            type_ = BarType.TOTAL if node["code"] == decomposition["code"] else BarType.SUB_TOTAL
            yield Bar(code=node["code"], short_name=node.get("short_name"), type=type_)
        else:
            yield Bar(code=node["code"], short_name=node.get("short_name"), type=BarType.VALUE)

    return list(visit(decomposition))

In [7]:
bars = get_bars(decomposition)

In [8]:
def iter_displayed_bars(bars: List[Bar], results: dict, salaire_de_base_arr, salaire_de_base: float, include_subtotals: bool, include_zero: bool):
    for bar in bars:
        if bar.type == BarType.SUB_TOTAL and not include_subtotals:
            continue
        value = dict(zip(salaire_de_base_arr, results[bar.code]))[salaire_de_base]
        if include_zero or value != 0:
            yield dataclasses.replace(bar, value=value)  

In [9]:
def calculate_bars(bars, simulation, period):
    return {
        bar.code: simulation.calculate_add(bar.code, period)
        for bar in bars
    }

In [18]:
def create_waterfall(bars, results, salaire_de_base_arr, period):
    def update_plot(*args):
        displayed_bars = list(iter_displayed_bars(bars, results, salaire_de_base_arr=salaire_de_base_arr, salaire_de_base=slider.value,
                                                  include_subtotals=include_subtotals.value, include_zero=include_zero.value))
        with fig.batch_update():
            waterfall = fig.data[0]
            waterfall.x = [bar.short_name or bar.code for bar in displayed_bars]
            waterfall.y = [
                bar.value if bar.type == BarType.VALUE else 0
                for bar in displayed_bars
            ]
            waterfall.measure = [
                "relative" if bar.type == BarType.VALUE else "total"
                for bar in displayed_bars
            ]
            waterfall.text = [
                ("{:+.0f}" if bar.type == BarType.VALUE else "{:.0f}").format(bar.value)
                for bar in displayed_bars
            ]

    include_subtotals = widgets.Checkbox(description="Afficher les sous-totaux")
    include_subtotals.observe(update_plot, 'value')

    include_zero = widgets.Checkbox(description="Afficher les valeurs 0")
    include_zero.observe(update_plot, 'value')

    slider = widgets.SelectionSlider(options=salaire_de_base_arr)
    slider.observe(update_plot, 'value')

    waterfall = go.Waterfall(textposition="inside")
    layout = go.Layout(title=f"Décomposition du revenu disponible ({period})")
    fig = go.FigureWidget(data=[waterfall], layout=layout)
    update_plot()

    return widgets.VBox([
        widgets.HBox([widgets.Label('Salaire de base'), slider]),
        widgets.HBox([include_subtotals, include_zero]),
        fig,
    ])

In [24]:
def build_test_case(salaire_min, salaire_max, salaire_count, period):
    return {
    "individus": {
        "Michel": {
            'date_naissance': {'ETERNITY': '1980-01-01'},
        },
    },
    "familles": {
        "famille_1": {
            "parents": ["Michel"]
        }
    },
    "foyers_fiscaux": {
        "foyer_fiscal_1": {
            "declarants": ["Michel"],
        },
    },
    "menages": {
        "menage_1": {
            "personne_de_reference": ["Michel"],
        }},
    "axes": [[
        {
            "name": 'salaire_de_base',
            "count": salaire_count,
            "min": salaire_min,
            "max": salaire_max,
            "period": period,
        },
    ]],
}

Situation : un célibataire né le 1er janvier 1980 dont le salaire de base varie entre deux bornes.

- choisir la période de simulation et les bornes du salaire de base
- cliquer sur "Calculer" pour générer un [diagramme en cascade](https://fr.wikipedia.org/wiki/Diagramme_en_cascade) de la décomposition du revenu disponible
- il est possible de faire varier le salaire de base avec un "slider"

Il est possible de générer plusieurs diagrammes depuis des paramètres différents.

In [27]:
output = widgets.Output()

period = widgets.IntText(description="Période", value=2019)
salaire_min = widgets.IntText(description="Minimum", value=0)
salaire_max = widgets.IntText(description="Maximum", value=100_000)
salaire_count = widgets.IntText(description="Nombre d'échantillons", value=101)

button = widgets.Button(description="Calculer")

@output.capture()
def on_button_clicked(b):
    test_case = build_test_case(salaire_min=salaire_min.value, salaire_max=salaire_max.value, salaire_count=salaire_count.value, period=period.value)
    simulation_builder = SimulationBuilder()
    simulation = simulation_builder.build_from_entities(tax_benefit_system, test_case)
    salaire_de_base_arr = simulation.calculate_add("salaire_de_base", period=period.value)
    results = calculate_bars(bars, simulation, period=period.value)
    display(create_waterfall(bars, results, salaire_de_base_arr, period=period.value))

button.on_click(on_button_clicked)

with output:
    display(widgets.VBox([
        period,
        widgets.HBox([widgets.Label("Salaire"), salaire_min, salaire_max, salaire_count]),
        button,
    ]))

output

Output()