# MPLLayout Demo

`mpllayout` models axes and other elements in figures as geometric primitives which can be constrained relative to each other. 
This gives a flexible way to precisely position figure elements.

The following demo produces Figure 1  from the paper "Ten Simple Rules for Better Figures" (Rougier NP, Droettboom M, Bourne PE (2014) Ten Simple Rules for Better Figures. PLOS Computational Biology 10(9): e1003833. https://doi.org/10.1371/journal.pcbi.1003833).
This figure is itself, a remake of one originally published in the [New York Times](https://archive.nytimes.com/www.nytimes.com/imagepages/2007/07/29/health/29cancer.graph.web.html?action=click&module=RelatedCoverage&pgtype=Article&region=Footer).

The below two sections illustrate how to create the above figure with a `Grid` type constraint as well as more basic constraints.

In [None]:
# Import the relevant packages

import matplotlib.pyplot as plt
import numpy as np

from mpllayout.solver import solve
from mpllayout.layout import Layout, update_layout_constraints
from mpllayout import primitives as pr
from mpllayout import constraints as co
from mpllayout.matplotlibutils import subplots, update_subplots, find_axis_position
from mpllayout import ui

In [None]:
def figure_layout(layout):
    """
    Return a figure of the layout
    """
    prims_n, solve_info = solve(layout)
    return ui.figure_prims(prims_n)

## Create the `layout`

The `layout` is a collection of:
- geometric primitives to represent figure elements
- and constraints which position the figure elements


In [None]:
layout = Layout()

### Add geometric primitives

Represent the `mpl.Figure` by a `Quadrilateral` and `mpl.Axes` by an `Axes` primitive.
The `Axes` primitive contains a `Quadrilateral` to represent the frame, and optionally, a `Quadrilateral` and `Point` to represent the x axis and y axis.

In [None]:
# First a box called "Figure" to the layout
layout.add_prim(pr.Quadrilateral(), "Figure")

# Then add "axes" primitives to represent a left, middle, and right axes
# Axes can contain `Quadrilaterals` and `Point` primitives to represent the
# axes frame, x/y axis, and axis labels
layout.add_prim(pr.Axes(xaxis=True, yaxis=True), "AxesLeft")
layout.add_prim(pr.Axes(), "AxesMid")
layout.add_prim(pr.Axes(xaxis=True, yaxis=True), "AxesRight")

### Add geometric constraints

#### Make all `Quadrilateral`s rectangular

MPLlayout doesn't constrain quadrilaterals to be rectangular like the figure or axes frame in matplotlib so they must be constrained.

In [None]:
# `co.Box` forces quadrilateral sides to be vertical and tops/bottoms to be horizontal
# It has no parameters so that last argument is any empty tuple
layout.add_constraint(co.Box(), ("Figure",), ())

# "AxesMid" only has a frame (no x/y axis)
layout.add_constraint(co.Box(), ("AxesMid/Frame",), ())

# Here we constrain all child quads of the left and right axes to be boxes
for axes_key in ["AxesLeft", "AxesRight"]:
    for quad_key in ["Frame", "XAxis", "YAxis"]:
        layout.add_constraint(co.Box(), (f"{axes_key}/{quad_key}",), ())

In [None]:
# This plots the created geometry
# Note that by default all the quads are unit squares
figure_layout(layout)

#### Fix the Figure dimensions and position

Set the figure width/height and fix the bottom left point to the origin

In [None]:
## Set the figure dimensions

# Fix the bottom left point of 'Figure' to the origin
layout.add_constraint(co.Fix(), ("Figure/Line0/Point0",), {"location": np.array([0, 0])})

# Set the 'Figure' width and height
fig_width, fig_height = (12, 7)
layout.add_constraint(co.Length(), ("Figure/Line1",), (fig_height,))
layout.add_constraint(co.Length(), ("Figure/Line0",), (fig_width,))

In [None]:
# Note the figure quadrilateral is 12" by 7"
# The remaining axes are unit squares since they haven't been constrained yet
fig, ax = figure_layout(layout)

#### Constrain the `Axes` to a 1 by 3 rectilinear grid

We can force the left, middle, and right axes to align on 1 by 3 grid and set their relative widths.



In [None]:
# Align the axes on a 1x3 rectilinear grid
shape = (1, 3)
layout.add_constraint(
    co.RectilinearGrid(shape),
    ("AxesLeft/Frame", "AxesMid/Frame", "AxesRight/Frame"),
    ()
)

In [None]:
# Set zeros margins between left/right axes and the middle axes
layout.add_constraint(co.OuterMargin(side='left'), ("AxesMid/Frame", "AxesLeft/Frame"), (0,))
layout.add_constraint(co.OuterMargin(side='right'), ("AxesMid/Frame", "AxesRight/Frame"), (0,))

# Make the left/right axes the same width and the central axes 0.5 that width
layout.add_constraint(
    co.RelativeLength(), ("AxesRight/Frame/Line0", "AxesLeft/Frame/Line0"), (1.0,)
)
layout.add_constraint(
    co.RelativeLength(), ("AxesMid/Frame/Line0", "AxesLeft/Frame/Line0"), (0.5,)
)

In [None]:
# Note the 3 axes are now aligned
# It's difficult to see because the left and right axes also have axises that are shown
fig, ax = figure_layout(layout)

#### Position the x-axis and y-axis for left and right axes

In [None]:
# These constraints fix the x/y axis to one side of the axes
# When creating the figure from a layout, these axis positions will be inherited
layout.add_constraint(co.PositionXAxis(top=True, bottom=False), ("AxesLeft", ), ())
layout.add_constraint(co.PositionYAxis(left=False, right=True), ("AxesLeft", ), ())

layout.add_constraint(co.PositionXAxis(top=True, bottom=False), ("AxesRight", ), ())
layout.add_constraint(co.PositionYAxis(left=True, right=False), ("AxesRight", ), ())


In [None]:
# These constraints set the variable width of the y-axis and variable height of the x-axis
# The axis dimensions are variable since they depend on the size of any tick labels
# Axis dimensions can be updated using `update_layout_constraints` after
# axis text has been generated
# Note that axis labels aren't included in the size of the axis!
layout.add_constraint(
    co.XAxisHeight(), (f"AxesLeft/XAxis",), (None,)
)
layout.add_constraint(
    co.YAxisWidth(), (f"AxesLeft/YAxis",), (None,)
)
layout.add_constraint(
    co.XAxisHeight(), (f"AxesRight/XAxis",), (None,)
)
layout.add_constraint(
    co.YAxisWidth(), (f"AxesRight/YAxis",), (None,)
)

In [None]:
# Note the x axis is now stuck to the top of each axes
figure_layout(layout)

#### Set Margins

Note that earlier we never set the absolute width of the axes; to ensure nice whitespace we can specify margins to indirectly set the axes dimensions.

In [None]:
# Set the top/bottom margins
# The top margin will be set above the x-axis bounding box which ensure the text won't cut out of the figure
margin_top, margin_bottom = (0.5, 0.5)

# The `InnerMargin` constraint sets the gap between an
# inner quad ("AxesRight/XAxis") and an outer quad ("Figure")
layout.add_constraint(
    co.InnerMargin(side="top"), ("AxesRight/XAxis", "Figure"), (margin_top,)
)

layout.add_constraint(
    co.InnerMargin(side="bottom"), ("AxesRight/Frame", "Figure"), (margin_bottom,)
)

# Set the left/right margins
margin_left, margin_right = (0.5, 0.5)

layout.add_constraint(
    co.InnerMargin(side='left'), ("AxesLeft/Frame", "Figure"), (margin_left,)
)

layout.add_constraint(
    co.InnerMargin(side='right'), ("AxesRight/Frame", "Figure"), (margin_right,)
)

In [None]:
# Now the margins are all constrained!
# The grid arrangement of the left/middle/right axes is clearer since the axes have been moved apart
fig, ax = figure_layout(layout)

## Use the `layout` to plot the figure

We can solve the `layout` to determine a set of primitives that satisfy the constraints.
The solved primitives are then used to generate matplotlib figure and axes objects that reflect the layout.

This is nice because the figure design and arrangement is separated from the plotting of data.

In [None]:
prim_tree_n, solve_info = solve(layout)

# The `subplots` function uses the solved primitives to create figure and axes objects with the determined sizes
# `axs` is a dictionary with keys matching the axes names
fig, axs = subplots(prim_tree_n)

## Plot the "Ten Simple Rules for Better Figures" dataset

We can use the generated figure and axes to plot data now.

In [None]:
# The data below is approximated from a New York Times article ()
# and is adapted from the figure-1.py file available at (https://github.com/rougier/ten-rules)

diseases = [
    "Kidney Cancer",
    "Bladder Cancer",
    "Esophageal Cancer",
    "Ovarian Cancer",
    "Liver Cancer",
    "Non-Hodgkin's\nlymphoma",
    "Leukemia",
    "Prostate Cancer",
    "Pancreatic Cancer",
    "Breast Cancer",
    "Colorectal Cancer",
    "Lung Cancer",
]
men_deaths = [
    10000,
    12000,
    13000,
    0,
    14000,
    12000,
    16000,
    25000,
    20000,
    500,
    25000,
    80000,
]
men_cases = [
    30000,
    50000,
    13000,
    0,
    16000,
    30000,
    25000,
    220000,
    22000,
    600,
    55000,
    115000,
]
women_deaths = [
    6000,
    5500,
    5000,
    20000,
    9000,
    12000,
    13000,
    0,
    19000,
    40000,
    30000,
    70000,
]
women_cases = [
    20000,
    18000,
    5000,
    25000,
    9000,
    29000,
    24000,
    0,
    21000,
    160000,
    55000,
    97000,
]

y_diseases = np.arange(len(diseases))

In [None]:
def format_axes(ax):
    """
    Apply the Axes formatting used in "Ten Simple Rules"
    """
    if not ax.xaxis.get_inverted():
        origin_side = "left"
        far_side = "right"
    else:
        origin_side = "right"
        far_side = "left"

    ax.spines[far_side].set_color("none")
    ax.spines[origin_side].set_zorder(10)
    ax.spines["bottom"].set_color("none")

    # ax.xaxis.set_ticks_position("top")

    # ax.yaxis.set_ticks_position(origin_side)
    ax.yaxis.set_ticks(y_diseases, labels=[""] * len(y_diseases))

    ax.spines["top"].set_position(("data", len(diseases) + 0.25))
    ax.spines["top"].set_color("w")

In [None]:
## Here we plot the actual NYT figure from the article

for ax in axs.values():
    ax.set_xlim(0, 200000)

# Plot the men/womens data
axs["AxesLeft"].barh(y_diseases, women_cases, height=0.8, fc="red", alpha=0.1)
axs["AxesLeft"].barh(y_diseases, women_deaths, height=0.55, fc="red", alpha=0.5)
axs["AxesLeft"].xaxis.set_inverted(True)

axs["AxesRight"].barh(y_diseases, men_cases, height=0.8, fc="blue", alpha=0.1)
axs["AxesRight"].barh(y_diseases, men_deaths, height=0.55, fc="blue", alpha=0.5)

axs_labels = ["AxesLeft", "AxesRight"]
axs_categories = ["women", "men"]
category_to_ax = {
    category: axs[key]
    for category, key in zip(axs_categories, axs_labels)
}
for category, ax in category_to_ax.items():
    format_axes(ax)
    ax.set_xticks(
        [0, 50000, 100000, 150000, 200000],
        [category.upper(), "50,000", "100,000", "150,000", "200,000"],
    )
    ax.grid(which="major", axis="x", color="white")
    ax.get_xticklabels()[0].set_weight("bold")

# Add ylabels to 'AxesMid'
axs["AxesMid"].set_axis_off()
axs["AxesMid"].set_ylim(axs["AxesLeft"].get_ylim())
axs["AxesMid"].set_xlim(-1, 1)

for y, disease_name in zip(y_diseases, diseases):
    axs["AxesMid"].text(0, y, disease_name, ha="center", va="center")

# Add the "NEW CASES" and "DEATHS" annotations
# Devil hides in the details...
arrowprops = {"arrowstyle": "-", "connectionstyle": "angle,angleA=0,angleB=90,rad=0"}

x = women_cases[-1]
y = y_diseases[-1]
axs["AxesLeft"].annotate(
    "NEW CASES",
    xy=(0.9 * x, y),
    xycoords="data",
    ha="right",
    fontsize=10,
    xytext=(-40, -3),
    textcoords="offset points",
    arrowprops=arrowprops,
)

x = women_deaths[-1]
axs["AxesLeft"].annotate(
    "DEATHS",
    xy=(0.85 * x, y),
    xycoords="data",
    ha="right",
    fontsize=10,
    xytext=(-50, -25),
    textcoords="offset points",
    arrowprops=arrowprops,
)

x = men_cases[-1]
axs["AxesRight"].annotate(
    "NEW CASES",
    xy=(0.9 * x, y),
    xycoords="data",
    ha="left",
    fontsize=10,
    xytext=(+40, -3),
    textcoords="offset points",
    arrowprops=arrowprops,
)

x = men_deaths[-1]
axs["AxesRight"].annotate(
    "DEATHS",
    xy=(0.9 * x, y),
    xycoords="data",
    ha="left",
    fontsize=10,
    xytext=(+50, -25),
    textcoords="offset points",
    arrowprops=arrowprops,
)

# Add the caption text
axs["AxesLeft"].text(
    165000, 8.2, "Leading Causes\nOf Cancer Deaths", fontsize=18, va="top"
)
axs["AxesLeft"].text(
    165000,
    7,
    "In 2007, there were more\n"
    "than 1.4 million new cases\n"
    "of cancer in the United States.",
    va="top",
    fontsize=10,
)

fig

#### Update x and y axis sizes

Now that data is plotted the x and y axis for the left/right axes have tick labels inside that change the whitespace.
You can use `update_layout_constraints` and the `XAxisHeight` and `YAxisWidth` constraints to update the axis sizes and adjust the figure layout.

In [None]:
# Update boundings boxes for the x/y axis now that text has been inserted
# This will update the layout of axes
update_layout_constraints(layout, axs)

prim_tree_n, info = solve(layout)
update_subplots(prim_tree_n, "Figure", fig, axs)

fig
fig.savefig("ten_simple_rules_demo.svg")

In [None]:
# If you plot the layout after axis sizes are updated, you can see the altered dimensions!
fig, ax = figure_layout(layout)