# Plotting in Python

An introduction to the stars of the show: `matplotlib`, `pandas`, and `seaborn`

## Intro to pyplot

`matplotlib.pyplot` is a collection of functions that make matplotlib work like MATLAB, if you're wondering about etymology. For our purposes, it's Python's best basic plotting interface. Each `pyplot` function makes some change to a figure: e.g., creates a figure, creates a plotting area in a figure, plots some lines in a plotting area, decorates the plot with labels, etc.

In `matplotlib.pyplot` various states are preserved across function calls, so that it keeps track of things like the current figure and plotting area, and the plotting functions are directed to the current axes (axes is matplotlib-speak for a particular subplot, or set of axes)

Generating visualizations with pyplot can be as quick as a single line with `plt.plot`:

In [None]:
import matplotlib.pyplot as plt
x = [1, 2, 3, 4, 5]
y = [2, 4, 6, 8, 10]
plt.plot(x, y)

But you can customize things quite a bit. Feel free to change any of the parameters in here and re-run the cell to see what happens

In [None]:
import matplotlib.pyplot as plt
import numpy as np
x = np.array([1, 2, 3, 4, 5])
y = x**2
fig, ax = plt.subplots(layout='constrained')
plt.sca(ax)
plt.plot(x, y,
         color='red', linestyle=':', linewidth=2,
         marker='o', markersize=10, markerfacecolor='black',
         alpha=0.9)
plt.ylabel('Dependent Variable')
plt.xlabel('Independent Variable', fontsize=14)
plt.title('Simple Line Plot')
xticks = x
xticklabels = xticks
plt.xticks(xticks, xticklabels,
           fontsize=12, fontweight='bold', rotation=90)
ax.text(1, 25, 'Text Annotation', fontsize=12,
        color='blue', ha='left', va='top')
plt.show()

However, you may have heard me talk about `seaborn` before. Seaborn is based on top of `matplotlib` but has some additional functionality and much prettier defaults. You will generally find yourself descending into matplotlib directly to change things like the size of your figures and the axis labels, but the base plotting will substitute `sns.lineplot` for `plt.plot` (and `sns.scatterplot` may become your closest friend and companion)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

x = np.array([1, 2, 3, 4, 5])
y = x**2
fig, ax = plt.subplots(layout='constrained')
plt.sca(ax)
sns.lineplot(x=x, y=y,
         color='red', linestyle=':', linewidth=2,
         marker='o', markersize=10, markerfacecolor='black',
         alpha=0.9)
plt.ylabel('Dependent Variable')
plt.xlabel('Independent Variable', fontsize=14)
plt.title('Simple Line Plot')
xticks = x
xticklabels = xticks
plt.xticks(xticks, xticklabels,
           fontsize=12, fontweight='bold', rotation=90)
ax.text(1, 25, 'Text Annotation', fontsize=12,
        color='blue', ha='left', va='top')
plt.show()

One nice feature of `seaborn` is easy access to a number of color palettes. You may or may not recognize the PARIS cinematic universe, but one feature is definitely a slight improvement to the `matplotlib` default. I generally prefer grabbing individual colors from these palettes to color specifications like `'red'` or `'blue'`, although both will work.

In [None]:
display(sns.color_palette())
display(sns.color_palette('dark'))
display(sns.color_palette('muted'))
display(sns.color_palette('mako'))
display(sns.color_palette('rocket'))
display(sns.color_palette('husl', n_colors=8))

## Pandas: A brief aside

You can spend quite a lot of time just plotting, but you'll likely spend more organizing and transforming your data. To keep things tight, you'll want to keep everything in a `pandas.DataFrame`. The basic concept is very familiar: a dataframe is just a fancy word for a table:


In [None]:
import pandas as pd
df = pd.DataFrame({'Participant ID': [f'03374-{str(i).zfill(3)}' for i in range(1, 6)],})
df['Sample ID'] = [''.join(np.random.choice(np.arange(1,10).astype(str), 5)) for _ in range(5)]
df['Spike-Binding Antibodies'] = np.random.randint(5, 40, 5)
display(df)

It can also read in data directly from Excel files and csvs

In [None]:
pd.read_csv('../data/atlantis.csv').head()

However, the way you interact with the tables is really a whole new language. The long and short of it is that Excel has a very convenient interface for entering new data (see: our entire Sharepoint), while `pandas` almost exclusively reads and manipulates existing datasets.

What we want from `pandas` today, though, is really just a way to refer to our data by column names. See the `sns.scatterplot` syntax below

In [None]:

data = {'a': np.arange(50),
        'color': np.random.randint(0, 50, 50),
        'diameter': np.random.randn(50)}
data['b'] = data['a'] + 10 * np.random.randn(50)
data['diameter'] = np.abs(data['diameter']) * 100
df = pd.DataFrame(data)

sns.scatterplot(data=df, x='a', y='b', hue='color', size='diameter')
plt.xlabel('Entry a')
plt.ylabel('Entry b')
plt.legend(ncols=2)
plt.show()

There are a lot more plots to be made. Check out a surfeit of examples here:
https://seaborn.pydata.org/examples/index.html

I lifted a couple fun ones you can run below

In [None]:
import seaborn as sns

with sns.axes_style("dark"):
    flights = sns.load_dataset("flights")

    # Plot each year's time series in its own facet
    g = sns.relplot(
        data=flights,
        x="month", y="passengers", col="year", hue="year",
        kind="line", palette="crest", linewidth=4, zorder=5,
        col_wrap=3, height=2, aspect=1.5, legend=False,
    )

    # Iterate over each subplot to customize further
    for year, ax in g.axes_dict.items():

        # Add the title as an annotation within the plot
        ax.text(.8, .85, year, transform=ax.transAxes, fontweight="bold")

        # Plot every year's time series in the background
        sns.lineplot(
            data=flights, x="month", y="passengers", units="year",
            estimator=None, color=".7", linewidth=1, ax=ax,
        )

    # Reduce the frequency of the x axis ticks
    ax.set_xticks(ax.get_xticks()[::2])

    # Tweak the supporting aspects of the plot
    g.set_titles("")
    g.set_axis_labels("", "Passengers")
    g.tight_layout()

In [None]:

with sns.axes_style("ticks"):
    rs = np.random.RandomState(11)
    x = rs.gamma(2, size=1000)
    y = -.5 * x + rs.normal(size=1000)

    sns.jointplot(x=x, y=y, kind="hex", color="#4CB391")