In [None]:
# plotting imports
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
# other imports
import numpy as np
import pandas as pd
from scipy import stats

Hello world
---

Using the `pyplot` notation, very similar to how MATLAB works

In [None]:
plt.plot([0, 1, 2, 3, 4],
         [0, 1, 2, 5, 10], 'bo-')
plt.text(1.5, 5, 'Hello world', size=14)
plt.xlabel('X axis\n($\mu g/mL$)')
plt.ylabel('y axis\n($X^2$)');

Hello world, reprise
---

Using the reccommended "object-oriented" (OO) style

In [None]:
fig, ax = plt.subplots()
ax.plot([0, 1, 2, 3, 4],
        [0, 1, 2, 5, 10], 'bo-')
ax.text(1.5, 5, 'Hello world', size=14)
ax.set_xlabel('X axis\n($\mu g/mL$)')
ax.set_ylabel('y axis\n($X^2$)');

In [None]:
# create some data
x = np.linspace(0, 2, 100)

In [None]:
fig, ax = plt.subplots()

ax.plot(x, x, label='linear')
ax.plot(x, x**2, label='quadratic')
ax.plot(x, x**3, label='cubic')

ax.set_xlabel('x label')
ax.set_ylabel('y label')
ax.set_title('Simple Plot')
ax.legend()

Controlling a figure aspect
---

In [None]:
# figure size
# width / height
fig, ax = plt.subplots(figsize=(9, 4))

ax.plot(x, x, label='linear')
ax.plot(x, x**2, label='quadratic')
ax.plot(x, x**3, label='cubic')

ax.set_xlabel('x label')
ax.set_ylabel('y label')
ax.set_title('Simple Plot')
ax.legend();

In [None]:
fig, ax = plt.subplots(figsize=(9, 4))

# change markers
ax.plot(x, x, '--', color='grey', label='linear')
ax.plot(x, x**2, '.-', color='red', label='quadratic')
ax.plot(x, x**3, '*', color='#3bb44a', label='cubic')

ax.set_xlabel('x label')
ax.set_ylabel('y label')
ax.set_title('Simple Plot')

# move the legend
ax.legend(loc='upper right');
# alternative ways to move it
# ax.legend(loc='center left',
#            bbox_to_anchor=(1, 0.5),
#            ncol=3);

Multiple panels
---

In [None]:
x1 = np.linspace(0.0, 5.0)
x2 = np.linspace(0.0, 2.0)

y1 = np.cos(2 * np.pi * x1) * np.exp(-x1)
y2 = np.cos(2 * np.pi * x2)

# rows, columns
fig, axes = plt.subplots(2, 1, figsize=(6, 4))

# axes is a list of "panels"
print(axes)

ax = axes[0]
ax.plot(x1, y1, 'o-')
ax.set_title('A tale of 2 subplots')
ax.set_ylabel('Damped oscillation')

ax = axes[1]
ax.plot(x2, y2, '.-')
ax.set_xlabel('time (s)')
ax.set_ylabel('Undamped');

Automagically adjust panels so that they fit in the figure
---

In [None]:
def example_plot(ax, fontsize=12):
    ax.plot([1, 2])

    ax.set_xlabel('x-label', fontsize=fontsize)
    ax.set_ylabel('y-label', fontsize=fontsize)
    ax.set_title('Title', fontsize=fontsize)

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(4, 4),
                        constrained_layout=False)
print(axs)
for ax in axs.flat:
    example_plot(ax)

In [None]:
# warning: "constrained_layout" is an experimental feature
fig, axs = plt.subplots(2, 2, figsize=(4, 4),
                        constrained_layout=True)

for ax in axs.flat:
    example_plot(ax)

In [None]:
# alternative way
fig, axs = plt.subplots(2, 2, figsize=(4, 4), constrained_layout=False)

for ax in axs.flat:
    example_plot(ax)
    
# alternative to constrained_layout
plt.tight_layout();

Example of manipulating axes limits
---

Extra: a look at ways to choose colors
and manipulating transparency

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(9, 4))

# same plot for both panels
# we are just gonna change the axes' limits
for ax in axes:
    # more color choices
    # (see here for a full list: https://matplotlib.org/tutorials/colors/colors.html)
    
    # xkcd rgb color survey: https://xkcd.com/color/rgb/
    ax.plot(x, x, '--', color='xkcd:olive green', label='linear')
    # RGBA (red, green, blue, alpha)
    ax.plot(x, x**2, '.-', color=(0.1, 0.2, 0.5, 0.3), label='quadratic')
    # one of {'b', 'g', 'r', 'c', 'm', 'y', 'k', 'w'}
    # they are the single character short-hand notations for:
    # blue, green, red, cyan, magenta, yellow, black, and white
    ax.plot(x, x**3, '*', color='m', label='cubic')
    # transparency can be manipulated with the "alpha" kwarg (= keyword argument)
    ax.plot(x, x**4, '-', color='b', linewidth=4, alpha=0.3, label='white house')

    ax.set_xlabel('x label')
    ax.set_ylabel('y label')
    ax.set_title('Simple Plot')

# only manipulate last axes
ax.set_ylim(1, 16.4)
ax.set_xlim(1.65, 2.03)

ax.legend(loc='center left',
          bbox_to_anchor=(1, 0.5),
          title='Fit');

Other sample plots using "vanilla" matplotlib
---

In [None]:
# scatter plot
fig, ax = plt.subplots(figsize=(6, 4))

N = 10
x = np.linspace(0, 1, N)
y = x ** 2
# colors is a list of colors
# in the same format as shown before
colors = np.linspace(0, 1, N)
# alternative
# colors = ['b', 'b', 'b',
#           'k', 'k', 'k',
#           'r', 'r', 'r',
#           'xkcd:jade']
area = 5 + (20 * x) ** 3

print(f'x: {x}')
print(f'y: {y}')
print(f'colors: {colors}')
print(f'area: {area}')

ax.scatter(x, y, s=area, c=colors,
           alpha=0.9,
           edgecolors='w', linewidths=3,
           label='Data')
ax.legend(loc='upper left');

In [None]:
# generate 2d random data
data = np.random.randn(2, 100)
data

In [None]:
# histogram
fig, axs = plt.subplots(1, 2, figsize=(6, 3))

bins = 25

axs[0].hist(data[0], bins=bins)
axs[1].hist2d(data[0], data[1], bins=bins);

Other useful tips
---

In [None]:
# scatter plot with log axes
fig, ax = plt.subplots(figsize=(6, 4))

N = 10
x = np.linspace(0, 10, N)
y = 2 ** x
colors = np.linspace(0, 1, N)
area = 500

ax.scatter(x, y, s=area, c=colors,
           alpha=0.9,
           edgecolors='w', linewidths=3,
           label='Data')
ax.set_yscale('log', basey=10);

In [None]:
# scatter plot with log axes
fig, ax = plt.subplots(figsize=(6, 4))

N = 10
x = 10 ** np.linspace(1, 4, N)
y = x ** 2
colors = np.linspace(0, 1, N)
area = 500

ax.scatter(x, y, s=area, c=colors,
           alpha=0.9,
           edgecolors='w', linewidths=3,
           label='Data')
ax.set_yscale('log', basey=2)
ax.set_xscale('log', basex=10);

In [None]:
# changing colormap
# find an exhaustive list here:
# https://matplotlib.org/3.1.0/tutorials/colors/colormaps.html
fig, ax = plt.subplots(figsize=(6, 4))

N = 10
x = 10 ** np.linspace(1, 4, N)
y = x ** 2
colors = np.linspace(0, 1, N)
area = 500

ax.scatter(x, y, s=area, c=colors,
           alpha=0.9,
           edgecolors='w', linewidths=3,
           label='Data',
#            cmap='plasma',
#            cmap='jet',
#            cmap='Blues',
#            cmap='Blues_r',
           cmap='tab20',
          )
ax.set_yscale('log', basey=2)
ax.set_xscale('log', basex=10);

Saving your plot
---

In [None]:
fig, ax = plt.subplots(figsize=(3, 2))

N = 10
x = 10 ** np.linspace(1, 4, N)
y = x ** 2
colors = np.linspace(0, 1, N)
area = 500

ax.scatter(x, y, s=area, c=colors,
           alpha=0.9,
           edgecolors='w', linewidths=3,
           cmap='tab20',
           label='My awesome data is the best thing ever',
#            rasterized=True
          )

ax.legend(bbox_to_anchor=(1, 0.5),
          loc='center left')

ax.set_yscale('log', basey=2)
ax.set_xscale('log', basex=10)

plt.savefig('the_awesomest_plot_ever.png',
            dpi=300,
            bbox_inches='tight',
            transparent=True
           )
plt.savefig('the_awesomest_plot_ever.svg',
            dpi=300, bbox_inches='tight',
            transparent=True);

What about other plot types?
---

Such as boxplots, heatmaps, complex plots *a la* ggplot.
Find a galery with code here: https://matplotlib.org/gallery/index.html

**However:** seaborn covers them much better, leveraging the power of pandas `DataFrame`s

In [None]:
# load a nice, tidy, test dataset
tips = sns.load_dataset('tips')

In [None]:
tips.head()

In [None]:
tips.info()

In [None]:
sns.relplot(x="total_bill",
            y="tip",
            col="time",
            # same category used for both style and color
            hue="smoker",
            style="smoker",
            size="size",
            data=tips,
            # stylistic stuff
            height=4,
            aspect=1,
            palette=['xkcd:indigo', 'xkcd:grass green'],
            hue_order=['No', 'Yes']);

Statistical relationships: `relplot`
---

In [None]:
sns.relplot(x="total_bill", y="tip",
            size="size",
            # change size range
            sizes=(1, 100),
            #
            data=tips);

In [None]:
# some more data
fmri = sns.load_dataset("fmri")

In [None]:
fmri.head()

In [None]:
sns.relplot(x="timepoint", y="signal",
            data=fmri);

In [None]:
sns.relplot(x="timepoint", y="signal",
            # lineplot instead of scatterplot
            # by default mean and 95% confidence interval
            # if we are aggregating multiple lines
            kind="line",
            #
            data=fmri);

In [None]:
sns.relplot(x="timepoint", y="signal",
            # lineplot instead of scatterplot
            kind="line",
            #
            ci='sd', estimator="median",
            data=fmri);

In [None]:
sns.relplot(x="timepoint", y="signal",
            col="region",
            hue="event",
            kind="line",
            data=fmri);

In [None]:
sns.relplot(x="timepoint", y="signal",
            col="region",
            hue="event",
            kind="line",
            # no aggregation
            estimator=None,
            #
            data=fmri);

In [None]:
sns.relplot(x="timepoint", y="signal",
            col="region",
            hue="event",
            kind="line",
            # no aggregation
            units="subject", estimator=None,
            #
            data=fmri);

**I do I access individual axes?**

In [None]:
rp = sns.relplot(x="timepoint", y="signal",
                 col="region",
                 hue="event",
                 kind="line",
                 data=fmri)
for ax in rp.axes.flatten():
    ax.set_xlabel('Changing stuff')
    ax.set_ylabel('Also here')
    # crazy thing
    ax.set_yscale('log', basey=2)

In [None]:
rp.axes    

In [None]:
rp.axes.shape

In [None]:
# not only columns, but also rows
sns.relplot(x="timepoint", y="signal", hue="subject",
            col="region", row="event", height=3,
            kind="line", estimator=None, palette='hsv',
            data=fmri);

Categorical data: `catplot`
---

Categorical scatterplots:

    stripplot() (with kind="strip"; the default)

    swarmplot() (with kind="swarm")

Categorical distribution plots:

    boxplot() (with kind="box")

    violinplot() (with kind="violin")

    boxenplot() (with kind="boxen")

Categorical estimate plots:

    pointplot() (with kind="point")

    barplot() (with kind="bar")

    countplot() (with kind="count")


In [None]:
sns.catplot(x="day", y="total_bill",
#             jitter=False,
#             jitter=0.4,
            data=tips);

In [None]:
sns.catplot(x="day", y="total_bill",
#             jitter=0.4,
            kind='swarm',
            data=tips);

In [None]:
sns.catplot(x="day", y="total_bill",
            hue='sex',
            hue_order=['Female', 'Male'],
            order=['Sun', 'Sat', 'Fri', 'Thur'],
            kind='swarm',
            data=tips);

In [None]:
sns.catplot(x="size", y="total_bill",
            kind="swarm",
#             orient='h',
            data=tips);

In [None]:
sns.catplot(x="day", y="total_bill",
            hue='sex',
            hue_order=['Female', 'Male'],
            order=['Sun', 'Sat', 'Fri', 'Thur'],
            kind='swarm',
            data=tips);

In [None]:
sns.catplot(x="day", y="total_bill",
            hue='sex',
            hue_order=['Female', 'Male'],
            order=['Sun', 'Sat', 'Fri', 'Thur'],
            kind='box',
            notch=True,
            data=tips);

In [None]:
sns.catplot(x="day", y="total_bill",
            hue='sex',
            hue_order=['Female', 'Male'],
            order=['Sun', 'Sat', 'Fri', 'Thur'],
            kind='boxen',
            data=tips);

In [None]:
sns.catplot(x="day", y="total_bill",
            hue='sex',
            hue_order=['Female', 'Male'],
            order=['Sun', 'Sat', 'Fri', 'Thur'],
            kind='violin',
            palette='Blues',
#             split=True,
#             bw=0.1, # bandwith argument for the underlying kde
#             palette='jet',
            data=tips);

In [None]:
g = sns.catplot(x="day", y="total_bill",
                kind="violin",
                inner=None, data=tips)
sns.swarmplot(x="day", y="total_bill",
              color="k",
              size=3, data=tips,
              ax=g.ax);

In [None]:
g = sns.catplot(y="day", x="total_bill",
                kind="box", color='xkcd:pale grey',
                orient='h',
                data=tips,
                height=3, aspect=2)
sns.swarmplot(y="day", x="total_bill",
              orient='h',
              color="k", alpha=0.7,
              size=6, data=tips,
              ax=g.ax);

In [None]:
sns.catplot(x="day", y="total_bill",
            hue='sex',
            data=tips,
            kind='bar',
#             ci='sd', estimator=np.median,
            height=5, aspect=0.7);

In [None]:
sns.catplot(x="day", y="total_bill",
            hue='sex',
            data=tips,
            kind='point',
            height=5, aspect=0.7);

Linear regressions: `lmplot`
---

In [None]:
sns.lmplot(x="total_bill", y="tip", data=tips);

In [None]:
anscombe = sns.load_dataset("anscombe")
anscombe.head()

In [None]:
sns.lmplot(x="x", y="y", data=anscombe,
           ci=None, height=4, aspect=1,
           col='dataset', col_wrap=2,
#            order=2,
#            robust=True,
#            lowess=True,
          );

In [None]:
sns.lmplot(x="total_bill", y="tip", hue="smoker",
#            ci=None,
           data=tips);

Univariate and bivariate distributions: `distplot`
---

In [None]:
# random distribution
x = np.random.normal(size=100)

In [None]:
sns.distplot(x,
#              kde=False,
#              rug=True,
             bins=20,
             )
sns.despine();

In [None]:
# random gamma distribution
x = np.random.gamma(6, size=200)

In [None]:
sns.distplot(x, 
             kde=False,
             fit=stats.gamma,
#              fit=stats.norm,
            );

In [None]:
# generate bivariate
mean, cov = [0, 1], [(1, .5), (.5, 1)]
data = np.random.multivariate_normal(mean, cov, 200)
df = pd.DataFrame(data, columns=["x", "y"])

In [None]:
sns.jointplot(x="x", y="y", data=df,
#               kind='hex',
#               kind='kde',
             );

In [None]:
# another dataset
iris = sns.load_dataset("iris")
iris.sample(5)

In [None]:
sns.pairplot(iris,
#              hue='species'
            );

Heatmaps, clustermaps
---

- Heatmap `sns.heatmap` example: http://seaborn.pydata.org/examples/many_pairwise_correlations.html
- Clustermap `sns.clustermap` example: http://seaborn.pydata.org/examples/structured_heatmap.html

Last: changing styles
---

In [None]:
plt.style.available

In [None]:
for style in plt.style.available:
    with plt.style.context(style):
        lp = sns.lmplot(x="total_bill", y="tip", hue="smoker",
                        ci=None,
                        data=tips)
        lp.axes.flatten()[0].set_title(style);