# Unit 7: Advanced Data Visualization
-------------------------------------

**After completing this unit, you should be able to:**

- Create faceted plots in `matplotlib`
- Use the `seaborn` package for more advanced plots

## 7.1. Creating faceted plots

In Unit 4, we created an example where we plotted a time-series data set, and it's derivative. Let's dissect this code to understand how the subplots were created. The key difference, compared to what we have done previously, it that we have introduced new parameters for the [`plt.subplots()`](https://matplotlib.org/stable/api/figure_api.html#matplotlib.figure.Figure.subplots) function. 

| Parameter | Default Value | Possible Values | Description |
|-----------|---------------|-----------------|-------------|
| `nrows` | 1 | integer | number of rows of plots to create |
| `ncols` | 1 | integer | number of columns of plots to create |
| `sharex` | `False` | `True`, `False`, `'row'`, `'col'` | should each axis share the same x-ticks? |
| `sharey` | `False` | `True`, `False`, `'row'`, `'col'` | should each axis share the same y-ticks? |

In this example, we have created 2 *rows* of plots, with the default value of 1 *column*. Because these share the same time scale, it is appropriate for them to share the same x-tick positions and labels. However, the y-scale for the signal and its derivative are not in the same units, so these are not shared (because the default argument is `False`).

When `nrows` or `ncols` is greater than 1 the `ax` object becomes an array, instead of a single `axis`. So, to select the specific axis, array indexing is required. By default, when `nrows==ncols==1`, the `plt.subplots()` function collapses the array to a scalar value so that indexing is not required. This is why we have not needed to use array indexing previously. In contrast if both `nrows` *and* `ncols` are greater than 1, then the return `ax` array will be 2-dimensional. When this is true, then the axis will be selected in a $row \times column$ form, just like other 2-dimensional matrices. For example, the `axis` in the first row and first column would be selected with `ax[0, 0]`. Try changing this plot to display the two signals side-by-side, instead of top and bottom.

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

plt.style.use('ggplot')

# read a .csv file into a DataFrame variable named df
fid_df = pd.read_csv('../../data/pcr-polyethylene_gc-fid.csv')
fid_df['seconds'] = fid_df['minutes'] * 60

# calculate the first derivative (change in FID signal) / (change in time)
fid_df['1st_derivative'] = fid_df['fid'].diff() / fid_df['seconds'].diff()

# plot the FID signal and its 1st derivative
# the \n in the label is the new-line character to split the label into 2 lines
fig, ax = plt.subplots(nrows=2, sharex=True)

ax[0].plot(fid_df['seconds'], fid_df['fid'])
ax[0].set_ylabel('Original\nFID Signal')

ax[1].plot(fid_df['seconds'], fid_df['1st_derivative'])
ax[1].set_xlabel('Seconds')
ax[1].set_ylabel('1st Derivative\nof FID Signal')

Creating faceted plots can be especially useful when we have data that has multiple scales, or if we need to separate data. Separating the plots can be easier for a reader to interpret, versus plotting using multiple y-scales on the same plot.

## 7.2. Advanced plotting using the `seaborn` package

The [`seaborn`](https://seaborn.pydata.org/) package was developed on top of `matplotlib` to provide easy-to-use functions to produce more complicated plots, especially when you there are multiple categories to display on the same plot.

### 7.2.1. Faceted histogram

In Unit 5, we explored the `axis.hist()` function for the creation of histograms. In that unit, we plotted histograms on a single axis. When there are many categories to be explored, a single axis can get crowded. Using what we learned in section 7.1, we could create faceted plots by from scratch. The `seaborn.displot()` function will automatically create the facets on our behalf, speeding up the creation of new plots.

In [None]:
import seaborn as sns

df = sns.load_dataset("penguins")

sns.displot(
    df, x="flipper_length_mm", col="species", row="sex",
    binwidth=3, height=3, facet_kws=dict(margin_titles=True)
)

df.head()

### 7.2.2. Grouped bar plot



In [None]:
penguins = sns.load_dataset("penguins")

# Draw a nested barplot by species and sex
g = sns.catplot(
    data=penguins, kind="bar",
    x="species", y="body_mass_g", hue="sex",
    ci="sd", palette="dark", alpha=.6, height=6
)
g.despine(left=True)
g.set_axis_labels("", "Body mass (g)")
g.legend.set_title("")

### 7.2.3. Scatter plot with categorical value



In [None]:
# Load the example diamonds dataset
diamonds = sns.load_dataset("diamonds")

# Draw a scatter plot while assigning point colors and sizes to different
# variables in the dataset
f, ax = plt.subplots(figsize=(6.5, 6.5))
sns.despine(f, left=True, bottom=True)
clarity_ranking = ["I1", "SI2", "SI1", "VS2", "VS1", "VVS2", "VVS1", "IF"]
sns.scatterplot(x="carat", y="price",
                hue="clarity", size="depth",
                palette="ch:r=-.2,d=.3_r",
                hue_order=clarity_ranking,
                sizes=(1, 8), linewidth=0,
                data=diamonds, ax=ax)

### 7.2.4. Time series with a confidence interval

In [None]:
import seaborn as sns
sns.set_theme(style="darkgrid")

# Load an example dataset with long-form data
fmri = sns.load_dataset("fmri")

# Plot the responses for different events and regions
sns.lineplot(x="timepoint", y="signal",
             hue="region", style="event",
             data=fmri)

### 7.2.5. Linear regression plot



In [None]:
# Load the penguins dataset
penguins = sns.load_dataset("penguins")

# Plot sepal width as a function of sepal_length across days
g = sns.lmplot(
    data=penguins,
    x="bill_length_mm", y="bill_depth_mm", hue="species",
    height=5
)

# Use more informative axis labels than are provided by default
g.set_axis_labels("Snoot length (mm)", "Snoot depth (mm)")

### 7.2.6. Box plot with observations



In [None]:
# Initialize the figure with a logarithmic x axis
f, ax = plt.subplots(figsize=(7, 6))
ax.set_xscale("log")

# Load the example planets dataset
planets = sns.load_dataset("planets")

# Plot the orbital period with horizontal boxes
sns.boxplot(x="distance", y="method", data=planets,
            whis=[0, 100], width=.6, palette="vlag")

# Add in points to show each observation
sns.stripplot(x="distance", y="method", data=planets,
              size=4, color=".3", linewidth=0)

# Tweak the visual presentation
ax.xaxis.grid(True)
ax.set(ylabel="")
sns.despine(trim=True, left=True)

### 7.2.7. Joint plot



In [None]:
import seaborn as sns

tips = sns.load_dataset("tips")
g = sns.jointplot(x="total_bill", y="tip", data=tips,
                  kind="reg", truncate=False,
                  xlim=(0, 60), ylim=(0, 12),
                  color="m", height=7)

--------------
## Next Steps:

1. Complete the [Unit 7 Problems](./unit07-solutions.ipynb) to test your understanding
2. Advance to [Unit 8](../08-image-analysis/unit08-lesson.ipynb) when you're ready for the next step