In [None]:
from src.projects.fagradalsfjall.common.paths import get_blog_post_subfolder
from src.tools.matplotlib import plot_style_matplotlib_default

import matplotlib.pyplot as plt

import numpy as np

In [None]:
# -------------------------------------------------------------------------
#  Output path settings
# -------------------------------------------------------------------------
path_figures = get_blog_post_subfolder(3, "figures")

In [None]:
# -------------------------------------------------------------------------
#  Init figure & axes
# -------------------------------------------------------------------------
plot_style_matplotlib_default()

fig, ax = plt.subplots(1, 1)  # type: plt.Figure, plt.Axes

# -------------------------------------------------------------------------
#  Colors
# -------------------------------------------------------------------------
blue = (0.3, 0.4, 0.8)
green = (0.3, 0.8, 0.4)
red = (0.8, 0.3, 0.4)


# -------------------------------------------------------------------------
#  Axes with arrows
# -------------------------------------------------------------------------
x_max = 1.1
y_max = 1.1

ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)

ax.plot(x_max, 0, ">k", clip_on=False)
ax.plot(0, y_max, "^k", clip_on=False)

# -------------------------------------------------------------------------
#  Actual lines & labels
# -------------------------------------------------------------------------

# --- define curves -------------------
x = np.linspace(0.05, 1.05, 100)

y_irreducible = 0.05 * np.ones_like(x)
y_bias = 0.1 * (x**-1)
y_variance = 0.1 + (x**4)

y_total = y_irreducible + y_bias + y_variance

# --- plot curves ---------------------
lw = 2
ax.plot(x, y_irreducible, c=red, lw=lw)
ax.plot(x, y_bias, c=green, lw=lw)
ax.plot(x, y_variance, c=blue, lw=lw)
ax.plot(x, y_total, c="k", lw=lw)

# --- labels --------------------------
text_kwargs = dict(ha="center", va="center", fontsize=10, fontweight=600)
ax.text(0.48, 0.45, "Total error", c="k", **text_kwargs)
ax.text(0.5, 0.08, "Irreducible error", c=red, **text_kwargs)
ax.text(0.15, 0.5, "Bias", c=green, **text_kwargs)
ax.text(0.87, 0.5, "Variance", c=blue, **text_kwargs)

# -------------------------------------------------------------------------
#  Labels & axis settings
# -------------------------------------------------------------------------

ax.set_xlim(0, x_max)
ax.set_xticklabels([])
ax.set_xlabel("Model complexity")

ax.set_ylim(0, y_max)
ax.set_yticklabels([])
ax.set_ylabel("Error")

ax.grid(True)

# -------------------------------------------------------------------------
#  Figure properties
# -------------------------------------------------------------------------
fig.set_size_inches(w=10, h=5)
fig.tight_layout()

In [None]:
fig.savefig(path_figures / "bias_variance.png", dpi=300)