# STEP-BY-STEP: FIGURES

In this notebook we will be running through the creation of a simple multi-panel figure using matplotlib. 

I will highlight complex changes to the code  using " # <======= #" and a short description. 

Click through each cell below by pressing [ SHIFT + ENTER ] and trying out the EXPLORE activities when they come up!


In [None]:
# Let's start by creating a list of integers from 0 to 20  

list_A = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20]

In [None]:
# Let's print the items in the list using a for loop...
for x in list_A:
    print(x)

In [None]:
# Instead of explicitly setting these values, we can instead use the linspace function from numpy:
# NOTE: As we are starting at zero, we need 11 steps to reach 20

import numpy as np               # <======= # We only need to import once per notebook

list_A = np.linspace(0,20,11)    # <======= # np.linspace(START,STOP,STEPS)

In [None]:
# EXPLORE : If you press [ SHIFT + TAB ] after clicking inside a function, the documentation describing that function will pop up!
# EXPLORE : Try it above by clicking inside of "linspace" and pressing [ SHIFT + TAB ] (Note, this will only work after importing a package)

In [None]:
# let's check that list_A is still made up of the same integers using a for loop

for x in list_A:
    print(x)

In [None]:
# Uh oh! Linspace uses floating point numbers by default, if we want to use integers we can set "dtype=int" within the linspace function

list_A = np.linspace(0,20,11, dtype=int)

for n in list_A:
    print(n)

In [None]:
# Much better! We can now make our first graph using the most basic plotting function for matplotlib: "plt.plot()"
# NOTE: matplotlib can handle floats and ints, I'm just showing you how to specify this if you do need it!

import matplotlib.pyplot as plt

plt.plot(list_A)

plt.show()

In [None]:
# Beautiful! Here, we can see that x-axis values of the line have been assumed to be the postition in the list and the list_A values are plotted on the y-axis
# This is easier to see if we add the individual datapoints to the graph by adding marker="o"

plt.plot(list_A, marker="o") # <======= # MODIFIED

plt.show()

In [None]:
# EXPLORE : Above, try setting marker to "x", "^", or "8"

In [None]:
# By setting the x- and y-axis limits to span the same range (0-20), we can start to see the true slope of the line we've drawn

plt.plot(list_A, marker="o")

plt.xlim(0,20) # <======= # NEW
plt.ylim(0,20) # <======= # NEW

plt.show()

In [None]:
# The default figure size for size for matplotlib is 6x4 inches,
# we can further reveal the true slope of the line by first creating a figure object called "fig" and setting "figsize" to 4x4 inches. 

fig = plt.figure(figsize=(4, 4)) # <======= # NEW

plt.plot(list_A, marker="o")

plt.xlim(0,20)
plt.ylim(0,20)

plt.show()

In [None]:
# To create multiple figure panels, we can add individual axes to the figure
# First, we create an "ax" object after the "fig" object created in the first line.

# NOTE: We use "plt.subplots" now instead of "plt.figure" and can set the number of columns and rows using "ncols=" and "nrows="
# NOTE: Now we have to specify which axis we want to plot to ( e.g. "ax[0].plot()" ) instead of using "plt.plot()"
# NOTE: To set the x- or y-limits on an individual ax, you need to use "set_xlim" instead of "xlim", a similar case is true for other properties 

fig, ax = plt.subplots(ncols=2,figsize=(4, 4)) # <======= # MODIFIED

ax[0].plot(list_A, marker="o")                 # <======= # MODIFIED

ax[0].set_xlim(0,20)                           # <======= # MODIFIED
ax[0].set_ylim(0,20)                           # <======= # MODIFIED

plt.show()

In [None]:
# We can switch this location this graph is plotted by changing to "ax[1].plot()"

fig, ax = plt.subplots(ncols=2,figsize=(4, 4)) 

ax[1].plot(list_A, marker="o")                 # <======= # MODIFIED

ax[1].set_xlim(0,20)                           # <======= # MODIFIED
ax[1].set_ylim(0,20)                           # <======= # MODIFIED

plt.show()

In [None]:
# EXPLORE : What happens if you still use "plt.plot()" instead of "ax[0].plot()"?

In [None]:
# Uh oh... this figure is a bit squished!
# Let's increase its size by doubling the width value (1st value of the tupple) of "figsize=" 

fig, ax = plt.subplots(ncols=2,figsize=(8, 4)) # <======= # MODIFIED

ax[1].plot(list_A, marker="o")

ax[1].set_xlim(0,20)
ax[1].set_ylim(0,20)

plt.show()

In [None]:
# But how do you know which ax is which? 
# To find out, we can add an annotation calling each...

fig, ax = plt.subplots(ncols=2,figsize=(8, 4))

ax[1].plot(list_A, marker="o")

ax[0].annotate("ax[0]",xy=(0.275,0.5),xycoords='axes fraction', fontsize=30,color="red") # <======= # NEW
ax[1].annotate("ax[1]",xy=(0.275,0.5),xycoords='axes fraction', fontsize=30,color="red") # <======= # NEW
 
ax[1].set_xlim(0,20)
ax[1].set_ylim(0,20)

plt.show()

In [None]:
# These plotting functions become a lot more useful when you link them to other functions! 
# For example, instead of refering to list_A directly, we can instead call the linspace function from within the ax[1].plot() function

fig, ax = plt.subplots(ncols=2,figsize=(8, 4))

ax[1].plot(np.linspace(0,20,11), marker="o") # <======= # MODIFIED

ax[1].set_xlim(0,20)
ax[1].set_ylim(0,20)
plt.show()

In [None]:
# It's the same as before... so now we have a way of rapdily adding more data to our graph!
# Below, let's add different linspace functions to our graph...

# NOTE: the color of the different datasets is specified automatically

fig, ax = plt.subplots(ncols=2,figsize=(8, 4))

ax[1].plot(np.linspace(0,5,11), marker="o")
ax[1].plot(np.linspace(0,10,11), marker="o") # <======= # NEW
ax[1].plot(np.linspace(0,15,11), marker="o") # <======= # NEW
ax[1].plot(np.linspace(0,20,11), marker="o") # <======= # NEW

ax[1].set_xlim(0,20)
ax[1].set_ylim(0,20)

plt.show()

In [None]:
# As we are simply changing the second argument in linspace (the maximum value) to draw each of these lines, 
# we can instead use a for loop to draw them!

fig, ax = plt.subplots(ncols=2,figsize=(8, 4))

for x in [5,10,15,20]:                            # <======= # NEW
    ax[1].plot(np.linspace(0,x,11), marker="o")   # <======= # MODIFIED 

ax[1].set_xlim(0,20)
ax[1].set_ylim(0,20)

plt.show()

In [None]:
# Of course, this for loop could itself use a linspace function instead of explicit list of values "[5,10,15,20]"

fig, ax = plt.subplots(ncols=2,figsize=(8, 4))

for x in np.linspace(5,20,4,dtype=int):          # <======= # MODIFIED 
    ax[1].plot(np.linspace(0,x,11), marker="o")

ax[1].set_xlim(0,20)
ax[1].set_ylim(0,20)

plt.show()

In [None]:
# Now we're getting complicated! 
# We can add more intervales to that range by setting 16 steps intead of 4

fig, ax = plt.subplots(ncols=2,figsize=(8, 4))

for x in np.linspace(5,20,16,dtype=int):         # <======= # MODIFIED STEP number from 4 to 16
    ax[1].plot(np.linspace(0,x,11), marker="o")

ax[1].set_xlim(0,20)
ax[1].set_ylim(0,20)
plt.show()

In [None]:
# We can also use this for loop to set data on ax[0]

# NOTE: For simplicity, let's forget about setting x- and y-axis limits...

fig, ax = plt.subplots(ncols=2,figsize=(8, 4))

for x in np.linspace(5,20,16,dtype=int):
    ax[0].plot(np.linspace(x,0,11), marker="o")    # <======= # NEW 
    ax[1].plot(np.linspace(0,x,11), marker="o")
                                                   # <======= # REMOVED set_xlim() and set_ylim() 
plt.show()

In [None]:
# We can add another row to this figure by adding "nrows=" to plt.subplots() during the createon of the fig object 

# NOTE: We must explicity refer to both the row and column number for each ax object, e.g. ax[0,0] is the top right ax
# NOTE: i've changed the for loop to start at 0 and removed the line markers

fig, ax = plt.subplots(ncols=2,nrows=2,figsize=(8, 8))     # <======= # MODIFIED (add nrows=2)

for x in np.linspace(0,20,21,dtype=int):                   # <======= # MODIFIED
    ax[0,0].plot(np.linspace(x,0,11))                      # <======= # MODIFIED (set as ax[0,0] instead of ax[0])
    ax[0,1].plot(np.linspace(0,x,11))                      # <======= # MODIFIED (set as ax[0,1] instead of ax[1])
    ax[1,0].plot(np.linspace(x,20,11))                     # <======= # NEW (set to second row)
    ax[1,1].plot(np.linspace(20,x,11))                     # <======= # NEW (set to second row)

plt.show()

In [None]:
# How can we tell which ax is which? This could get confusing In large figures! Let's add some annotations! 

fig, ax = plt.subplots(ncols=2,nrows=2,figsize=(8, 8))

for x in np.linspace(0,20,21,dtype=int):
    ax[0,0].plot(np.linspace(x,0,11))
    ax[0,1].plot(np.linspace(0,x,11))
    ax[1,0].plot(np.linspace(x,20,11))
    ax[1,1].plot(np.linspace(20,x,11))

ax[0,0].annotate("ax[0,0]",xy=(0.275,0.5),xycoords='axes fraction', fontsize=30,color="red")             # <======= # NEW
ax[0,1].annotate("ax[0,1]",xy=(0.275,0.5),xycoords='axes fraction', fontsize=30,color="red")             # <======= # NEW
ax[1,0].annotate("ax[1,0]",xy=(0.275,0.5),xycoords='axes fraction', fontsize=30,color="red")             # <======= # NEW
ax[1,1].annotate("ax[1,1]",xy=(0.275,0.5),xycoords='axes fraction', fontsize=30,color="red")             # <======= # NEW

plt.show()


In [None]:
# These annotations could also be set using simple range functions and f-strings inside the annotation text

fig, ax = plt.subplots(ncols=2,nrows=2,figsize=(8, 8))

for x in np.linspace(0,20,21,dtype=int):
    ax[0,0].plot(np.linspace(x,0,11))
    ax[0,1].plot(np.linspace(0,x,11))
    ax[1,0].plot(np.linspace(x,20,11))
    ax[1,1].plot(np.linspace(20,x,11))

for i in range(2):                                                                                         # <======= # NEW
    for j in range(2):                                                                                     # <======= # NEW
        ax[i,j].annotate(f"ax[{i},{j}]\n row {i}\n  col {j}",xy=(0.275,0.25),xycoords='axes fraction', fontsize=30,color="red")  # <======= # MODIFIED (i and j are used to refer to rows and column number)

plt.show()


In [None]:
# Because these datasets span the exact same range of values, we can forcr our data to "share" the x- and y-axis limits, by adding these "keyword arguments" to plt.subplots()
# When you do this, the "tick labels" are only displayed for the lowest x-axis or most left y-axis by default

fig, ax = plt.subplots(ncols=2,nrows=2,figsize=(8, 8),sharex=True,sharey=True) # <======= # MODIFIED

for x in np.linspace(0,20,21,dtype=int):
    ax[0,0].plot(np.linspace(x,0,11))
    ax[0,1].plot(np.linspace(0,x,11))
    ax[1,0].plot(np.linspace(x,20,11))
    ax[1,1].plot(np.linspace(20,x,11))

plt.show()

In [None]:
# We can reduce the spacing between these plots using "plt.subplots_adjust()"

fig, ax = plt.subplots(ncols=2,nrows=2,figsize=(8, 8),sharex=True,sharey=True)

for x in np.linspace(0,20,21,dtype=int):
    ax[0,0].plot(np.linspace(x,0,11))
    ax[0,1].plot(np.linspace(0,x,11))
    ax[1,0].plot(np.linspace(x,20,11))
    ax[1,1].plot(np.linspace(20,x,11))

plt.subplots_adjust(hspace = 0.05,wspace = 0.05)        # <======= # NEW

plt.show()

In [None]:
# EXPLORE : What happens when you set hspace and wspace in subplots_adjust to zero?

In [None]:
# We can change how the axis ticks are displayed on each ax using "tick_params"

fig, ax = plt.subplots(ncols=2,nrows=2,figsize=(8, 8),sharex=True,sharey=True)

for x in np.linspace(0,20,21,dtype=int):
    ax[0,0].plot(np.linspace(x,0,11))
    ax[0,1].plot(np.linspace(0,x,11))
    ax[1,0].plot(np.linspace(x,20,11))
    ax[1,1].plot(np.linspace(20,x,11))

plt.subplots_adjust(hspace = 0.03,wspace = 0.03)

ax[0,0].tick_params(which='both', # Options for both major and minor ticks   # <======= # NEW
                top=False, # turn off top ticks                              # <======= # NEW
                left=True, # turn on left ticks                              # <======= # NEW
                right=False,  # turn off right ticks                         # <======= # NEW
                bottom=False) # turn off bottom ticks                        # <======= # NEW   

ax[0,1].tick_params(which='both', # Options for both major and minor ticks   # <======= # NEW
                top=False, # turn off top ticks                              # <======= # NEW
                left=False, # turn off left ticks                            # <======= # NEW
                right=False,  # turn off right ticks                         # <======= # NEW
                bottom=False) # turn off bottom ticks                        # <======= # NEW

ax[1,0].tick_params(which='both', # Options for both major and minor ticks   # <======= # NEW
                top=False, # turn off top ticks                              # <======= # NEW
                left=True, # turn on left ticks                              # <======= # NEW
                right=False,  # turn off right ticks                         # <======= # NEW
                bottom=True) # turn on bottom ticks                          # <======= # NEW

ax[1,1].tick_params(which='both', # Options for both major and minor ticks
                top=False, # turn off top ticks                              # <======= # NEW
                left=False, # turn off left ticks                            # <======= # NEW
                right=False,  # turn off right ticks                         # <======= # NEW
                bottom=True) # turn on bottom ticks                          # <======= # NEW

plt.show()

In [None]:
# This can be simplified by using a for loop that iterates through a dictionary of axis settings

fig, ax = plt.subplots(ncols=2,nrows=2,figsize=(8, 8),sharex=True,sharey=True)

for x in np.linspace(0,20,21,dtype=int):
    ax[0,0].plot(np.linspace(x,0,11))
    ax[0,1].plot(np.linspace(0,x,11))
    ax[1,0].plot(np.linspace(x,20,11))
    ax[1,1].plot(np.linspace(20,x,11))

plt.subplots_adjust(hspace = 0.03,wspace = 0.03)

#         {"ax[i,j]":(i,j,top,left,right,bottom)      # <======= # NEW (i and j are used to refer to rows and column number)
ax_dict = {"ax[0,0]":(0,0,False,True,False,False),    # <======= # NEW
           "ax[0,1]":(0,1,False,False,False,False),   # <======= # NEW
           "ax[1,0]":(1,0,False,True,False,True),     # <======= # NEW
           "ax[1,1]":(1,1,False,False,False,True)}    # <======= # NEW

for key, item in ax_dict.items():                                                                               # <======= # NEW
    ax[item[0],item[1]].tick_params(which='both', top=item[2], left=item[3], right=item[4], bottom=item[5])     # <======= # MODIFIED & CONDENSED 

plt.show()

In [None]:
# let's add another graph to this figure that is generated by a function! I've chosen to use the normal distribution, generated by the scipy package

import scipy.stats as stats          # <======= # NEW (only need to import once per notebook)

h = np.linspace(0,10,100)            # <======= # NEW (create the x-axis data points for the normal distribution)
fit = stats.norm.pdf(h,5,1)*10       # <======= # NEW (create the y-axis data points for the normal distribution)

fig, ax = plt.subplots(ncols=3,nrows=2,figsize=(12, 8),sharex=True,sharey=True)

for x in np.linspace(0,20,21,dtype=int):
    ax[0,0].plot(np.linspace(x,0,11))
    ax[0,1].plot(h,fit/4*x)                        # <======= # NEW (normal distribution)
    ax[0,2].plot(np.linspace(0,x,11))              # <======= # MODIFIED (increased column number)
    
    ax[1,0].plot(np.linspace(x,20,11))
    ax[1,1].plot(h,20-fit/4*x)                     # <======= # NEW (normal distribution)
    ax[1,2].plot(np.linspace(20,x,11))             # <======= # MODIFIED (increased column number)

plt.subplots_adjust(hspace = 0.03,wspace = 0.03)

#         {"ax[i,j]":(i,j,top,left,right,bottom)
ax_dict = {"ax[0,0]":(0,0,False,True,False,False),
           "ax[0,1]":(0,1,False,False,False,False), # <======= # NEW       
           "ax[0,2]":(0,2,False,False,False,False), # <======= # MODIFIED (increased column number)
           "ax[1,0]":(1,0,False,True,False,True),
           "ax[1,1]":(1,1,False,False,False,True),  # <======= # NEW         
           "ax[1,2]":(1,2,False,False,False,True)}  # <======= # MODIFIED (increased column number)

for key, item in ax_dict.items():
    ax[item[0],item[1]].tick_params(which='both', top=item[2], left=item[3], right=item[4], bottom=item[5])

plt.show()

In [None]:
# EXPLORE: how would you adjust the x and y axis limits in the above example?

In [None]:
# Next, we can add axis labels usign "ax[i,j].set_ylabel("")"

h = np.linspace(0,10,100)
fit = stats.norm.pdf(h,5,1)*10  

fig, ax = plt.subplots(ncols=3,nrows=2,figsize=(12, 8),sharex=True,sharey=True)

for x in np.linspace(0,20,21,dtype=int):
    ax[0,0].plot(np.linspace(x,0,11))
    ax[0,1].plot(h,fit/4*x)
    ax[0,2].plot(np.linspace(0,x,11))
    
    ax[1,0].plot(np.linspace(x,20,11))
    ax[1,1].plot(h,20-fit/4*x)
    ax[1,2].plot(np.linspace(20,x,11))

plt.subplots_adjust(hspace = 0.03,wspace = 0.03)

#         {"ax[i,j]":(i,j,top,left,right,bottom)
ax_dict = {"ax[0,0]":(0,0,False,True,False,False),
           "ax[0,1]":(0,1,False,False,False,False),         
           "ax[0,2]":(0,2,False,False,False,False), 
           "ax[1,0]":(1,0,False,True,False,True),
           "ax[1,1]":(1,1,False,False,False,True),           
           "ax[1,2]":(1,2,False,False,False,True)}

for key, item in ax_dict.items():
    ax[item[0],item[1]].tick_params(which='both', top=item[2], left=item[3], right=item[4], bottom=item[5])

for i in [0,1]:
    ax[i,0].set_ylabel("y-label text (units)")   # <======= # NEW    

for j in [0,1,2]:
    ax[1,j].set_xlabel("x-label text (units)")   # <======= # NEW    

plt.show()

In [None]:
# Of course we could just exapnd the ax_dict to include these values, which would be more when we aren't sharing the x- and y-axis...

h = np.linspace(0,10,100)
fit = stats.norm.pdf(h,5,1)*10  

fig, ax = plt.subplots(ncols=3,nrows=2,figsize=(12, 8),sharex=True,sharey=True)

for x in np.linspace(0,20,21,dtype=int):
    ax[0,0].plot(np.linspace(x,0,11))
    ax[0,1].plot(h,fit/4*x)
    ax[0,2].plot(np.linspace(0,x,11))
    
    ax[1,0].plot(np.linspace(x,20,11))
    ax[1,1].plot(h,20-fit/4*x)
    ax[1,2].plot(np.linspace(20,x,11))

plt.subplots_adjust(hspace = 0.03,wspace = 0.03)

#         {"ax[i,j]":(i,j,top,left,right,bottom,xlabel,ylabel)
ax_dict = {"ax[0,0]":(0,0,False,True,False,False,None,"y-label text (units)"),                   # <======= # MODIFIED 
           "ax[0,1]":(0,1,False,False,False,False,None,None),         
           "ax[0,2]":(0,2,False,False,False,False,None,None), 
           "ax[1,0]":(1,0,False,True,False,True,"x-label text (units)","y-label text (units)"),  # <======= # MODIFIED 
           "ax[1,1]":(1,1,False,False,False,True,"x-label text (units)",None),                   # <======= # MODIFIED 
           "ax[1,2]":(1,2,False,False,False,True,"x-label text (units)",None)}                   # <======= # MODIFIED 

for key, item in ax_dict.items():
    ax[item[0],item[1]].tick_params(which='both', top=item[2], left=item[3], right=item[4], bottom=item[5])
    ax[item[0],item[1]].set_xlabel(item[6])
    ax[item[0],item[1]].set_ylabel(item[7])

plt.show()

In [None]:
# We can also add roman numerals panel labels to the figure using annotate

h = np.linspace(0,10,100)
fit = stats.norm.pdf(h,5,1)*10 

fig, ax = plt.subplots(ncols=3,nrows=2,figsize=(12, 8),sharex=True,sharey=True)

for x in np.linspace(0,20,21,dtype=int):
    ax[0,0].plot(np.linspace(x,0,11))
    ax[0,1].plot(h,fit/4*x)
    ax[0,2].plot(np.linspace(0,x,11))
    
    ax[1,0].plot(np.linspace(x,20,11))
    ax[1,1].plot(h,20-fit/4*x)
    ax[1,2].plot(np.linspace(20,x,11))

plt.subplots_adjust(hspace = 0.03,wspace = 0.03)

#         {"ax[i,j]":(i,j,top,left,right,bottom,xlabel,ylabel,panel label,panel loc)
ax_dict = {"ax[0,0]":(0,0,False,True,False,False,None,"y-label text (units)","(i)",(0.9,0.925)),                         # <======= # MODIFIED
           "ax[0,1]":(0,1,False,False,False,False,None,None,"(ii)",(0.9,0.925)),                                         # <======= # MODIFIED   
           "ax[0,2]":(0,2,False,False,False,False,None,None,"(iii)",(0.025,0.925)),                                      # <======= # MODIFIED
           "ax[1,0]":(1,0,False,True,False,True,"x-label text (units)","y-label text (units)","(iv)",(0.875,0.05)),      # <======= # MODIFIED
           "ax[1,1]":(1,1,False,False,False,True,"x-label text (units)",None,"(v)",(0.9,0.05)),                          # <======= # MODIFIED
           "ax[1,2]":(1,2,False,False,False,True,"x-label text (units)",None,"(vi)",(0.025,0.05))}                       # <======= # MODIFIED

for key, item in ax_dict.items():
    ax[item[0],item[1]].tick_params(which='both', top=item[2], left=item[3], right=item[4], bottom=item[5])
    ax[item[0],item[1]].set_xlabel(item[6])
    ax[item[0],item[1]].set_ylabel(item[7])
    ax[item[0],item[1]].annotate(item[8], xy=item[9],xycoords='axes fraction', fontsize=12)                              # <======= # MODIFIED

plt.show()

In [None]:
# We can add a legend using "fig.legend()"

h = np.linspace(0,10,100)
fit = stats.norm.pdf(h,5,1)*10

fig, ax = plt.subplots(ncols=3,nrows=2,figsize=(12, 8),sharex=True,sharey=True)

for x in np.linspace(0,20,21,dtype=int):
    ax[0,0].plot(np.linspace(x,0,11),label=x)
    ax[0,1].plot(h,fit/4*x)
    ax[0,2].plot(np.linspace(0,x,11))
    
    ax[1,0].plot(np.linspace(x,20,11))
    ax[1,1].plot(h,20-fit/4*x)
    ax[1,2].plot(np.linspace(20,x,11))

plt.subplots_adjust(hspace = 0.03,wspace = 0.03)

#         {"ax[i,j]":(i,j,top,left,right,bottom,xlabel,ylabel,panel label,panel loc)
ax_dict = {"ax[0,0]":(0,0,False,True,False,False,None,"y-label text (units)","(i)",(0.9,0.925)),
           "ax[0,1]":(0,1,False,False,False,False,None,None,"(ii)",(0.9,0.925)),         
           "ax[0,2]":(0,2,False,False,False,False,None,None,"(iii)",(0.025,0.925)), 
           "ax[1,0]":(1,0,False,True,False,True,"x-label text (units)","y-label text (units)","(iv)",(0.875,0.05)),
           "ax[1,1]":(1,1,False,False,False,True,"x-label text (units)",None,"(v)",(0.9,0.05)),      
           "ax[1,2]":(1,2,False,False,False,True,"x-label text (units)",None,"(vi)",(0.025,0.05))} 

for key, item in ax_dict.items():
    ax[item[0],item[1]].tick_params(which='both', top=item[2], left=item[3], right=item[4], bottom=item[5])
    ax[item[0],item[1]].set_xlabel(item[6])
    ax[item[0],item[1]].set_ylabel(item[7])
    ax[item[0],item[1]].annotate(item[8], xy=item[9],xycoords='axes fraction', fontsize=12)

fig.legend(loc='center right', frameon=True,title="x value:")
plt.show()

In [None]:
# We can also turn this whole figure into a function by indenting everything and using "def" to define a function

# NOTE: this is particularly useful for reproducibility as you can just add new data and always process and display it the exact same way!
# NOTE: when you run this cell there will be no output, you are simply creating the function to be called elsewhere
# NOTE: 20 was a commmon value used throught this function, it has been replaced with the variable "a" so we can modify how many lines are drawn
# NOTE: calling plt.tight_layout() will make sure everything fits nicely

def myFigure(a):                                      # <======= # NEW
    h = np.linspace(0,10,a*5)                         # <======= # MODIFIED
    fit = stats.norm.pdf(h,5,1)*10

    fig, ax = plt.subplots(ncols=3,nrows=2,figsize=(12, 8),sharex=True,sharey=True)

    for x in np.linspace(0,a,a+1,dtype=int):          # <======= # MODIFIED
        ax[0,0].plot(np.linspace(x,0,11),label=x)
        ax[0,1].plot(h,fit/4*x)
        ax[0,2].plot(np.linspace(0,x,11))

        ax[1,0].plot(np.linspace(x,a,11))             # <======= # MODIFIED
        ax[1,1].plot(h,a-fit/4*x)                     # <======= # MODIFIED
        ax[1,2].plot(np.linspace(a,x,11))             # <======= # MODIFIED

    plt.subplots_adjust(hspace = 0.03,wspace = 0.03)

    #         {"ax[i,j]":(i,j,top,left,right,bottom,xlabel,ylabel,panel label,panel loc)
    ax_dict = {"ax[0,0]":(0,0,False,True,False,False,None,"y-label text (units)","(i)",(0.9,0.925)),
               "ax[0,1]":(0,1,False,False,False,False,None,None,"(ii)",(0.9,0.925)),         
               "ax[0,2]":(0,2,False,False,False,False,None,None,"(iii)",(0.025,0.925)), 
               "ax[1,0]":(1,0,False,True,False,True,"x-label text (units)","y-label text (units)","(iv)",(0.875,0.05)),
               "ax[1,1]":(1,1,False,False,False,True,"x-label text (units)",None,"(v)",(0.9,0.05)),      
               "ax[1,2]":(1,2,False,False,False,True,"x-label text (units)",None,"(vi)",(0.025,0.05))} 

    for key, item in ax_dict.items():
        ax[item[0],item[1]].tick_params(which='both', top=item[2], left=item[3], right=item[4], bottom=item[5])
        ax[item[0],item[1]].set_xlabel(item[6])
        ax[item[0],item[1]].set_ylabel(item[7])
        ax[item[0],item[1]].annotate(item[8], xy=item[9],xycoords='axes fraction', fontsize=12)

    fig.legend(loc="center left",bbox_to_anchor=(1.00,0.5), frameon=True, title="x value:", ncol=int(a/20)+1) # <======= # MODIFIED
    plt.tight_layout()
    plt.show()

In [None]:
# RUN THIS FUNCTION and set a to 20!
myFigure(20)

In [None]:
# change a to any value you want (it might take forever to render above 200...)
myFigure(200)

In [None]:
# EXPLORE : how low can you go?
myFigure(3)

In [None]:
# We can also put our functions into a separate file that we can import as a package. 
# Above, press FILE>NEW>TEXT FILE, paste the entire myFigure function into that new text file, then save it as "myPackage.py"
# Ho through this file and find all of the packages that we have imported, paste them at the top of the myPackage.py file
# Now we can import this file as a python package!
# NOTE: if you forgot to import a package and got an error here, you'll have to imprt that package to myPackage.py, save the file, then add "%load_ext autoreload" and %autoreload 2" on separate lines to this cell so that the package is imported again 

%load_ext autoreload
%autoreload 2

import myPackage as mypac

mypac.myFigure(20)

In [None]:
# You can also explicitly import only the myFigure() function and call it whatever you want 

from myPackage import myFigure as myfig

myfig(20)

In [None]:
# See! You can call it whatever you want! But you should try to name things usefully... 

from myPackage import myFigure as i_am_batman

i_am_batman(20)

In [None]:
# Finally, let's save this figure by creating a new function called "mySavedFigure" where we also add a "savename" argument

def mySavedFigure(a,savename): # <======= # MODIFIED 
    h = np.linspace(0,10,a*5)
    fit = stats.norm.pdf(h,5,1)*10  #this is a fitting indeed

    fig, ax = plt.subplots(ncols=3,nrows=2,figsize=(12, 8),sharex=True,sharey=True)

    for x in np.linspace(0,a,a+1,dtype=int):
        ax[0,0].plot(np.linspace(x,0,11),label=x)
        ax[0,1].plot(h,fit/4*x)
        ax[0,2].plot(np.linspace(0,x,11))

        ax[1,0].plot(np.linspace(x,a,11))
        ax[1,1].plot(h,a-fit/4*x)
        ax[1,2].plot(np.linspace(a,x,11))

    plt.subplots_adjust(hspace = 0.03,wspace = 0.03)

    #         {"ax[i,j]":(i,j,top,left,right,bottom,xlabel,ylabel,panel label,panel loc)
    ax_dict = {"ax[0,0]":(0,0,False,True,False,False,None,"y-label text (units)","(i)",(0.9,0.925)),
               "ax[0,1]":(0,1,False,False,False,False,None,None,"(ii)",(0.9,0.925)),         
               "ax[0,2]":(0,2,False,False,False,False,None,None,"(iii)",(0.025,0.925)), 
               "ax[1,0]":(1,0,False,True,False,True,"x-label text (units)","y-label text (units)","(iv)",(0.875,0.05)),
               "ax[1,1]":(1,1,False,False,False,True,"x-label text (units)",None,"(v)",(0.9,0.05)),      
               "ax[1,2]":(1,2,False,False,False,True,"x-label text (units)",None,"(vi)",(0.025,0.05))} 

    for key, item in ax_dict.items():
        ax[item[0],item[1]].tick_params(which='both', top=item[2], left=item[3], right=item[4], bottom=item[5])
        ax[item[0],item[1]].set_xlabel(item[6])
        ax[item[0],item[1]].set_ylabel(item[7])
        ax[item[0],item[1]].annotate(item[8], xy=item[9],xycoords='axes fraction', fontsize=12)

    fig.legend(loc="center left",bbox_to_anchor=(1.00,0.5), frameon=True, title="x value:", ncol=int(a/20)+1)
    plt.tight_layout()
    plt.savefig(savename)   # <======= # NEW 
    plt.show()

In [None]:
mySavedFigure(20,"savedFig.jpg")

In [None]:
# EXPLORE: what happens if you change .jpg to .pdf?

In [None]:
# That's it! You should now be able to make simple figures using matplotlib!
# Just as importamntly, you've been interacting with object-oriented programming!

In [None]:
# NOTE: You can make much more complicated graph layouts using a vairety of built-in tools
# E.g. gridspec!  https://matplotlib.org/3.2.1/tutorials/intermediate/gridspec.html

fig3 = plt.figure(constrained_layout=True)
gs = fig3.add_gridspec(3, 3)
f3_ax1 = fig3.add_subplot(gs[0, :])
f3_ax1.set_title('gs[0, :]')
f3_ax2 = fig3.add_subplot(gs[1, :-1])
f3_ax2.set_title('gs[1, :-1]')
f3_ax3 = fig3.add_subplot(gs[1:, -1])
f3_ax3.set_title('gs[1:, -1]')
f3_ax4 = fig3.add_subplot(gs[-1, 0])
f3_ax4.set_title('gs[-1, 0]')
f3_ax5 = fig3.add_subplot(gs[-1, -2])
f3_ax5.set_title('gs[-1, -2]')
plt.show()

fig9 = plt.figure(constrained_layout=False)
gs1 = fig9.add_gridspec(nrows=3, ncols=3, left=0.05, right=0.48,
                        wspace=0.05)
f9_ax1 = fig9.add_subplot(gs1[:-1, :])
f9_ax2 = fig9.add_subplot(gs1[-1, :-1])
f9_ax3 = fig9.add_subplot(gs1[-1, -1])

gs2 = fig9.add_gridspec(nrows=3, ncols=3, left=0.55, right=0.98,
                        hspace=0.05)
f9_ax4 = fig9.add_subplot(gs2[:, :-1])
f9_ax5 = fig9.add_subplot(gs2[:-1, -1])
f9_ax6 = fig9.add_subplot(gs2[-1, -1])
plt.show()
