# Plotting with Matplotlib

Now you have learned how to handle numerical data and perform computations on them through NumPy library, it's time to learn how to visualize your data and computation results. By far the most popular plotting package for Python is **Matplotlib**.

Just like NumPy, you start using Matplotlib by importing the package. However, this time we are going to import a particular subpackage that houses all of the common plotting functions.

In [None]:
import matplotlib.pyplot as plt

Furthermore, we are going to tell Jupyter notebook that we want all plots to show up inside the notebook:

In [None]:
%matplotlib inline

As we are going to be working with numerical data during the plotting, we are also going to import `numpy`:

In [None]:
import numpy as np

This pattern of importing multiple packages at the top of your notebook/script is a very common practice when working in Python!

## Your first plot

Let's dive right into plotting something! Let's start simple with a line: $y=x$

In [None]:
x = np.linspace(0, 5, 3)
y = x
plt.plot(x, y)

That was quite simple wasn't it? Let's now take some time to understand what just happened.

## Plotting in Matplotlib

Let's take a closer look at the `plt.plot` function we just used:

In [None]:
plt.plot?

There is a lot of information in that docstring, but the essential information is summarized by the first sentence:
> Plot y versus x as lines and/or markers.

When you plot with `plt.plot`, you are plotting a **sequence of (x, y) points** connected by a straight line. To see this point better, let's try plotting a slightly more complex function: $y = x^2$

In [None]:
x = np.linspace(0, 5, 3)
y = x**2
plt.plot(x, y)

Notice that our plot doesn't look too smooth. This is because you used only 3 pairs of points in your plot, as can be seen by inspecting your `x` and `y`.

In [None]:
x

In [None]:
y

To get a smoother curve, you have to plot at many more points inbetween:

In [None]:
x = np.linspace(0, 5, 100)  # using 100 points instead
y = x**2
plt.plot(x, y)

Now your curve looks much smoother!

## Customizing the plotting of data points

By default `plot` simply draws a line segment between two consecutive points. You can customize the plot styling by passing in a string that specifies the **line and marker style**.

In [None]:
x = np.linspace(0, 5, 10)
y = x**2
plt.plot(x, y, 'o') # plots circle at the data points with no lines connecting them

In [None]:
x = np.linspace(0, 5, 10)
y = x**2
plt.plot(x, y, 'o-') # plots circle connected with line

In [None]:
x = np.linspace(0, 5, 10)
y = x**2
plt.plot(x, y, '--') # plot dashed line

In [None]:
x = np.linspace(0, 5, 10)
y = x**2
plt.plot(x, y, 'D--') # plot dashed line connecting diamonds

You can checkout Matplotlib's references for more information about available [marker](https://matplotlib.org/api/markers_api.html) and [line](https://matplotlib.org/gallery/lines_bars_and_markers/line_styles_reference.html) styles.

## Adding more information to your plot

When you plot data, it is always a good idea to **label the axis** and give some title to your plot, so it beceomes easier for others to understand what you just plotted. You can use `plt.xlabel`, `plt.ylabel`, and `plt.title` to add these information:

In [None]:
angle = np.linspace(-2*np.pi, 2*np.pi)
value = np.cos(angle)

plt.plot(angle, value)

# add x and y labels
plt.xlabel('Angle (radians)')
plt.ylabel('Value')

# add a title
plt.title('Trigonometric function')

## Plotting multiple lines

You can plot multiple curves together in the same plot:

In [None]:
angle = np.linspace(-2*np.pi, 2*np.pi)
value1 = np.cos(angle)
value2 = np.sin(angle)

plt.plot(angle, value1)
plt.plot(angle, value2)


plt.xlabel('Angle (radians)')
plt.ylabel('Value')

plt.title('Trigonometric function')

While separate lines are color coded, it is not easy to tell what curve is what. You would want **legend** for the lines. Thankfully, there is a function for that called `plt.legend`!

In [None]:
angle = np.linspace(-2*np.pi, 2*np.pi)
value1 = np.cos(angle)
value2 = np.sin(angle)

plt.plot(angle, value1)
plt.plot(angle, value2)


plt.xlabel('Angle (radians)')
plt.ylabel('Value')

plt.title('Trigonometric function')
plt.legend()

Notice that calling `plt.legend` didn't quite give you a nice legend, but rather gave you a warning: `No handles with labels found to put in legend`. This is Matplotlib's way of saying that you haven't labeled any lines to generate legend for!

When you plot, you can specify the keyword argument `label` in `plot`:

In [None]:
angle = np.linspace(-2*np.pi, 2*np.pi)
value1 = np.cos(angle)
value2 = np.sin(angle)

plt.plot(angle, value1, label='cos(x)')
plt.plot(angle, value2, label='sin(x)')


plt.xlabel('Angle (radians)')
plt.ylabel('Value')

plt.title('Trigonometric function')
plt.legend()

## Customizing plots with keyword arguments

It turns out that you can customize the behavior of the plotting function by using additional keyword arguments. Let's take a look at a few:

### Changing color of the line

In [None]:
x = np.linspace(-10, 10, 50)
y = x**3 - x**2 - 50 * x - 10

plt.plot(x, y, color='red')

### Changing the thickness of the line

In [None]:
x = np.linspace(-10, 10, 50)
y = x**3 - x**2 - 50 * x - 10

plt.plot(x, y, linewidth=5)

### Changing marker and marker size

In [None]:
x = np.linspace(-10, 10, 20)
y = x**3 - x**2 - 50 * x - 10

plt.plot(x, y, marker='s')

In [None]:
x = np.linspace(-10, 10, 20)
y = x**3 - x**2 - 50 * x - 10

plt.plot(x, y, marker='s', markersize=10)

# Changing the properties of the axis

You can also heavily customize the properties of the plotting area, such as axis limits and tick placements.

Consider the following plot

In [None]:
n = 100
r = np.linspace(0, 1, n)
theta = np.linspace(0, 8*np.pi, n)
x = r * np.cos(theta)
y = r * np.sin(theta)

plt.plot(x, y)

Let's set the limits of x and y axis to be exactly -1 to 1 using `plt.xlim` and `plt.ylim`

In [None]:
n = 100
r = np.linspace(0, 1, n)
theta = np.linspace(0, 8*np.pi, n)
x = r * np.cos(theta)
y = r * np.sin(theta)

plt.plot(x, y)

# set x and y limits
plt.xlim([-1, 1])
plt.ylim([-1, 1])

We also want the ticks to occur at only [-1, 0, 1]. Use `plt.xticks` and `plt.yticks`!

In [None]:
n = 100
r = np.linspace(0, 1, n)
theta = np.linspace(0, 8*np.pi, n)
x = r * np.cos(theta)
y = r * np.sin(theta)

plt.plot(x, y)

# set x and y limits
plt.xlim([-1, 1])
plt.ylim([-1, 1])

# set ticks
plt.xticks([-1, 0, 1])
plt.yticks([-1, 0, 1])

Finally, it would also be nice if the plot can actually be a square. Use `plt.axis` to control axis **aspect ratio**!

In [None]:
n = 100
r = np.linspace(0, 1, n)
theta = np.linspace(0, 8*np.pi, n)
x = r * np.cos(theta)
y = r * np.sin(theta)

plt.plot(x, y)

# set x and y limits
plt.xlim([-1, 1])
plt.ylim([-1, 1])

# set ticks
plt.xticks([-1, 0, 1])
plt.yticks([-1, 0, 1])

# set aspect ratio to "square"
plt.axis('square')

In [None]:
n = 100
r = np.linspace(0, 1, n)
theta = np.linspace(0, 8*np.pi, n)
x = r * np.cos(theta)
y = r * np.sin(theta)

plt.plot(x, y)

# set x and y limits
plt.xlim([-1, 1])
plt.ylim([-1, 1])

# set ticks
plt.xticks([-1, 0, 1])
plt.yticks([-1, 0, 1])

# set aspect ratio to "square"
plt.axis('square')

## A figure with multiple subplots

In [None]:
x = np.linspace(-1, 1, 100)

plt.subplot(2, 3, 1)
plt.plot(x, x**0)

plt.subplot(2, 3, 2)
plt.plot(x, x**1)

plt.subplot(2, 3, 3)
plt.plot(x, x**2)

plt.subplot(2, 3, 4)
plt.plot(x, x**3)

plt.subplot(2, 3, 5)
plt.plot(x, x**4)

plt.subplot(2, 3, 6)
plt.plot(x, x**5)

Notice that labels are overlapping. You can call `plt.tight_layout` to let Matplotlib automatically adjust spacing between subplots to avoid overlap.

In [None]:
x = np.linspace(-1, 1, 100)

plt.subplot(2, 3, 1)
plt.plot(x, x**0)

plt.subplot(2, 3, 2)
plt.plot(x, x**1)

plt.subplot(2, 3, 3)
plt.plot(x, x**2)

plt.subplot(2, 3, 4)
plt.plot(x, x**3)

plt.subplot(2, 3, 5)
plt.plot(x, x**4)

plt.subplot(2, 3, 6)
plt.plot(x, x**5)

plt.tight_layout()

# Closer look at Matplotlib: Figures, Subplots, Axes, and Ticks

So far we have been using Matplotlib's convenience interface for making new figures and start plotting. While this works very well, learning a bit about how figures are structures in Matplotlib can turn you from a Matplotlib novice to an Matplotlib wizard!

## Figure object

So far, you have been creating a Matplotlib **figure** object when you used functions like `plt.plot`. But you can explicitly create a new **figure object** that serves as the container for everything you plot.

In [None]:
fig = plt.figure()  # create a new empty figure

As an empty figure there is not much in it to see. Let's go ahead and add some **axes** object!

## Axes object

While figure represent the whole figure including the margins, an **axes** object represents the part of the figure where you actually plot things.

Although you can create an axes object and then add it to the figure:

In [None]:
fig = plt.figure()  # create a new empty figure
ax = plt.axes()
fig.add_axes(ax)

this is rather inconveneint. Instead, you can use `plt.subplots` function to create a new figure and an axes already placed inside.

In [None]:
fig, ax = plt.subplots()

When creating figure through this **object based** approach, you would use the `ax` object to plot:

In [None]:
fig, ax = plt.subplots()

x = np.linspace(0, 1, 30)
ax.plot(x, x**2, 'o', color='orange')

Now at this point, you may be wondering what's the point of all of these exercises? It may look like this is making your life unnecessary complicated. Just bear with me for a bit more, and you'll soon see the power of this method!

With this method, you can actually create two figures side by side:

In [None]:
# create two figures, with axes in each
fig1, ax1 = plt.subplots()

fig2, ax2 = plt.subplots()

x = np.linspace(0, 1, 30)

ax1.plot(x, x**2, 'o', color='orange')
ax2.plot(x, x**3, 'o', color='blue')

# plot again in ax1
ax1.plot(x, x, 'x', color='green')

But the true strenght of this approach really shows through when working with *subplots* **programmmatically**.

In [None]:
fig, axs =  plt.subplots(3, 3) # create 3 x 3 subplots - 9 axes placed in an NumPy array

In [None]:
axs

In [None]:
axs.shape

In [None]:
fig, axs =  plt.subplots(3, 3) # create 3 x 3 subplots - 9 axes placed in an NumPy array

x = np.linspace(-1, 1, 100)
for ax in axs.ravel():   # flatten and visit one ax at a time
    ax.plot(x, x)

More intersting - remember `enumerate`?

In [None]:
fig, axs =  plt.subplots(3, 3) # create 3 x 3 subplots - 9 axes placed in an NumPy array

x = np.linspace(-1, 1, 100)
for i, ax in enumerate(axs.ravel()):   # flatten and visit one ax at a time
    ax.plot(x, x**i)
    ax.set_title('y=x^{}'.format(i))

In [None]:
fig, axs =  plt.subplots(3, 3) # create 3 x 3 subplots - 9 axes placed in an NumPy array

x = np.linspace(-1, 1, 100)
for i, ax in enumerate(axs.ravel()):   # flatten and visit one ax at a time
    ax.plot(x, x**i)
    ax.set_title('y=x^{}'.format(i))
    
fig.tight_layout()

Let's increase the size of figure by setting figure size in inches

In [None]:
fig, axs =  plt.subplots(3, 3, figsize=(12, 12)) # create 3 x 3 subplots - 9 axes placed in an NumPy array

x = np.linspace(-1, 1, 100)
for i, ax in enumerate(axs.ravel()):   # flatten and visit one ax at a time
    ax.plot(x, x**i)
    ax.set_title('$y=x^{}$'.format(i))
    
fig.tight_layout()

Alternatively (or in addition to), you can control the **resolution** of the plot by setting **dpi**

In [None]:
fig, axs =  plt.subplots(3, 3, figsize=(12, 12), dpi=200) # create 3 x 3 subplots - 9 axes placed in an NumPy array

x = np.linspace(-1, 1, 100)
for i, ax in enumerate(axs.ravel()):   # flatten and visit one ax at a time
    ax.plot(x, x**i)
    ax.set_title('$y=x^{}$'.format(i))
    
fig.tight_layout()

# Other kinds of plots

While `plt.plot` gives you **line plots** that is very useful for plotting sequential information, you would obviously want to be able to generate other kinds of plots. Don't worry! Matplotlib got you covered!

## Bar plots

In [None]:
n = 12
X = np.arange(n)
Y = np.random.rand(n)
plt.bar(X, Y)

## Scatter plots

You can make **scatter plots** using `plt.scatter`

In [None]:
x = np.random.randn(100)
y = np.random.randn(100)

plt.scatter(x, y)

There is an optional argument `c` that can be used to control the **color** of the points based on some value.

In [None]:
x = np.random.randn(100)
y = np.random.randn(100)

delta = np.abs(x - y)

plt.scatter(x, y, c=delta)
plt.colorbar()

Here we are showing the ground truth linear model, and some samples with noise!

In [None]:
x = np.linspace(0, 1, 100)
a = 5
b = 10

# linear model
y = a * x + b 

plt.plot(x, y, label='ground truth')

n = 30
sample_x = np.linspace(0, 1, n)
noise = np.random.randn(n) * 0.5
sample_y = a * sample_x + b + noise

plt.scatter(sample_x, sample_y, label='samples')

plt.legend()

## Images

Plotting image is also a very commong thing to do. It turns out you can simply treat an appropriately shaped NumPy array as an image!

In [None]:
img = np.random.randn(128, 128)

...and display it using `plt.imshow`

In [None]:
plt.imshow(img)

For simple images like this where there is no color information, you can display the image using various **color maps**.

In [None]:
plt.imshow(img, cmap='gray')

In [None]:
plt.imshow(img, cmap='jet')

You can load a sample image from `scipy` package:

In [None]:
from scipy.misc import face

In [None]:
img = face(gray=True) # get gray scale image of a "face"

In [None]:
plt.imshow(img)

In [None]:
plt.imshow(img, cmap='gray')

In [None]:
plt.imshow(img, cmap='cool')