<a href="https://colab.research.google.com/github/manolan1/ml-lab/blob/master/05-04-Seaborn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


In [None]:
%matplotlib notebook
import pandas as pd
import matplotlib.pyplot as plt

# can set these parameters individually or pass a dictionary
fontdict = {
    'fontsize': 12,
    'weight': 'bold',
    'horizontalalignment': 'center'
}

# Comparing Pyplot and Seaborn

Our dataset is from the Capital Bikeshare in Washington, DC.

In [None]:
day = pd.read_csv('Bike-Sharing-Dataset/day.csv')

## Pyplot

We will start by trying a simple scatter plot.

In [None]:
plt.scatter('temp', 'cnt', data = day, c = 'season')
plt.xlabel('Normalized Temperature', fontsize = 'large')
plt.ylabel('Count of Total Bike Rentals', fontsize = 'large')

That's interesting and there appears to be some correlation between temperature and the number of bike rentals. But we would like to understand the role of seasons a little more clearly.

In [None]:
fig = plt.figure(figsize = (9.5, 6))

spring = plt.scatter('temp', 'cnt', data = day[day['season'] == 1], marker='o', color = 'green')
summer = plt.scatter('temp', 'cnt', data = day[day['season'] == 2], marker='o', color = 'orange')
autumn = plt.scatter('temp', 'cnt', data = day[day['season'] == 3], marker='o', color = 'brown')
winter = plt.scatter('temp', 'cnt', data = day[day['season'] == 4], marker='o', color = 'blue')
plt.legend(
    handles = (spring, summer, autumn, winter),
    labels = ('Spring', 'Summer', 'Autumn', 'Winter'),
    title = 'Season',
    title_fontsize = 12,
    scatterpoints = 1,
    loc = 'upper left', 
    ncol = 1,
    fontsize = 12
)
plt.title('Bike Rentals at Different Temperatures\nBy Season', fontdict = fontdict, color = 'black')
plt.xlabel('Normalized temperature', fontdict = fontdict)
plt.ylabel('Count of Total Rental Bikes', fontdict = fontdict)

That's better, now we can clearly see the season. However, it would be easier to see trends if we could separate them out.

We will add subplots to separate each season.

When adding a subplot, we can either specify three numbers, or a three digit number. In either case, the first digit is the number of rows, the second is the number of columns and the third indicates which subplot we are addressing. This code creates a 2x2 grid of subplots.

Confusingly, matplotlib calls subplots "axes", which is easily confused with the plural of axis.

In [None]:
fig = plt.figure(figsize = (9.5, 6))

plt.subplots_adjust(wspace = 0.2, hspace = 0.3) # specified as fraction of subplot sizes

fig.suptitle('Bike Rentals at Different Temperatures\nBy Season', 
             fontsize = 12, 
             fontweight = 'bold', 
             color = 'black', 
             position = (0.5, 0.99))

ax1 = fig.add_subplot(221)
ax1.scatter('temp', 'cnt', data = day[day['season'] == 1], c = 'g')
ax1.set_title('Spring', fontdict = fontdict, color = 'g')
ax1.set_ylabel('Count of Total Rental Bikes', fontdict = fontdict, position = (0, -0.1))

ax2 = fig.add_subplot(222)
ax2.scatter('temp', 'cnt', data = day[day['season'] == 2], c= 'orange')
ax2.set_title('Summer', fontdict = fontdict, color = 'orange')

ax3 = fig.add_subplot(223)
ax3.scatter('temp', 'cnt', data=day[day['season'] == 3], c = 'brown')
ax3.set_title('Autumn', fontdict = fontdict, color = 'brown')

ax4 = fig.add_subplot(224)
ax4.scatter('temp', 'cnt', data = day[day['season'] == 4], c = 'b')
ax4.set_title('Winter', fontdict = fontdict, color = 'b')
ax4.set_xlabel('Normalized temperature', fontdict = fontdict, position = (-0.1, 0));

There's actually a problem here since the temperature scale is not the same for all subplots and this may lead us to some erroneous conclusions. It is clearer if we stack the subplots.

In [None]:
fig = plt.figure(figsize = (9.5, 6))

plt.subplots_adjust(hspace = 0.6)

fig.suptitle('Bike Rentals at Different Temperatures\nBy Season', 
             fontsize = 12, 
             fontweight = 'bold', 
             color = 'black', 
             position = (0.5, 0.99))

ax1 = fig.add_subplot(411)
ax1.scatter('temp', 'cnt', data = day[day['season'] == 1], c = 'g')
ax1.set_title('Spring', fontdict = fontdict, color = 'g')

ax2 = fig.add_subplot(412)
ax2.scatter('temp', 'cnt', data = day[day['season'] == 2], c= 'orange')
ax2.set_title('Summer', fontdict = fontdict, color = 'orange')
ax2.set_ylabel('Count of Total Rental Bikes', fontdict = fontdict, position = (-0.5, -0.1))

ax3 = fig.add_subplot(413)
ax3.scatter('temp', 'cnt', data=day[day['season'] == 3], c = 'brown')
ax3.set_title('Autumn', fontdict = fontdict, color = 'brown')

ax4 = fig.add_subplot(414)
ax4.scatter('temp', 'cnt', data = day[day['season'] == 4], c = 'b')
ax4.set_title('Winter', fontdict = fontdict, color = 'b')
ax4.set_xlabel('Normalized temperature', fontdict = fontdict);

Now we can clearly see that the scales are different.

We can fix this by setting the limits to be the same.

In [None]:
ax1.set_xlim(0, 1.0)
ax2.set_xlim(0, 1.0)
ax3.set_xlim(0, 1.0)
ax4.set_xlim(0, 1.0)
y_min = day.cnt.min()
y_max = day.cnt.max()
ax1.set_ylim(y_min, y_max)
ax2.set_ylim(y_min, y_max)
ax3.set_ylim(y_min, y_max)
ax4.set_ylim(y_min, y_max)

That's much clearer now.

However, it wasn't ideal that we had to set all the limits individually. And, in this dataset, we could see easily how to set the temperature scale.

As an alternative, we can share the axis between subplots.

In [None]:
fig = plt.figure(figsize = (9.5, 6))

plt.subplots_adjust(hspace = 0.6)

fig.suptitle('Bike Rentals at Different Temperatures\nBy Season', 
             fontsize = 12, 
             fontweight = 'bold', 
             color = 'black', 
             position = (0.5, 0.99))

# define ax2 first because it has the widest range
ax2 = fig.add_subplot(412)
ax2.scatter('temp', 'cnt', data = day[day['season'] == 2], c= 'orange')
ax2.set_title('Summer', fontdict = fontdict, color = 'orange')
ax2.set_ylabel('Count of Total Rental Bikes', fontdict = fontdict, position = (-0.5, -0.1))

ax1 = fig.add_subplot(411, sharex = ax2, sharey = ax2)
ax1.scatter('temp', 'cnt', data = day[day['season'] == 1], c = 'g')
ax1.set_title('Spring', fontdict = fontdict, color = 'g')

ax3 = fig.add_subplot(413, sharex = ax2, sharey = ax2)
ax3.scatter('temp', 'cnt', data=day[day['season'] == 3], c = 'brown')
ax3.set_title('Autumn', fontdict = fontdict, color = 'brown')

ax4 = fig.add_subplot(414, sharex = ax2, sharey = ax2)
ax4.scatter('temp', 'cnt', data = day[day['season'] == 4], c = 'b')
ax4.set_title('Winter', fontdict = fontdict, color = 'b')
ax4.set_xlabel('Normalized temperature', fontdict = fontdict);

That's useful, but in this case we needed to know that Summer had the widest data range and that isn't something we can always guarantee. For this particular data, we would probably do it the hard way, as shown earlier.

## Seaborn

Now let's see what impact seaborn has on this problem.

We start by importing it and setting the style sheet.

In [None]:
import seaborn as sns
sns.set()

If you go back and re-run any of the previous graphs, you will see an immediate impact based on the default seaborn styles.

We will just make a small change to the data set to make it more useful. Seaborn uses values in legends and labels automatically, so it is more useful to have the seasons pre-mapped to their names rather than filtering the data in each series.

In [None]:
day.rename(columns = {'season': 'Season'}, inplace = True)
day['Season'] = day.Season.map({1: 'Spring', 2: 'Summer', 3: 'Autumn', 4: 'Winter'})

Now we can create a seaborn plot to match our first seasonal scatter plot.

Most of the code here is aimed at the appearance of the plot. But compare the 4 calls to ``plt.scatter`` that we had to make before with the single call to ``sns.scatterplot`` this time.

We have also changed the marker in each series. We could have done this directly in ``plt.scatter`` by setting the marker to 'x', 's' (for square) and '+'.

In [None]:
plt.figure(figsize = (7, 6))

sns.set_context('talk', font_scale = 0.9)
sns.set_style('ticks')

sns.scatterplot(
    x = 'temp',
    y = 'cnt', 
    hue = 'Season',   # change color for each value
    data = day, 
    style = 'Season', # change the marker for each value
    palette = ['green', 'orange', 'brown', 'blue'],
    legend = 'full'
)

plt.legend(
    scatterpoints = 1,
    loc = 'upper left',
    ncol = 1,
    fontsize = 12
)
plt.xlabel('Normalized Temperature', fontdict = fontdict)
plt.ylabel('Count of Total Bike Rentals', fontdict = fontdict)
plt.title('Bike Rentals at Different Temperatures\nBy Season', fontdict = fontdict, position = (0.5, 1));

Now let's create the stacked scatter that we had before.

Again, notice that we can achieve this in a single plotting call.

In [None]:
sns.set() 

g = sns.relplot(
    x = 'temp', 
    y = 'cnt', 
    hue = 'Season', 
    data = day,
    palette = ['green', 'orange', 'brown', 'blue'],
    row = 'Season', 
    legend = False,
    height = 2, 
    aspect = 4.8, 
    style = 'Season'
)

g.fig.subplots_adjust(top = 0.9)
g.fig.suptitle('Bike Rentals at Different Temperatures\nBy Season', position=(0.5, 0.99), fontweight = 'bold', size = 12)
# this is a bit of a cheat, supressing default labels and associating the axis label with the last subplot
# we could also use g.set_xlabels or g.axes etc
g.set_axis_labels('', '')
plt.ylabel('Count of Total Bike Rentals', fontdict = fontdict, position = (0, 2.0))
plt.xlabel('Normalized Temperatures', fontdict = fontdict)
g.set_titles(template = '{row_name}')

# Colors

In [None]:
sns.palplot(sns.color_palette())

In [None]:
sns.palplot(sns.color_palette('muted'))

In [None]:
sns.palplot(sns.color_palette('bright'))

In [None]:
sns.palplot(sns.color_palette('pastel'))

In [None]:
sns.palplot(sns.color_palette('rainbow'))

In [None]:
sns.palplot(sns.color_palette('Blues'))

In [None]:
sns.palplot(sns.crayon_palette(['Maroon', 'Bittersweet', 'Burnt Orange', 'Canary', 'Fern', "Robin's Egg Blue", 'Royal Purple']))

# Other Plots

In [None]:
tips = sns.load_dataset('tips')
iris = sns.load_dataset('iris')
mpg = sns.load_dataset('mpg')

## Using Size in a Scatter Plot

In [None]:
sns.set(style="ticks")

g = sns.relplot(x = 'total_bill', y = 'tip', hue = 'time', size = 'size',
                palette = ['b', 'r'], sizes = (10, 100),
                col = 'time', data = tips)

## Various Categorical Plots

A strip plot is a combination of scatter and bar graph, but it allows data points to overlap. A swarm plot does not allow them to overlap.

In [None]:
sns.set(style = 'whitegrid') 

ax = sns.catplot(x = 'species', y = 'sepal_length', data = iris, kind = 'strip'); 
plt.title('Sepal Length by Species') 

In [None]:
ax = sns.catplot(x = 'species', y = 'sepal_length', data = iris, kind = 'swarm'); 
plt.title('Sepal Length by Species')

In a categorical bar chart, seaborn plots the mean and uses an error bar to indicate the range of values.

In [None]:
ax = sns.catplot(x = 'species', y = 'sepal_length', data = iris, kind = 'bar'); 
plt.title('Sepal Length by Species')

A categorical point chart also plots the mean value of a category, which it shows as a point with an error bar. The mean values are joined by lines in each series, making this a good choice for data trends.

In [None]:
ax = sns.catplot(x = 'model_year', y = 'mpg', data = mpg, kind = 'point', hue = 'cylinders', dodge = True); 
plt.title('Miles per gallon by year and engine cylinders')

## Joint Plot

The joint plot adds plots to the x- and y-axis of the main plot. In the default case, it combines a scatter (or strip) chart with bar charts.

In [None]:
ax = sns.jointplot(x = 'model_year', y = 'mpg', data = mpg); 

## Heatmap

A heatmap uses color, or the intensity of color to indicate the data value at an intersection in a matrix.

In [None]:
flights = sns.load_dataset('flights')
flights = flights.pivot('month', 'year', 'passengers')
ax = sns.heatmap(flights)