# Basic plots with Matplotlib and Seaborn

![](https://matplotlib.org/_static/logo_dark.svg)

Matplotlib is a comprehensive library for creating static, animated, and interactive visualizations in Python. It is widely used for generating high-quality plots, charts, histograms, and other graphical representations of data.

Matplotlib provides a flexible interface for constructing a wide range of visualizations, making it suitable for both simple exploratory data analysis and complex data visualization tasks.

## Online Documentation

The [matplotlib.org](https://matplotlib.org/) project website is the primary online resource for the library's documentation. It contains the example galleries, FAQs, API documentation, and tutorials.
You can take a look at the [Gallery](https://matplotlib.org/stable/gallery/index.html) and the [tutorial](https://github.com/matplotlib/AnatomyOfMatplotlib) to see the variety of ways one can make figures.

## Basics of Matplotlib

Let's start by taking a closer look at the anatomy of a Matplolib plot.

First of all, Matplotlib is based on two main concepts: 
* **Figure:** a container that holds all elements of the plot, including a title, a legend or a color bar. It can contain one or more axes.
* **Axes:** the area where data is plotted. A given figure can contain many Axes arranged in a grid, but a given Axes object can only be in one Figure. The Axes contains two (or three in the case of 3D visualisation) Axis objects (x and y).

![](https://gitlabsu.sorbonne-universite.fr/scai/data-visualization/-/raw/main/assets/Figure_Axes.png)



## Two approaches for creating a plot

There are two ways to plot data using matplotlib.

**Simple approach**

In the simple approach, specific functions are used as interfaces to the underlying plotting library in matplotlib. It means that figures and axes are implicitly and automatically created to achieve the desired plot. 

For example, calling `plt.plot` will automatically create the necessary figure and axes to achieve the desired plot, as follows:

```python
import matplotlib.pyplot as plt # usual import of matplotlib
import numpy as np # import numpy for data sampling

# Sample data
x = np.linspace(0, 10, 100)
y = np.sin(x)

plt.plot(x, y)
plt.title("Simple Approach", fontsize=18)
plt.xlabel("X-axis", fontsize=18)
plt.ylabel("Y-axis", fontsize=18)
plt.show() # not necessary in Jupyter Notebook and Google Colab
```

Note that calling the function `plt.figure` can be useful if you want to modify the dimensions of the figure:

```python
plt.figure(figsize=(8, 6))
```


**Advanced approach**

In comparison, the advanced (and more complicated) way needs to create the figure and the axes separately and then to add the chart:

```python
 # Create the figure
fig = plt.figure()

# Create axes from the figure, setting the dimensions within the figure
ax = fig.add_axes([0, 0, 1, 1])

# Plot the data in the axes
ax.plot(x, y)

# Set the title and the axis labels
ax.set_title("Advanced Approach", fontsize=18)
ax.set_xlabel("X-axis", fontsize=18)
ax.set_ylabel("Y-axis", fontsize=18)
```

One other way, using the function `plt.subplots` that creates the figure and the axes in one line:

```python
# Create the figure and the axes, with a specific figsize
fig, ax = plt.subplots(figsize=(8, 6))

# Plot the data in the axes
ax.plot(x, y)

# Plot the data in the axes
ax.set_title("Advanced Approach", fontsize=18)
ax.set_xlabel("X-axis", fontsize=18)
ax.set_ylabel("Y-axis", fontsize=18)
```

In [None]:
# Usual import
import matplotlib.pyplot as plt

# Needed in the following
import pandas as pd
import numpy as np

## Data sampling

In [None]:
# generate data
x = np.random.normal(0, 1, 1000)
y = np.random.normal(0, 1, 1000)

## Save a figure

In [None]:
# The figure can be saved using the savefig method

x = np.linspace(0, 10, 100)
y = np.sin(x)

# First need to get a  fig object
fig = plt.figure()

# Plot the data
plt.plot(x, y)

# Save the figure
fig.savefig("some_figure.png")

## One-dimensional charts

One-dimensional charts refers to plotting the values of a single variable:
* histogram and density: distribution of numerical values
* barplot and piechart: distribution of categorical values
* lineplot: lines/timeseries plot

### Histrogram

In [None]:
# Histogram, numerical 1D data

# Create the figure with a specific size
plt.figure(figsize=(8, 6))

# Plot
plt.hist(x)

# Set title and labels
plt.title("Histogram")
plt.xlabel("Values")
plt.ylabel("Frequency")

In [None]:
# Histogram, numerical 1D data with a different style

# Create the figure with a specific size
plt.figure(figsize=(8, 6))

# Plot
plt.hist(x, bins=30, color="skyblue", edgecolor="black")

# Set title and labels
plt.title("Histogram")
plt.xlabel("Values")
plt.ylabel("Frequency")

In [None]:
# With a title and axis labels

plt.hist(x)

plt.title("Example of an histogram")
plt.xlabel("Bins")
plt.ylabel("Frequencies")

In [None]:
# Alternativelly

# First create the figure object
fig = plt.figure()

# Create the axes object
ax = fig.add_axes([0, 0, 1, 1])

# Build the chart within the axes
ax.hist(x, bins=30, color="skyblue", edgecolor="black")

# Set title and labels
ax.set_title("Histogram")
ax.set_xlabel("Values")
ax.set_ylabel("Frequency")

### Boxplot: visualize statistics of numerical values

A boxplot shows the three quartile values of the distribution along with extreme values.

In [None]:
# Boxplot

plt.boxplot(x, labels=["Some variable"])
plt.title("Example of a boxplot")
plt.ylabel("Values")

In [None]:
# Many boxplots / variables comparison

# The data are stored in a dictionnary for a convenient use
data = {
    "variable 1": np.random.normal(0, 1, 1000),
    "variable 2": np.random.normal(1, 1, 1000),
    "variable 3": np.random.normal(3, 1, 1000)
}

plt.boxplot(data.values(), labels=data.keys())
plt.title("Many boxplots")
plt.ylabel("Values")

### Bar plots and Pie chart

Visualizing categories

In [None]:
# Generate example data
categories = ["A", "B", "C", "D"]
counts = [20, 35, 30, 25]

plt.bar(categories, counts, color="orange")
plt.title("Bar Plot")
plt.xlabel("Categories")
plt.ylabel("Values")

In [None]:
# Horizontal bar plot

# Generate example data
categories = ["A", "B", "C", "D"]
counts = [20, 35, 30, 25]

plt.barh(categories, counts, color="green")
plt.title("Bar Plot")
plt.xlabel("Categories")
plt.ylabel("Values")

In [None]:
# Exercise: load the weather dataset and create a barplot with the variable "weather"
# Tip: the first need to count the number of samples for each category

weather = pd.read_csv("https://gitlabsu.sorbonne-universite.fr/scai/data-visualization/-/raw/main/data/weather.csv", sep=";")

# TO DO

In [None]:
# Piechart
# The piechart takes a list of categories and a list of the number of samples for each category.
# Then is compute the proportion of each category. The areas of the chart must sum to 1

# Reference: https://chartio.com/learn/charts/pie-chart-complete-guide/

# Generate example data
categories = ["A", "B", "C", "D"]
counts = [20, 35, 30, 25]

plt.pie(counts, labels=categories, autopct='%1.1f%%')
plt.title("Bar Plot")

### Line plot / time series

In [None]:
# Line plot using plt.plot

# Samples
x = np.linspace(0, 10, 100)
y = np.sin(x)

plt.plot(x, y, color="green")
plt.title("Line Plot")
plt.xlabel("X-axis")
plt.ylabel("Y-axis")

In [None]:
# Many lines: just stack the plt.plot instructions
# We also add a legend

# Samples
x = np.linspace(0, 10, 100)
y = np.sin(x)

# Plots
plt.plot(x, y, label="Line 1")
plt.plot(x+1, y, label="Line 2")
plt.plot(x+2, y, label="Line 3")

# Title and labels
plt.title("Line Plot")
plt.xlabel("X-axis")
plt.ylabel("Y-axis")

# Legend
plt.legend()

In [None]:
# What about time series?

# Sample time series data
data_1 = {
    "Date": pd.date_range(start="2022-01-01", periods=1000),
    "Value": np.random.randn(1000).cumsum() # random generation of number and cumulative sum
}

data_2 = {
    "Date": pd.date_range(start="2022-01-01", periods=1000),
    "Value": np.random.randn(1000).cumsum() # random generation of number and cumulative sum
}

# Set the figure size
plt.figure(figsize=(10, 6))

# Plotting the time series
plt.plot(data_1["Date"], data_1["Value"], linestyle="-", label="Time series 1")
plt.plot(data_2["Date"], data_2["Value"], linestyle="-", label="Time series 2")

# Title and labels
plt.title("Time Series Plot")
plt.xlabel("Date")
plt.ylabel("Value")

# legend
plt.legend()

# Add a grid in the background
plt.grid(True)

## Two-dimensional charts

Two-dimensional charts refers to visualize the values of two variables together. The goal is basically to compare and analyze the correlations of these two variables.
* scatter plot
* heatmap

### Scatter plots

In [None]:
# Case 1: one group of data points

# Generate data
x = np.random.normal(0, 1, 1000)
y = np.random.normal(0, 1, 1000)

# scatter
plt.scatter(x, y, color="red")

plt.title("Some title")
plt.xlabel("Label for X")
plt.ylabel("Label for Y")

# Change the limits of the axis
plt.xlim(-3, 3)
plt.ylim(-3, 3)

In [None]:
# Case 2: two groups of data points

# Generate data
x1 = np.random.normal(0, 1, 100)
y1 = np.random.normal(0, 1, 100)

x2 = np.random.normal(5, 1, 100)
y2 = np.random.normal(5, 1, 100)


# Two scatters for the two groups
# Here we add a 50% transparency to the data points using alpha=.5

plt.scatter(x1, y1, label="Group 1", alpha=.5)
plt.scatter(x2, y2, label="Group 2", alpha=.5)

plt.title("Scatter plot with two groups")
plt.xlabel("Label for X")
plt.ylabel("Label for Y")
plt.legend()

# Add a grid on the Y-axis
plt.grid(True, axis="y")

### Bubble plot

A bubble plot is a scatter plot whose data points have different sizes refering a third variable. It allows to plot data points in 3D.

In [None]:
x3 = np.random.uniform(0, 1, 100)
plt.scatter(x1, y1, s=x3*1000, label="Group 1", alpha=.5)

### Heatmap

Heatmaps are useful for visualizing the distribution or relationships between two-dimensional data, such as correlation matrices, spatial data, or density plots.

The function in matplotlib is `imshow`. Specifying the colormap (`cmap`) modifies the color contrast of the pixels. Setting the interpolation to `nearest` allows avoiding smoothing.

Colormaps documentation: https://matplotlib.org/stable/users/explain/colors/colormaps.html

In [None]:
# Generate sample data for the heatmap, a matrix of size 10x10
data = np.random.rand(10, 10)

In [None]:
# Create the heatmap

# Create the figure
plt.figure(figsize=(8, 6))

# Create heatmap with a colorbar indicating the values of the cells
plt.imshow(data, cmap="viridis", interpolation="nearest")
plt.colorbar(label="Value")
plt.title("Heatmap Example")

In general, your data is not a matrix format. You might prepare the data before creating the heatmap.

For example, the data contains two categorical variables `categorical_1` and `categorical_2` and one numerical variable `values`.
Let's say that you want to visualize the values of `numerical` for each category given by the two features together.

```yaml
categorical_1  categorical_2  values
A              AA             50
B              BB             45
C              CC             60
...            ...            ...
```

In [None]:
# Generate a toy example
cat_1 = np.random.choice(["A", "B", "C", "D", "E", "F"], 1000)
cat_2 = np.random.choice(["AA", "BB", "CC", "DD", "EE", "FF"], 1000)
values = np.random.uniform(0, 100, 1000)
data = pd.DataFrame({"categorical_1": cat_1, "categorical_2": cat_2, "values": values})
data

In [None]:
# Build a pivot table from the dataset
df = data.pivot_table(index="categorical_1", columns="categorical_2", values="values", aggfunc="mean")
df

In [None]:
# Create the figure
fit, ax = plt.subplots(figsize=(12, 8))

# Create the heatmap
plt.imshow(df)

# Change the ticks labels to the categories
ax.set_xticks(np.arange(df.shape[1]), labels=list(df.columns))
ax.set_yticks(np.arange(df.shape[0]), labels=list(df.index))

# Axis labels and colorbar
ax.set_xlabel("categorical_2")
ax.set_ylabel("categorical_1")
plt.colorbar(label="Value")

In [None]:
# Exercise: apply the previous cells to the weather dataset. Ex: visualize the mean temperature for each month and each year
# Tip: you first need to extract the year from the dates

# TO DO

## Creating multiple subplots

In matplotlib, a subplot refers to a grid-based layout of multiple plots within a single figure. Subplots allow you to visualize and compare different aspects of your data side by side, making it easier to explore relationships and patterns. 

The `plt.subplot()` function is used to create subplots, and it takes three arguments: the number of rows (`nrows`), the number of columns (`ncols`), and the index of the subplot you want to create. The index starts from 1 and increases row-wise. For example, `plt.subplot(1, 2, 1)` creates a grid with 1 row and 2 columns and selects the first subplot in the grid. Similarly, `plt.subplot(1, 2, 2)` selects the second subplot, and so on.

Example:

```python
# Create a grid of two plots side-by-side
# plot an histogram in the first plot and a boxplot in the second one

plt.subplot(1, 2, 1)
plt.hist(x)

plt.subplot(1, 2, 2)
plt.boxplot(x)
```

In [None]:
# Exercise: use the function plt.subplot with one of the datasets.

# TO DO


Another convient way to create subplots uses the function `plt.subplots`. It creates the figure and the axes in the same time allowing to pass additional keyword arguments the (internal) `plt.figure` call.
Otherwise, the function `plt.subplots` takes two arguments: the number of rows (`nrows`), the number of columns (`ncols`). It returns a figure object and the axes objects (an array of possibly many axes).

Example:

```python
# Create a grid of two plots side-by-side, set the figure size.
# plot an histogram in the first plot and a boxplot in the second one

fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 8))
axes[0].hist(x)
axes[1].boxplot(x)
```

In [None]:
# Exercise: use the function plt.subplots with one of the datasets.

# TO DO

### Stacking subplots

Stacking many subplots is possible using the two previous functions by setting nrows and ncols larger than 1.
In particular, when calling the function `plt.subplot`, we just need to increase the plot index as shown in the example below:

```python
# Create a grid of four plots, 2 rows and 2 cols

plt.subplot(2, 2, 1)
plt.hist(x)

plt.subplot(2, 2, 2)
plt.boxplot(x)

plt.subplot(2, 2, 3)
plt.scatter(x, x)

plt.subplot(2, 2, 4)
plt.hist(y)
```

But there is a difference when using the function `plt.subplots`. The axes returned by the function is a 2D numpy array. So the access to the axes object is a bit more complicated, especially if you want to customize the plots.

```python
fig, axes = plt.subplots(2, 2)

axs[0, 0].plot(x, y)
axs[0, 0].set_title('Axis [0, 0]')

axs[0, 1].plot(x, y)
axs[0, 1].set_title('Axis [0, 1]')

axs[1, 0].plot(x, -y)
axs[1, 0].set_title('Axis [1, 0]')

axs[1, 1].plot(x, -y)
axs[1, 1].set_title('Axis [1, 1]')
```

If you have to set parameters for each subplot it's handy to flatten the `axes` object using axes.flat and to iterate over it. Example:

```python
for ax in axes.flat:
    ax.set_xlabel("X Label")
    ax.set_ylabel("Y Label")
```

In [None]:
# Exercise: use the function plt.subplots with one of the datasets.

# TO DO

### Sharing axes

By default, each axes is scaled individually. Thus, if the ranges are different the tick values of the subplots do not align.

In [None]:
# Data for the example
x = np.linspace(0, 2*np.pi, 400)
y = np.sin(x**2)

In [None]:
# Create the subplot and store the axes objects in ax1 and ax2
fig, (ax1, ax2) = plt.subplots(2)
ax1.plot(x, y)
ax2.plot(x + 1, -y)

You can use `sharex` or `sharey` to align the horizontal or vertical axis.

In [None]:
# sharex=True
fig, (ax1, ax2) = plt.subplots(2, sharex=True) # or sharex="all"
ax1.plot(x, y)
ax2.plot(x + 1, -y)

# Seaborn library

Seaborn is a Python visualization library based on matplotlib that provides a high-level interface for creating attractive and informative statistical graphics.

**Statistical data visualization**

Seaborn simplifies the creation of complex statistical plots by providing functions for common visualization tasks such as plotting univariate and bivariate distributions, visualizing linear relationships, and exploring pairwise relationships in datasets.

**Integration with pandas**

Seaborn seamlessly integrates with pandas data frames, allowing users to directly pass DataFrame objects to plotting functions. This integration makes it easy to work with tabular data and visualize patterns and relationships within datasets.

**Default Aesthetics**

Seaborn enhances the visual appeal of plots by providing attractive default aesthetics, including color palettes, themes, and styles. These defaults can be easily customized to match specific preferences or the requirements of a particular analysis.

**Advanced Features**

Seaborn includes advanced features for visualizing complex relationships in data, such as conditional plots, faceting, and multivariate analysis. These features enable users to explore and visualize high-dimensional datasets effectively.

*Remark: As seaborn is based on matplotlib, many matplotlib functions can (or should) be used with seaborn, in particular for titles, axis labels, legends, etc.*



![](https://seaborn.pydata.org/_static/logo-wide-lightbg.svg)

Official documentation: https://seaborn.pydata.org


## Usual import

In [None]:
import seaborn as sns

## Change the visual theme

Set aspects of the visual theme for all matplotlib and seaborn plots.

Be aware that calling the function `set_theme` will change the style of every matplotlib figures.

See https://seaborn.pydata.org/generated/seaborn.set_theme.html

In [None]:
sns.set_theme()

## Histogram and density

In [None]:
data = np.random.normal(size=10000)

In [None]:
# Simple histogram
sns.histplot(data)

In [None]:
# Histogram and density estimator in the same figure
sns.histplot(data, kde=True)

# Some matplotlib functions
plt.title("A seaborn histogram")
plt.xlabel("Values")
plt.ylabel("Frequency")

In [None]:
# Change the background style
sns.set_theme(style="whitegrid")

# Density estimator
sns.kdeplot(data)

# Some matplotlib functions
plt.title("Density estimator")
plt.xlabel("Values")
plt.ylabel("Frequency")

In [None]:
# Change the background style
sns.set_theme(style="whitegrid")

# Density estimator
sns.kdeplot(data, fill=True, color="orange")

# Some matplotlib functions
plt.title("Density estimator")
plt.xlabel("Values")
plt.ylabel("Frequency")

In [None]:
# Two density plots in the same chart

sns.kdeplot(data, fill=True, label="Data 1")

# Generate another dataset and plot it
data = np.random.normal(4, 1, size=10000)
sns.kdeplot(data, fill=True, color="orange", label="Data 2")

plt.legend()


In [None]:
# Another way using a data frame

data_1 = np.random.normal(0, 1, size=1000)
data_2 = np.random.normal(4, 1, size=1000)
df = pd.DataFrame({"data_1": data_1, "data_2": data_2})

# Just one boxplot calling
sns.kdeplot(data=df, fill=True)

## Boxplot

In [None]:
# Default theme
sns.set_theme()

In [None]:
# Simple boxplot, horizontally
sns.boxplot(x=data)

In [None]:
# Simple boxplot, horizontally
sns.boxplot(y=data)

In [None]:
# What if we try to plot two boxplots in the same chart?
# Is it what you want?

data = np.random.normal(0, 1, size=1000)
sns.boxplot(y=data)

# Generate another dataset and plot it
data = np.random.normal(4, 1, size=1000)
sns.boxplot(y=data)

In [None]:
# SOLUTION: Pass a data frame instead of the two series of data

# Create the data frame (or import it)
data_1 = np.random.normal(0, 1, size=1000)
data_2 = np.random.normal(4, 1, size=1000)
df = pd.DataFrame({"data_1": data_1, "data_2": data_2})

# Just one boxplot calling
sns.boxplot(data=df)

In [None]:
# Exercise: use one of the CSV datasets, visualize and compare distributions with histplot, kdeplot or boxplot

# TO DO

## Barplot

In [None]:
# Create a data frame with categories and values
df = pd.DataFrame({
    "categories": ["A", "B", "C"],
    "values": [10, 20, 30]
})

In [None]:
# Barplot, vertically
# Use x and y argument with the names of the columns
sns.barplot(data=df, x="categories", y="values")

In [None]:
# Barplot, horizontally
# Use x and y argument with the names of the columns
sns.barplot(data=df, y="categories", x="values")

## Line plot and time series

In [None]:
# Create a data frame with two time series
data = {
    "date": pd.date_range(start="2022-01-01", periods=1000),
    "series_1": np.random.randn(1000).cumsum(), # random generation of number and cumulative sum
    "series_2": np.random.randn(1000).cumsum() # random generation of number and cumulative sum
}

df = pd.DataFrame(data)
df

In [None]:
sns.set_theme(style="whitegrid")
plt.figure(figsize=(10, 6))
sns.lineplot(data=df)

## Scatter plot

In [None]:
# data
data_1 = np.random.normal(0, 1, size=1000)
data_2 = np.random.normal(4, 1, size=1000)
df = pd.DataFrame({"data_1": data_1, "data_2": data_2})

In [None]:
# Create a scatter plot from two columns in a data frame

# If you just pass the data frame in the function, the X-axis will be the position of each data point, and two groups of data will be displayed
sns.scatterplot(data=df)

In [None]:
# Create a scatter plot from two columns in a data frame

# Pass x and y arguments with the names of the columns
sns.scatterplot(data=df, x="data_1", y="data_2", alpha=.5)

## Bubble plot

In [None]:
df["data_3"] = np.random.exponential(1, size=1000)
sns.scatterplot(data=df, x="data_1", y="data_2", size="data_3", alpha=.5, legend=False, sizes=(20, 200))

## Heatmap

In [None]:
# In the same way than before, we need to tranform our data into a pivot table
# Let's say that we want to visualize numerical values given by one column from categories given from two other columns

# Generate a toy example
cat_1 = np.random.choice(["A", "B", "C", "D", "E", "F"], 1000)
cat_2 = np.random.choice(["AA", "BB", "CC", "DD", "EE", "FF"], 1000)
values = np.random.uniform(0, 100, 1000)
data = pd.DataFrame({"categorical_1": cat_1, "categorical_2": cat_2, "values": values})

# Build a pivot table from the dataset
df = data.pivot_table(index="categorical_1", columns="categorical_2", values="values", aggfunc="mean")
df

In [None]:
sns.heatmap(data=df, cbar=True, annot=False, cmap="viridis")

In [None]:
sns.heatmap(data=df, cbar=False, annot=True, cmap="seismic")

## Exercise: Anscombe quartet

Reproduce the Anscombe quartet phenomenon: compute basic statistics and visualize the data points for each dataset.

In [None]:
anscombe_data = sns.load_dataset("anscombe")
anscombe_data

In [None]:
# Compute basic statistics using pandas: means, variances, correlations, etc.

# TO DO

In [None]:
# Visualize the data points

# TO DO