# Chapter 2 Exercise solutions



## Q1)

For this question we will revisit the health
expenditure data, which can be loaded by running:


In [None]:
import seaborn as sns

data = sns.load_dataset("healthexp")
data.head()

Again, the code below will generate a subset of the data for Great Britain:


In [None]:
gb = data[data["Country"] == "Great Britain"]

### Q1a)

Run the code below to initialise a `plt.subplots()` figure with two panels.


In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(
    nrows=2,
    ncols=1,
    figsize=(6, 5),
)

The top panel is accessed using `ax[0]` and the bottom panel is accessed with `ax[1]`.

### Q1b)

The code below adds a regression plot to the upper panel, showing the best-fit trendline
for the life expectancy vs the health expenditure for the Great Britain data.

Add a residual plot to the lower panel.

*Hint*: use the argument `ax` to specify which axes to use.


In [None]:
fig, ax = plt.subplots(
    nrows=2,
    ncols=1,
    figsize=(6, 5),
)
sns.regplot(
    data=gb,
    x="Spending_USD",
    y="Life_Expectancy",
    ax=ax[0],
)
# Your Q1b) code here
sns.residplot(
    data=gb,
    x="Spending_USD",
    y="Life_Expectancy",
    ax=ax[1],
)

_NB. Because our figure has been created with matplotlib, it can be customised
using matplotlib syntax. For example, you could change the y-axis label of the lower
panel to "Residuals" using_ `ax[1].set_ylabel("Residuals")`

## Q2)

In this question we will plot the distribution of pulse rates for different
types of exercise and diet. First, load the dataset with:


In [None]:
import seaborn as sns

exercise = sns.load_dataset("exercise")
exercise.head()

### Q2a)

Plot the distribution of pulse rates (`pulse`) using a histogram.


In [None]:
# Your Q2a) code here
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(8, 6))
sns.histplot(data=exercise, x="pulse", ax=ax)

### Q2b)

Plot the distribution of pulse rates, split by exercise type (listed in the `"kind"` column). 
Use stepped histograms.

*hint*: The `hue` argument can be used to split a histogram by a variable. 

In [None]:
# Your Q2b) code here
fig, ax = plt.subplots(figsize=(8, 6))
sns.histplot(
    data=exercise,
    x="pulse",
    element="step",
    hue="kind",
    ax=ax,
)

### Q2c)

Recreate the plot from Q2b), but visualising the distributions with a KDE,
rather than a histogram.


In [None]:
# Your Q2c) code here
fig, ax = plt.subplots(figsize=(8, 6))
sns.kdeplot(data=exercise, x="pulse", hue="kind", ax=ax)

## Q3)

### Q3a)

Using a facet grid of histograms, display the distributions of pulse rates (`pulse`) for
every combination of `diet` and exercise type (`kind`).

In [None]:
# Run this cell to hide the deprecation warning
import warnings

warnings.filterwarnings("ignore", message="iteritems is deprecated")

In [None]:
# Your Q3a) code here
g = (
    sns.FacetGrid(data=exercise, row="diet", col="kind")
    .map(sns.histplot, "pulse")
)

### Q3b)

The x-axis labels have no units! Redo the plot with the unit "bpm" (beats per
minute) displayed on the x-axes.

*Hint: do this for the lower panels only*


In [None]:
# Your Q3b) code here
g = (
    sns.FacetGrid(data=exercise, row="diet", col="kind")
    .map(sns.histplot, "pulse")
)
fig, ax = g.figure, g.axes
for i in range(3):
    ax[1, i].set_xlabel("Pulse / bpm")

## Extra exercises!

## Q4)

For this question we will construct a pair grid plot using the planets dataset,
which can be loaded by running:


In [None]:
planets = sns.load_dataset("planets")
planets.head()

### Q4a)

The code below will generate an empty pair grid. Update this so that the
diagonal panels display histograms, the upper panels display scatter plots, and
the lower panels display KDE plots.


In [None]:
# Your Q4a) code here
g = (
    sns.PairGrid(data=planets, diag_sharey=False)
    .map_diag(sns.histplot)
    .map_lower(sns.kdeplot)
    .map_upper(sns.scatterplot)
)

### Q4b)

A linear scale is unsuitable for the orbital period and the distance. Running
the code below will create a new dataset `planets_log`, with the base-10
logarithm applied to the orbital period and distance columns:


In [None]:
import numpy as np

planets_log = planets.copy()
planets_log["orbital_period"] = np.log10(planets["orbital_period"])
planets_log["distance"] = np.log10(planets["distance"])

Rerun your code from part a) using the new dataset. The KDEs should now span the
full axis range.


In [None]:
# Your Q4b) code here
g = (
    sns.PairGrid(data=planets_log, diag_sharey=False)
    .map_diag(sns.histplot)
    .map_lower(sns.kdeplot)
    .map_upper(sns.scatterplot)
)

### Q4c)

Change the axis labels to `"Number"`, `"Orbital period / yr"`, `"Mass / $M_J$"`,
`"Distance / ly"`, `"Year"`.

*Hint: use a for loop to update the labels.*


In [None]:
# Your Q4c) code here
g = (
    sns.PairGrid(data=planets_log, diag_sharey=False)
    .map_diag(sns.histplot)
    .map_lower(sns.kdeplot)
    .map_upper(sns.scatterplot)
)

fig, ax = g.figure, g.axes
labels = [
    "Number",
    "Orbital period / yr",
    "Mass / $M_J$",
    "Distance / ly",
    "Year",
]
for i, lab in enumerate(labels):
    ax[i, 0].set_ylabel(lab)
    ax[4, i].set_xlabel(lab)

### Q4d)

There is a mismatch between the axis labels and tick values for the orbital
period and the distance, where base-10 logarithms are now used. Correct this by
replacing the tick labels using scientific notation. For example, 1 should be
replaced with $10^1$ (written as `$10^1$` in math mode), etc.

Finally, save your figure as a PDF.

*Hint: because of the shared axes, you only need to amend the x ticks for one
row and the y ticks for one column.*


In [None]:
# Your Q4d) code here
g = (
    sns.PairGrid(data=planets_log, diag_sharey=False)
    .map_diag(sns.histplot)
    .map_lower(sns.kdeplot)
    .map_upper(sns.scatterplot)
)

fig, ax = g.figure, g.axes
for i, lab in enumerate(labels):
    ax[i, 0].set_ylabel(lab)
    ax[4, i].set_xlabel(lab)

# Update orbital period tick labels
ax[4, 1].set_xticks([0, 3, 6])
ax[4, 1].set_xticklabels(["$10^0$", "$10^3$", "$10^6$"])
ax[1, 0].set_yticks([0, 3, 6])
ax[1, 0].set_yticklabels(["$10^0$", "$10^3$", "$10^6$"])

# Update distance tick labels
ax[4, 3].set_xticks([0, 2, 4])
ax[4, 3].set_xticklabels(["$10^0$", "$10^2$", "$10^4$"])
ax[3, 0].set_yticks([0, 2, 4])
ax[3, 0].set_yticklabels(["$10^0$", "$10^2$", "$10^4$"])

fig.savefig("planets_grid.pdf")