# Notebook 6: Matplotlib

Here we gain more experience with the Matplotlib plotting library. We will do so in the context of analyzing data from the famous Iris dataset (collected by Edgar Anderson) described in RA Fisher's 1936 work:

`R. A. Fisher (1936). "The use of multiple measurements in taxonomic problems". Annals of Eugenics. 7 (2): 179–188. `

In [None]:
# Load standard libraries
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
%matplotlib inline

# Set seaborn plotting style style
sns.set_style('ticks')

In [None]:
# Load Iris data and preview it
df = sns.load_dataset('iris')
df.head()

In [None]:
# Display the number of entries for each species
df['species'].value_counts()

In [None]:
# Get rows for species setosa
rows = (df['species']=='setosa')

In [None]:
# Extract sepal lengths and widths for setosa
lengths = df.loc[rows, 'sepal_length'].values
widths = df.loc[rows, 'sepal_width'].values

In [None]:
# To create a simple scatter plot, use plt.scatter
plt.scatter(x=widths, y=lengths)

We encounter a slight visualization problem here: not all the data points have unique values. One way to deal with overlapping data points is to decrease the opacity. We can set opacity to 50% by using `alpha=.5` as a keyword argument in `plt.scatter()`.

In [None]:
# Reduce opacity
plt.scatter(x=widths, y=lengths, alpha=.5)

Another common strategy is to add a little bit of "jitter" to each data point.

In [None]:
# Define a function to add some jitter to each data point
def jitter(vec, sigma):
    "Adds a normal random number with std sigma to each entry of the np.array vec"
    return vec + sigma*np.random.randn(len(vec))

# Add a little bit of jitter to each data point
sigma = .05
widths_j = jitter(widths, sigma)
lengths_j = jitter(lengths, sigma)

# Plot the jittered data
plt.scatter(x=widths_j, y=lengths_j, alpha=.5)

Now that we know what to plot, we can begin to style the figure more to our liking.

In [None]:
# Create a figure on which to draw 
fig, ax = plt.subplots(figsize=[3,3])

# Draw points
ax.scatter(x=widths_j,              
           y=lengths_j, 
           alpha=.5,             # Opacity   
           marker='o',           # Marker shape
           s=30,                 # Marker size
           color='C0',           # Marker color
           linewidths=0)         # Removes boundary from marker

# Set the x and y labels
ax.set_xlabel('sepal width (cm)', fontsize=10)
ax.set_ylabel('sepal length (cm)', fontsize=10)

# Set the xlims and ylims
ax.set_xlim([2,5])
ax.set_ylim([3.5,6.5])

# Set xticks and yticks
ax.set_xticks([2,3,4,5])
ax.set_yticks([4,5,6])

# Make a title
ax.set_title('species: setosa', fontsize=10)

# Make sure labels don't get pushed off plot
plt.tight_layout()

# Save figure
file_name = '6_matplotlib_1.pdf'
fig.savefig(file_name)

# Check that figure indeed saved
!open $file_name

We can plot data from other species in different colors by adding a `for` loop.

In [None]:
# Create a figure on which to draw this
fig, ax = plt.subplots(figsize=[4,4])

# Get a list of the three species
species = df['species'].unique()
print("Iterating over the species:", species)

# For each species...
for i, s in enumerate(species):

    # Get rows of dataframe specific to that species
    rows = (df['species']==s)

    # Extract sepal lengths and widths
    lengths = df.loc[rows, 'sepal_length'].values
    widths = df.loc[rows, 'sepal_width'].values
    
    # Add jitter
    lengths_j = jitter(lengths, sigma)
    widths_j = jitter(widths, sigma)

    # Draw points (color is automatically changed in each iteration)
    ax.scatter(x=widths_j,              
               y=lengths_j,
               label=s,              # For legend
               alpha=.5,             # Opacity   
               marker='o',           # Marker shape
               s=50,                 # Marker size
               linewidths=0)         # Removes boundary from marker

    # Set the x and y labels
    ax.set_xlabel('sepal width (cm)', fontsize=10)
    ax.set_ylabel('sepal length (cm)', fontsize=10)

# Adjust lims
ax.set_xlim([1.75, 5.75])
ax.set_ylim([4, 8])

# Set xticks and yticks
ax.set_xticks([2,3,4,5])
ax.set_yticks([4,5,6,7,8])
    
# Create legend
ax.legend(loc='lower right')
    
# Make sure labels don't get pushed off plot
plt.tight_layout()

# Save figure
file_name = '6_matplotlib_2.pdf'
fig.savefig(file_name)

# Check that figure indeed saved
!open $file_name

We can instead create a figure containing three panels, one for each species.

In [None]:
# Create a 3-panel figure
fig, axs = plt.subplots(nrows=1,ncols=3,figsize=[8,3])

# Define dot colors to use
colors = ['C0','C1','C2']

# Iterate over species
species = df['species'].unique()
for i, s in enumerate(species):
    
    # Get rows specific to species
    rows = (df['species']==s)

    # Extract sepal lengths and widths
    lengths = df.loc[rows, 'sepal_length'].values
    widths = df.loc[rows, 'sepal_width'].values
    
    # Add jitter
    lengths_j = jitter(lengths, 0.05)
    widths_j = jitter(widths, 0.05)
    
    # Choose Axes object from axs array
    ax = axs[i]
    
    # Draw points (manually setting color)
    ax.scatter(x=widths_j,              
               y=lengths_j, 
               alpha=.5,             # Opacity   
               marker='o',           # Marker shape
               s=50,                 # Marker size
               color=colors[i],      # Marker color
               linewidths=0)         # Removes boundary from marker

    # Set the x and y labels
    ax.set_xlabel('sepal width (cm)', fontsize=10)
    ax.set_ylabel('sepal length (cm)', fontsize=10)
    
    # Make a title
    ax.set_title('species: %s'%s, fontsize=10)
    
# Make sure labels don't get pushed off plot
plt.tight_layout()

# Label panels
fig.text(x=.01, y=.99, s='(A)', horizontalalignment='left', verticalalignment='top', fontsize=12)
fig.text(x=.35, y=.99, s='(B)', horizontalalignment='left', verticalalignment='top', fontsize=12)
fig.text(x=.67, y=.99, s='(C)', horizontalalignment='left', verticalalignment='top', fontsize=12)

# Save figure
file_name = '6_matplotlib_3.pdf'
fig.savefig(file_name)

# Check written file
! open $file_name

## Exercise

**E6.1** By specifying `sharey=True` and `sharex=True` in `plt.subplots()`, and by conditioning `ax.set_ylabel()` on `i==1`, make all three panels use the same y-axis. You might also want to increase the spacing between plots by specifying `w_pad=3` in `fig.tight_layout()`.

In [None]:
# Answer here