# MPLLayout Demo

`mpllayout` models axes and other elements in figures as geometric primitives which can be constrained relative to each other. 
This gives a flexible way to precisely position figure elements.

The following demo produces Figure 1  from the paper "Ten Simple Rules for Better Figures" (Rougier NP, Droettboom M, Bourne PE (2014) Ten Simple Rules for Better Figures. PLOS Computational Biology 10(9): e1003833. https://doi.org/10.1371/journal.pcbi.1003833).
This figure is itself, a remake of one originally published in the [New York Times](https://archive.nytimes.com/www.nytimes.com/imagepages/2007/07/29/health/29cancer.graph.web.html?action=click&module=RelatedCoverage&pgtype=Article&region=Footer).

The below two sections illustrate how to create the above figure with a `Grid` type constraint as well as more basic constraints.

In [None]:
# Import the relevant packages

import matplotlib.pyplot as plt
import numpy as np

from mpllayout.solver import solve
from mpllayout.layout import Layout, update_layout_constraints
from mpllayout import geometry as geo
from mpllayout.matplotlibutils import subplots, update_subplots, find_axis_position
from mpllayout import ui

In [None]:
def figure_layout(layout):
    """
    Plot a layout in an axes
    """
    prim_tree_n, solve_info = solve(
        layout.root_prim, *layout.flat_constraints(), max_iter=50
    )
    return ui.figure_prims(prim_tree_n), solve_info

### Create the layout with `Grid` constraints


First, create the `Layout` object used to track geometric primitives and constraints

In [None]:
layout = Layout()

#### Add all the `Figure` and `Axes` elements

Represent the `mpl.Figure` by a `Quadrilateral` and `mpl.Axes` by an `Axes` or `AxesXY` primitive.
The `AxesXY` primitives contains three `Quadrilaterals` to represent the drawing frame, and x and y axis bounding boxes.
The `Axes` primitives contains three `Quadrilaterals` to represent the drawing frame, and x and y axis bounding boxes.

In [None]:
# This adds a box called 'Figure' to the layout
# The `geo.Box` constraint ensure that the quadrilateral sides and tops/bottom are
layout.add_prim(geo.Quadrilateral(), "Figure")
layout.add_constraint(geo.Box(), ("Figure",), ())

# This adds a `AxesXY` called 'AxesLeft' to the layout
# We constrain x-axis to be on the top side and y-axis on the right side

layout.add_prim(geo.AxesXY(), "AxesLeft")
for key in ["AxesLeft/Frame", "AxesLeft/XAxis", "AxesLeft/YAxis"]:
    layout.add_constraint(geo.Box(), (key,), ())

layout.add_constraint(geo.PositionXAxis(top=True, bottom=False), ("AxesLeft", ), ())
layout.add_constraint(geo.PositionYAxis(left=False, right=True), ("AxesLeft", ), ())

layout.add_constraint(geo.PositionXAxis(top=True, bottom=False), ("AxesRight", ), ())
layout.add_constraint(geo.PositionYAxis(left=True, right=False), ("AxesRight", ), ())


# This adds an `Axes` (no x and y axis bounding boxes) called 'AxesMid' to the layout
# We'll use this axes to put the central labels in
layout.add_prim(geo.Axes(), "AxesMid")
layout.add_constraint(geo.Box(), ("AxesMid/Frame",), ())

# This adds a `AxesXY` called 'AxesRight' to the layout
# We constrain x-axis to be on the top side and y-axis on the left side

layout.add_prim(geo.AxesXY(), "AxesRight")
for key in ["AxesRight/Frame", "AxesRight/XAxis", "AxesRight/YAxis"]:
    layout.add_constraint(geo.Box(), (key,), ())

In [None]:
layout.add_constraint(
    geo.XAxisHeight(), (f"AxesLeft/XAxis",), (None,), "AxesLeft.XAxisHeight"
)
layout.add_constraint(
    geo.YAxisWidth(), (f"AxesLeft/YAxis",), (None,), "AxesLeft.YAxisWidth"
)
layout.add_constraint(
    geo.XAxisHeight(), (f"AxesRight/XAxis",), (None,), "AxesRight.XAxisHeight"
)
layout.add_constraint(
    geo.YAxisWidth(), (f"AxesRight/YAxis",), (None,), "AxesRight.YAxisWidth"
)

#### Fix the Figure dimensions and position

In [None]:
## Set the figure dimensions

# Fix the bottom left point of 'Figure' to the origin
layout.add_constraint(geo.Fix(), ("Figure/Line0/Point0",), {"location": np.array([0, 0])})

# Set the 'Figure' width and height
fig_width, fig_height = (12, 7)
layout.add_constraint(geo.Length(), ("Figure/Line1",), (fig_height,))
layout.add_constraint(geo.Length(), ("Figure/Line0",), (fig_width,))

In [None]:
(fig, ax), solve_info = figure_layout(layout)
print(solve_info)

#### Constrain the `Axes` to a 1 by 3 rectilinear grid



In [None]:
# Align the axes on a 1x3 rectilinear grid
shape = (1, 3)
layout.add_constraint(
    geo.RectilinearGrid(shape),
    ("AxesLeft/Frame", "AxesMid/Frame", "AxesRight/Frame"),
    ()
)

In [None]:
# Set zeros margins between left/right axes and the middle axes
layout.add_constraint(geo.OuterMargin(side='left'), ("AxesMid/Frame", "AxesLeft/Frame"), (0,))
layout.add_constraint(geo.OuterMargin(side='right'), ("AxesMid/Frame", "AxesRight/Frame"), (0,))

# Make the left/right axes the same width and the central axes 0.5 that width
layout.add_constraint(
    geo.RelativeLength(), ("AxesRight/Frame/Line0", "AxesLeft/Frame/Line0"), (1.0,)
)
layout.add_constraint(
    geo.RelativeLength(), ("AxesMid/Frame/Line0", "AxesLeft/Frame/Line0"), (0.5,)
)

In [None]:
## Set top/bottom and left/right margins for the axes
margin_top, margin_bottom = (0.5, 0.5)
margin_left, margin_right = (0.5, 0.5)

# Set the top/bottom margins
# The top margin will be set above the x-axis bounding box which ensure the text won't cut out of the figure
layout.add_constraint(
    geo.InnerMargin(side="top"), ("AxesRight/XAxis", "Figure"), (margin_top,)
)

layout.add_constraint(
    geo.InnerMargin(side="bottom"), ("AxesRight/Frame", "Figure"), (margin_bottom,)
)

# Set the left/right margins
layout.add_constraint(
    geo.InnerMargin(side='left'), ("AxesLeft/Frame", "Figure"), (margin_left,)
)

layout.add_constraint(
    geo.InnerMargin(side='right'), ("AxesRight/Frame", "Figure"), (margin_right,)
)

In [None]:
(fig, ax), solve_info = figure_layout(layout)
print(solve_info)

## Create the Figure from the layout

In [None]:
prim_tree_n, solve_info = solve(layout.root_prim, *layout.flat_constraints())

In [None]:
# The data below is approximated from a New York Times article ()
# and is adapted from the figure-1.py file available at (https://github.com/rougier/ten-rules)

diseases = [
    "Kidney Cancer",
    "Bladder Cancer",
    "Esophageal Cancer",
    "Ovarian Cancer",
    "Liver Cancer",
    "Non-Hodgkin's\nlymphoma",
    "Leukemia",
    "Prostate Cancer",
    "Pancreatic Cancer",
    "Breast Cancer",
    "Colorectal Cancer",
    "Lung Cancer",
]
men_deaths = [
    10000,
    12000,
    13000,
    0,
    14000,
    12000,
    16000,
    25000,
    20000,
    500,
    25000,
    80000,
]
men_cases = [
    30000,
    50000,
    13000,
    0,
    16000,
    30000,
    25000,
    220000,
    22000,
    600,
    55000,
    115000,
]
women_deaths = [
    6000,
    5500,
    5000,
    20000,
    9000,
    12000,
    13000,
    0,
    19000,
    40000,
    30000,
    70000,
]
women_cases = [
    20000,
    18000,
    5000,
    25000,
    9000,
    29000,
    24000,
    0,
    21000,
    160000,
    55000,
    97000,
]

y_diseases = np.arange(len(diseases))

In [None]:
def format_axes(ax):
    """
    Apply the Axes formatting used in "Ten Simple Rules"
    """
    if not ax.xaxis.get_inverted():
        origin_side = "left"
        far_side = "right"
    else:
        origin_side = "right"
        far_side = "left"

    ax.spines[far_side].set_color("none")
    ax.spines[origin_side].set_zorder(10)
    ax.spines["bottom"].set_color("none")

    # ax.xaxis.set_ticks_position("top")

    # ax.yaxis.set_ticks_position(origin_side)
    ax.yaxis.set_ticks(y_diseases, labels=[""] * len(y_diseases))

    ax.spines["top"].set_position(("data", len(diseases) + 0.25))
    ax.spines["top"].set_color("w")

In [None]:
fig, axs = subplots(prim_tree_n)

for ax in axs.values():
    ax.set_xlim(0, 200000)

# Plot the men/womens data
axs["AxesLeft"].barh(y_diseases, women_cases, height=0.8, fc="red", alpha=0.1)
axs["AxesLeft"].barh(y_diseases, women_deaths, height=0.55, fc="red", alpha=0.5)
axs["AxesLeft"].xaxis.set_inverted(True)

axs["AxesRight"].barh(y_diseases, men_cases, height=0.8, fc="blue", alpha=0.1)
axs["AxesRight"].barh(y_diseases, men_deaths, height=0.55, fc="blue", alpha=0.5)

axs_labels = ["AxesLeft", "AxesRight"]
axs_categories = ["women", "men"]
for axs_label, category in zip(axs_labels, axs_categories):
    ax = axs[axs_label]
    format_axes(ax)
    ax.set_xticks(
        [0, 50000, 100000, 150000, 200000],
        [category.upper(), "50,000", "100,000", "150,000", "200,000"],
    )
    ax.grid(which="major", axis="x", color="white")
    ax.get_xticklabels()[0].set_weight("bold")

# Add ylabels to 'AxesMid'
axs["AxesMid"].set_axis_off()
axs["AxesMid"].set_ylim(axs["AxesLeft"].get_ylim())
axs["AxesMid"].set_xlim(-1, 1)

for y, disease_name in zip(y_diseases, diseases):
    axs["AxesMid"].text(0, y, disease_name, ha="center", va="center")

# Add the "NEW CASES" and "DEATHS" annotations
# Devil hides in the details...
arrowprops = {"arrowstyle": "-", "connectionstyle": "angle,angleA=0,angleB=90,rad=0"}

x = women_cases[-1]
y = y_diseases[-1]
axs["AxesLeft"].annotate(
    "NEW CASES",
    xy=(0.9 * x, y),
    xycoords="data",
    ha="right",
    fontsize=10,
    xytext=(-40, -3),
    textcoords="offset points",
    arrowprops=arrowprops,
)

x = women_deaths[-1]
axs["AxesLeft"].annotate(
    "DEATHS",
    xy=(0.85 * x, y),
    xycoords="data",
    ha="right",
    fontsize=10,
    xytext=(-50, -25),
    textcoords="offset points",
    arrowprops=arrowprops,
)

x = men_cases[-1]
axs["AxesRight"].annotate(
    "NEW CASES",
    xy=(0.9 * x, y),
    xycoords="data",
    ha="left",
    fontsize=10,
    xytext=(+40, -3),
    textcoords="offset points",
    arrowprops=arrowprops,
)

x = men_deaths[-1]
axs["AxesRight"].annotate(
    "DEATHS",
    xy=(0.9 * x, y),
    xycoords="data",
    ha="left",
    fontsize=10,
    xytext=(+50, -25),
    textcoords="offset points",
    arrowprops=arrowprops,
)

# Add the caption text
axs["AxesLeft"].text(
    165000, 8.2, "Leading Causes\nOf Cancer Deaths", fontsize=18, va="top"
)
axs["AxesLeft"].text(
    165000,
    7,
    "In 2007, there were more\n"
    "than 1.4 million new cases\n"
    "of cancer in the United States.",
    va="top",
    fontsize=10,
)

# Update boundings boxes for the x/y axis now that text has been inserted
# This will update the layout of axes
update_layout_constraints(layout.root_constraint, layout.root_constraint_param, axs)
prim_tree_n, info = solve(layout.root_prim, *layout.flat_constraints())
update_subplots(prim_tree_n, "Figure", fig, axs)


fig.savefig("out/ten_simple_rules_demo.svg")

In [None]:
(fig, ax), solve_info = figure_layout(layout)
print(solve_info)