# Advanced plots with seaborn and pandas

In this activity, we will learn how to create more advanced plots using matplotlib, seaborn and pandas.

Features of the activity:
* How the seaborn functions work?
* Advanced plotting: 
  * Visualize groups of data
  * Customize plots and styling
  * Use subplots for chart comparison
* Use seaborn and pandas to visualize statistics

![](https://seaborn.pydata.org/_static/logo-wide-lightbg.svg)


In [None]:
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
sns.set_theme()

## How the seaborn functions work?

As an interface of matplotlib library, seaborn proposes a set of functions with similar signature but for different tasks. 
These functions are organized into three groups:
* Relational
* Distributional
* Categorical

In addition to that, there is a cross-cutting classification of seaborn functions as "axes-level" or "figure-level". Axes-level functions plots data onto a single `matplotlib.pyplot.Axes` object, which is the return value of the function. In contrast, figure-level functions interface with matplotlib through a seaborn object that manages the figure. Each module has a single figure-level function, which offers a unitary interface to its various axes-level functions.

For example, the figure-level function `displot` is an interface to the axes-level functions `histplot`, `kdeplot`, etc. see the figure below.

![](https://seaborn.pydata.org/_images/function_overview_8_0.png)

**Let's have a look on the argument of the seaborn functions !**

The functions provided by seaborn have broadly the same signature, meaning that the arguments are (almost) the same.
Indeed, seaborn closely works with pandas data frames which makes the use easier than matplotlib.

The main arguments of the axes-level functions are:
* data: usually a pandas DataFrame. Either a long-form or a wide-form dataset.
* x, y: columns names in the data frame
* hue: name of the column used to determine the color of plot elements
* ax: maplotlib `Axes` object in which to draw the plot, otherwise use the currently-active `Axes`

The figure-level functions have broadly the same main arguments. But as they create a matplotlib figure directly, they do not have `ax` as argument.

In [None]:
sns.histplot?

### Examples of distribution plots

In [None]:
# Load the Employees dataset
weather = pd.read_csv("https://gitlabsu.sorbonne-universite.fr/scai/data-visualization/-/raw/main/data/weather.csv", sep=";")

# Convert the date column as a datetime object
weather["date"] = pd.to_datetime(weather["date"])

# Extract the year from the dates
weather["year"] = weather["date"].dt.year

# Extract the seasons
season = weather.date.dt.month%12 // 3 + 1 # Get the season number and replace it by the season names
season[season == 1] = "winter"
season[season == 2] = "spring"
season[season == 3] = "summer"
season[season == 4] = "autumn"
weather["season"] = season # add a new column in the dataframe

# Set the date as new index
# We keep the date in a separated column (drop=False) because we will need it later
weather.set_index("date", inplace=True, drop=False)

weather.head()

In [None]:
# Load the Employees dataset
employees = pd.read_csv("https://gitlabsu.sorbonne-universite.fr/scai/data-visualization/-/raw/main/data/employees.csv")
employees.head()

In [None]:
# Create an histogram of the temperature for each year
# The argument multiple can have the values "stack", "fill", "dodge", "layer"
# Change the palette for a different colors
sns.histplot(data=weather, x="temp_max", hue="year", multiple="stack", palette="viridis")

In [None]:
# Density estimator, kdeplot
sns.kdeplot(data=weather, x="temp_max", hue="year", multiple="stack", palette="viridis")

### Axes-level vs Figure-level plots

In [None]:
# Example of axes-level distribution plot, just as previous
sns.histplot(data=weather, x="temp_max", hue="year", multiple="stack", palette="viridis")

In [None]:
# Example of figure-level distribution plot, just as previous
sns.displot(data=weather, x="temp_max", hue="year", multiple="stack", palette="viridis")

## Visualization x Statistics

Sometimes, data visualization requires a process of aggregation or estimation in which numerous data points are condensed into a summary statistic, like the mean or median. When displaying a summary statistic, it is generally advisable to include error bars, offering a visual indication of the accuracy with which the summary reflects the original data points.

Seaborn integrates error estimations directly on the plotting functions. In the case of a barplot, this means that the condense statistic is plotted as well as the variability of the data points within each category.

### Barplot with error bar

In [None]:
# Data
df = pd.DataFrame({
    "categories": ["A", "B", "C", "A", "B", "C"],
    "column_1": [15, 25, 35, 45, 30, 20]
})

# Barplot, the error bar is displayed in black lines, indicating the variability of the estimates
sns.barplot(data=df, x="categories", y="column_1")

### Regression plot

The regression plot visualizes a scatter plot of two variables together and fits a regression line.

In [None]:
# Scatter plot and linear regression fit, function sns.regplot

# extract a sample for the example
sample = weather.sample(100)

sns.regplot(data=sample, x="temp_max", y="wind")

## Visualize groups of data

Most of the functions from seaborn has a parameter `hue` allowing us to add information about a categorical variable to the chart. This is possible with the following functions: barplot, histplot, kdeplot, scatterplot, heatmap.

The data usually needs to be in a long format.

In [None]:
# Data with long format
df = pd.DataFrame({
    "categories": ["A", "B", "C", "A", "B", "C"],
    "classes": ["class 1", "class 1", "class 1", "class 2", "class 2", "class 2"],
    "values": [10, 20, 30, 45, 25, 15]
})
df

### Grouped barplot

To create grouped bar plots, you can use the `sns.barplot()` as below. It produces an axes-level plots and the arguments used are `x`, `y` and `hue`.

In [None]:
sns.barplot(data=df, x="categories", y="values", hue="classes")

In [None]:
# Another palette
sns.barplot(data=df, x="categories", y="values", hue="classes", palette="hot")

### Grouped barplot for columns comparison

If you want to compare the values of two (comparable) columns, you need to prepare the data using the melt function.

In [None]:
# Data
df = pd.DataFrame({
    "categories": ["A", "B", "C", "A", "B", "C"],
    "column_1": [15, 25, 35, 45, 30, 20],
    "column_2": [10, 20, 30, 45, 25, 15]
})

# transform into a long format where the previous column names will be considered as categories
melted = pd.melt(df, id_vars="categories", var_name="column", value_name="values")
melted

In [None]:
# Axes-level plot
sns.barplot(data=melted, x="categories", y="values", hue="column")

In [None]:
# Figure-level plot
sns.catplot(data=melted, x="categories", y="values", hue="column", kind="bar")

### Grouped boxplot

In [None]:
# Axes-level plot
sns.set_theme(style="whitegrid")
sns.boxplot(data=melted, x="categories", y="values", hue="column")

In [None]:
# Figure-level plot
sns.catplot(data=melted, x="categories", y="values", hue="column", kind="box")

## High-dimensional scatter plots

High-dimensional plots refers to visualizing data points in a higher dimension than 2.

Seaborn does not have a 3D scatter plot that represents three variables x, y and z together. So we need to use the function `sns.scatterplot` and to modify the visual if the data points to represent the other variables. In particular, when you have categorical variables, it is of interest to expose the categories by modifying the style of the data points (colors, sizes, shapes).

It should noted that `Matplotlib` has a 3D toolkit for this purpose, see [here](https://matplotlib.org/stable/gallery/mplot3d/index.html).

The function `sns.scatterplot` has the following arguments for this purpose:
* `hue`: grouping variable that will produce points with different colors. Can be either categorical or numeric.
* `size`: grouping variable that will produce points with different sizes. Can be either categorical or numeric.
* `style`: grouping variable that will produce points with different markers. Can have a numeric dtype but will always be treated as categorical.

By setting these arguments, you can produce a scatterplot with a dimension up to 5D! Be careful, increasing the chart dimension will not necessarily help the interpretation.

You also can use the arguments `col` and `row` that creates a faceted figure with multiple subplots arranged across the columns of the grid.

### Toy data examples

In [None]:
# Load a small data set from seaborn
tips = sns.load_dataset("tips")
tips.head()

In [None]:
# 3D plot
sns.scatterplot(data=tips, x="total_bill", y="tip", hue="time")

In [None]:
# 4D plot
sns.scatterplot(data=tips, x="total_bill", y="tip", hue="time", size="size", sizes=(10, 300))

In [None]:
# 5D plot, use relplot to have teh legend outside of the axes, increse the height
sns.relplot(data=tips, x="total_bill", y="tip", hue="time", size="size", sizes=(10, 300), style="sex", height=10)

### Weather data examples

In [None]:
# 3D plot
months = ["January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"]
sns.scatterplot(data=weather, x="temp_max", y="temp_min", 
                hue="month", hue_order=months,
                palette="viridis")

In [None]:
# 3D plot: Just extract a sample to better visualize
sample = weather.sample(100)

months = ["January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"]
sns.scatterplot(data=sample, x="temp_max", y="temp_min", 
                hue="month", hue_order=months,
                palette="viridis")

In [None]:
# 4D plot: add different style for the points
sns.scatterplot(data=weather, x="temp_min", y="temp_max", hue="month", palette="viridis", style="year")

In [None]:
# 5D plot: Add different sizes
sns.scatterplot(data=weather, x="temp_min", y="temp_max", hue="month", palette="viridis", style="year", size="weather")

In [None]:
# Figure-level plot
months = ["January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"]
sns.relplot(data=weather, x="temp_max", y="temp_min", 
            hue="month", hue_order=months,
            palette="viridis", kind="scatter",
            style="year", size="weather")

## Stacking plots

The most useful feature offered by the figure-level functions is that they can easily create figures with multiple subplots. For example, instead of stacking the three distributions for each category in the same axes, we can “facet” them by plotting each distribution across the columns of the figure:

```python
sns.displot(data=weather, x="temp_max", col="species")
```

In [None]:
# Four plots = four categories
sns.displot(data=weather, x="temp_max", col="year")

In [None]:
# Four plots = four categories with colors and a legend
sns.displot(data=weather, x="temp_max", hue="year", multiple="stack", palette="viridis", col="year")

In [None]:
# Add the weather categories with hue parameter
sns.displot(data=weather, x="temp_max", hue="weather", multiple="stack", palette="viridis", col="year")

In [None]:
# Another way to stack the histograms with multiple="layer"
sns.displot(data=weather, x="temp_max", hue="weather", multiple="layer", palette="viridis", col="year")

In [None]:
# Stacked density estimators
sns.displot(data=weather, x="temp_max", hue="weather", multiple="layer", col="year", kind="kde", fill=True)

## Seaborn x Matplotlib: How to use `ax` argument in axes-level plots?

Axes-level plots contains an optional argument `ax` that specifies the maplotlib `Axes` object in which to draw the plot. By default, the currently-active `Axes` is used.
It allows to use the tools of matplotlib in a seaborn figure.

In [None]:
# Example 1

import matplotlib.pyplot as plt

# Create the figure and the axes, with a specific figsize
fig, ax = plt.subplots(figsize=(8, 6))

# Plot the data in the axes
sns.histplot(data=weather, x="temp_max", hue="year", multiple="stack", palette="viridis", ax=ax)

# Set titles and labels
ax.set_title("Seaborn x Matplotlib", fontsize=18)
ax.set_xlabel("X-axis", fontsize=18)
ax.set_ylabel("Y-axis", fontsize=18)

# Save the figure
fig.savefig("fig.png")

In [None]:
# Example 2

# Create the figure and the axes, with a specific figsize
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 6), sharey=True, sharex=True)

# Plot an histogram in the first ax
sns.histplot(data=weather, x="temp_max", hue="year", multiple="stack", palette="viridis", ax=axes[0])

# Plot another histogram in the second ax
sns.histplot(data=weather, x="temp_min", hue="year", multiple="stack", palette="viridis", ax=axes[1], legend=False) # disable the legend because it is the same as the other chart

# Set titles and labels in each subplot
axes[0].set_xlabel("Maximum temperature", fontsize=18)
axes[0].set_ylabel("Count", fontsize=18)

axes[1].set_title("Seaborn x Matplotlib", fontsize=18)
axes[1].set_xlabel("Minimum temperature", fontsize=18)
axes[1].set_ylabel("Count", fontsize=18)



## Time series

### Long-form data

Seaborn can consider teh two data forms when dealing with time series: long and wide.

We give the two ways with the weather dataset. Let say that we need to plot the mean temperature versus the time.

In [None]:
# Compute the mean temperature
weather["temp_mean"] = (weather["temp_max"] + weather["temp_min"]) / 2

In [None]:
# Use the function relplot with kind="line"
sns.relplot(data=weather["temp_mean"], kind="line", height=10)

In [None]:
# The function lineplot may be more convenient for resizing the figure
fig, ax = plt.subplots(1, 1, figsize=(18, 6))

# Just pass the desired column, and make sure that the indexes are the date in the good format
sns.lineplot(data=weather["temp_mean"], ax=ax)

In [None]:
# Here we specify with column is X and which one is Y
# Make sure that the date are in a separated column

# Add colors with year column
fig, ax = plt.subplots(1, 1, figsize=(18, 6))
sns.lineplot(data=weather, x="date", y="temp_mean", hue="year", ax=ax, palette="tab10")
ax.set_ylabel("Mean temperature")

In [None]:
# Let say that we want to visualize the temperature versus the month
# Need to group the data for each month
mean_by_month = weather.groupby("month").mean()

# Re-order the months
months = ["January", "February", "March",
          "April", "May", "June",
          "July", "August", "September",
          "October", "November", "December"]
mean_by_month = mean_by_month.loc[months]

fig, ax = plt.subplots(1, 1, figsize=(18, 6))
sns.lineplot(data=mean_by_month, x="month", y="temp_mean", ax=ax, palette="tab10")
ax.set_ylabel("Mean temperature")

In [None]:
# In fact, lineplot can do this in one line!
# It agregates the data points by month, plots the mean curve and the error bar

fig, ax = plt.subplots(1, 1, figsize=(18, 6))
sns.lineplot(data=weather, x="month", y="temp_mean", ax=ax, palette="tab10")
ax.set_ylabel("Mean temperature")

In [None]:
# Same with separated years
fig, ax = plt.subplots(1, 1, figsize=(18, 6))
sns.lineplot(data=weather, x="month", y="temp_mean", hue="year", ax=ax, palette="tab10")
ax.set_ylabel("Mean temperature")

### Wide-form data

It is also possible to plot the time series when the data are in a wide form.

Applying this in the previous example, the wide form can have the same three variables (mean temperature, month and year), but they are organized differently. For example, the months are the indexes, the year are the columns and the values of the temperature are stored in the cells of the data.

In [None]:
wide = pd.pivot_table(data=weather, index="month", columns="year", values="temp_mean")

# Re-order the months
months = ["January", "February", "March",
          "April", "May", "June",
          "July", "August", "September",
          "October", "November", "December"]
wide = wide.loc[months]
wide

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(18, 6))
sns.lineplot(data=wide, ax=ax, palette="tab10")
ax.set_ylabel("Mean temperature")

## Combining multiple views on the data

The functions `pairplot` and `jointplot` allows to combine many charts in order to analyze the data with different points of view.

A joint plot is a 2D visualization chart that displays a scatter plot with density estimates for each axis.
A pair plot is a matrix of scatter plot. The data points are plotted by pairs of columns using a scatter plot for each.

### Joint plot

In [None]:
# Joint plot / scatter

sample = weather.sample(100)

months = ["January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"]
sns.jointplot(data=sample, x="temp_max", y="temp_min", 
                palette="tab10",
                kind="scatter")

In [None]:
# Joint plot / Hex

sample = weather.sample(100)

months = ["January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"]
sns.jointplot(data=sample, x="temp_max", y="temp_min", palette="tab10", kind="hex")

In [None]:
# Joint plot / Density

sample = weather.sample(100)

months = ["January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"]
sns.jointplot(data=sample, x="temp_max", y="temp_min", palette="tab10", kind="kde", fill=True)

In [None]:
# Joint plot / Density

sample = weather.sample(100)

months = ["January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"]
sns.jointplot(data=sample, x="temp_max", y="temp_min", palette="tab10", kind="kde")

In [None]:
# Joint plot / with regression plot

sample = weather.sample(100)

months = ["January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"]
sns.jointplot(data=sample, x="temp_max", y="temp_min", palette="tab10", kind="reg")

In [None]:
# Joint plot with hue parameter, it works with kind=scatter, kde, hist 

sample = weather.sample(100)

months = ["January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"]
sns.jointplot(data=sample, x="temp_max", y="temp_min", 
              hue="season",
              palette="tab10",
              kind="scatter")

### Pair plot

In [None]:
# Pair plot / IT MAY BE LONG TO COMPUTE
# You should take only the numerical columns. Use vars argument
# The kind argument changes the kind of the plot (scatter, kde, hist or reg)

sample = weather.sample(100)

sns.pairplot(data=sample, vars=["precipitation", "temp_max", "temp_min", "wind"], kind="scatter")

In [None]:
# Pair plot / IT MAY BE LONG TO COMPUTE
# You should take only the numerical columns. Use vars argument

sample = weather.sample(100)

sns.pairplot(data=sample, vars=["precipitation", "temp_max", "temp_min", "wind"], hue="season", palette="tab10", kind="scatter")