# Styling Charts in Seaborn

Seaborn aims to replicate much of what ggplot does using a core of matplotlib, providing a nice grammer of graphics api for python users. This leveraging of the monolithic matplotlib library provides it with almost limitless flexibility, although sometimes how to do things can be a bit buried in stack overflow replies. The library also has great integration with pandas, so the data wrangling-to-viz workflow is seamless. 

In this notebook, I want to give a brief overview on how to get some charts off the ground and looking nice using seaborn.

## Contents
1. Bar chart with data labels
2. Line chart with timeseries axis
3. Small multiples/facet grid
4. Saving figures

## Data
There are 2 primary data tables I will be using relate to Covid-19 case numbers: data for the number of new cases reported per country (source: https://covid19.who.int/), and global population data (source: https://data.worldbank.org/). Both were retrieved at the start of June 2020.



## General style elements
Setting the general style for seaborn charts is best done right at the start using the `sns.set()`method. There are a few different “style” presets we can choose from including “white”, “whitegrid”, “dark”, “darkgrid”, and “ticks”. I normally opt for “whitegrid”. We can also apply a scaling factor to font sizes if desired, and perhaps most importantly set a custom color palette (or choose from one of the many available seaborn palettes).

At the top of my notebook I have the following block:

In [None]:
# color palette can be passed as a list of hex codes
custom_colors = ["#9b59b6", "#3498db", "#95a5a6", "#e74c3c", "#34495e", "#2ecc71"]

# set overall plot style, font size scaling factor, and color palette
sns.set(style="whitegrid", font_scale=1.2, palette=custom_colors)

Note that we can (and will) tweak some settings later, but these top level style settings allow us to maintain a consistent chart aesthetic throughout our work.

### Helper functions
There are a couple of helper functions that are very useful for quickly adding or styling additional elements to our seaborn charts. These typically glue a lot of matplotlib code together, and so can be a bit verbose to write out each time we need it, hence wrapping them up as helpers.

For example, to add data labels to bar charts something like the following should do:

In [1]:
def datalabel_bar(ax, fontsize=12):
  rects = [rect for rect in ax.get_children() 
    if isinstance(rect, mpl.patches.Rectangle)]
  
  for rect in rects:
    height = rect.get_height()
    if height > 0:
      ax.annotate(f"{height:.0f}", xy=(rect.get_x() + rect.get_width() / 2, height), 
        xytext=(0, 3), # 3 points offset textcoords="offset points"
        ha="center", va="bottom", fontsize=fontsize)

We could tweak this to handle both horizontal and vertical bar charts, but I find it simpler just to have a separate function, which we’ll `call _hbar` for horizontal bars.

In [1]:
def datalabel_hbar(ax, fontsize=12):
  rects = [rect for rect in ax.get_children() if isinstance(rect, mpl.patches.Rectangle)]
  
  for rect in rects:
    width = rect.get_width()
    if width > 1:
      ax.annotate(f"{width:.0f}", xy=(width, rect.get_y() + rect.get_height() / 2),
        xytext=(5,-1), # 5 points offset
         textcoords="offset points", ha="left", va="center", fontsize=fontsize)

The only required argument for these functions is a matplotlib axes object, which is returned from calling seaborn to create charts.

Note that you can change the string formatting of the labels as desired; for example, if you were plotting percentages and wanted them displayed to 2 d.p. you would change the first argument of the call to `.annotate()` above to `f"{width:.2%}"`.

I also really like the `nice_dates() `function [written by Zi Chong Kao](https://kaomorphism.com/2017/09/10/Sane-Date-Axes.html) for creating hierarchical timeseries axis labels in a smart way. I’ve tweaked it slightly to operate on matplotlib axes objects rather than figure objects as I find it handles sizing a bit better. 

In [1]:
# Adapted from https://kaomorphism.com/2017/09/10/Sane-Date-Axes.html
import pandas as pd
import matplotlib.dates as mdates
from datetime import timedelta as tdelta
from matplotlib.ticker import FuncFormatter

INTERVALS = {
  'YEARLY'  : [1, 2, 4, 5, 10],
  'MONTHLY' : [1, 2, 3, 4, 6],
  'DAILY'   : [1, 2],
  'WEEKLY'  : [1, 2],
  'HOURLY'  : [1, 2, 3, 4, 6, 12],
  'MINUTELY': [1, 5, 10, 15, 30],
  'SECONDLY': [1, 5, 10, 15, 30],
}

TICKS_PER_INCH = 1.5

def _next_largest(value, options):
    for i in options:
        if i >= value:
            return i
    return i

def _get_dynamic_formatter(timedelta, *fmts):
    def dynamic_formatter(x, pos):
        dx = mdates.num2date(x)
        strparts = [dx.strftime(fmt) for fmt in fmts]
        if pos > 0:
            # renders previous tick and removes common parts
            prior_dx = dx - timedelta
            prior_strparts = [prior_dx.strftime(fmt) for fmt in fmts]
            strparts = [new if new != prior else '' for new, prior in zip(strparts, prior_strparts)]
        return '\n'.join(strparts).strip()
    return dynamic_formatter


def _deduce_locators_formatters(max_ticks, data):
    data_interval_seconds = (data.max() - data.min()) / tdelta(seconds=1)
    interval_seconds = data_interval_seconds / max_ticks
    
    if interval_seconds < tdelta(minutes=0.5).total_seconds():
        # print("xticks: seconds")
        unit_multiple = _next_largest(interval_seconds, INTERVALS['SECONDLY'])
        timedelta = tdelta(seconds=unit_multiple)
        return (mdates.SecondLocator(bysecond=range(0, 60, unit_multiple)),
                FuncFormatter(_get_dynamic_formatter(timedelta, '%M%:S', '%-Hh', '%-d %b')))
    elif interval_seconds < tdelta(hours=0.5).total_seconds():
        # print("xticks: minutes")
        unit_multiple = _next_largest(interval_seconds / tdelta(minutes=1).total_seconds(), INTERVALS['MINUTELY'])
        timedelta = tdelta(minutes=unit_multiple)
        return (mdates.MinuteLocator(byminute=range(0, 60, unit_multiple)),
                FuncFormatter(_get_dynamic_formatter(timedelta, '%H%:M', '%-d %b', '%Y')))
    elif interval_seconds < tdelta(days=0.5).total_seconds():
        # print("xticks: hours")
        unit_multiple = _next_largest(interval_seconds / tdelta(hours=1).total_seconds(), INTERVALS['HOURLY'])
        timedelta = tdelta(hours=unit_multiple)
        return (mdates.HourLocator(byhour=range(0, 24, unit_multiple)),
                FuncFormatter(_get_dynamic_formatter(timedelta, '%-Hh', '%-d %b', '%Y')))
    elif interval_seconds < tdelta(days=3).total_seconds():
        # print("xticks: days")
        unit_multiple = _next_largest(interval_seconds / tdelta(days=1).total_seconds(), INTERVALS['DAILY'])
        timedelta = tdelta(days=unit_multiple)
        return (mdates.WeekdayLocator(byweekday=range(0, 7, unit_multiple)),
                FuncFormatter(_get_dynamic_formatter(timedelta, '%-d', '%b', '%Y')))
    elif interval_seconds < tdelta(days=14).total_seconds():
        # print("xticks: weeks")
        unit_multiple = _next_largest(interval_seconds / tdelta(weeks=1).total_seconds(), INTERVALS['WEEKLY'])
        timedelta = tdelta(days=unit_multiple * 7)
        return (mdates.WeekdayLocator(byweekday=0, interval=unit_multiple),
                FuncFormatter(_get_dynamic_formatter(timedelta, '%-d', '%b', '%Y')))
    elif interval_seconds < tdelta(weeks=26).total_seconds():
        # print("xticks: months")
        unit_multiple = _next_largest(interval_seconds / tdelta(weeks=4).total_seconds(), INTERVALS['MONTHLY'])
        timedelta = tdelta(weeks=unit_multiple * 4)
        return (mdates.MonthLocator(bymonth=range(1, 13, unit_multiple)),
                FuncFormatter(_get_dynamic_formatter(timedelta, '%b', '%Y')))
    else:
        # print("xticks: years")
        unit_multiple = _next_largest(interval_seconds / tdelta(weeks=52).total_seconds(), INTERVALS['YEARLY'])
        return (mdates.YearLocator(base=unit_multiple),
                mdates.DateFormatter('%Y'))
    
def nice_dates(ax):
    fig = ax.get_figure()
    
    # information for deciding tick locations
    bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
    xaxis_length_inch, yaxis_length_inch = bbox.width, bbox.height
    max_ticks = xaxis_length_inch * TICKS_PER_INCH
    data = pd.to_datetime(ax.lines[0].get_xdata())
    
    maj_locator, maj_formatter = _deduce_locators_formatters(max_ticks, data)

    ax.xaxis.set_major_locator(maj_locator)
    ax.xaxis.set_major_formatter(maj_formatter)

## Let’s start plotting

### 1. Bar chart with data labels
For our bar chart we want a horizontal bar chart with countries on the y-axis, and the count of cases per 100,000 people on the x-axis. We’ll filter to just the top 25 countries to make the chart more manageable.

Our dataset currently looks like this:

![data_1](data.png)

We’ll use pandas to sort and filter this data as we pass it into seaborn.

In [3]:
fig,ax = plt.subplots(figsize=(10,8))
g = sns.barplot(
  x="cases_per100k", 
  y="Country", 
  data=df.sort_values(by="cases_per100k", ascending=False).iloc[:25],
  color=sns.color_palette()[0],
  saturation=1.,
)
# add data labels
datalabel_hbar(g)
# remove x-axis tick labels
g.xaxis.set_ticklabels([])
# rename axes labels
g.set(xlabel='Cases per 100k people', ylabel='Country')
# remove border
sns.despine(bottom=True);

We do a few things here. First, we create top level matplotlib figure and axes objects with plt.subplots(). There are ways of setting the figure size in seaborn, but I find this to be the most commonly used method. Then we call sns.barplot() and pass in our variables and our dataframe, as well as a single color (the list of colors is stored in sns.color_palette()), and an override of the saturation (which seaborn defaults to 0.75).

Importantly, this is all saved as a variable g (which is a matplotlib axes object), making it easy to apply additional tweaks to various aspects of the chart. For example, we can pass this g axes to our datalabel_hbar() helper function to add data labels. We then grab the x-axis and set the tick labels to and empty list (effectively removing them), rename the x and y axis labels with sns.set(), and finally, use despine to remove the top, right, and bottom axes “spines”.

The above results in the following output:

![](fig_1.png)

Looking pretty good! If you think the x-axis tick labels should be in there, simply delete the line that we added to remove them.

### 2. Line chart with timeseries axis
Looking at our dataset we see that the countries are grouped into “WHO_regions”. Let’s calculate a 7-day rolling average of new cases reported per region.

First we aggregate the countries into regions, taking the sum of the daily new case numbers:

In [4]:
dff = df.groupby(['Date_reported', 'WHO_region']).agg(
  {'New_cases': 'sum'}
).reset_index()

Then we’ll apply pandas `.rolling()` method per region with a 7 day window, taking the mean of new cases reported:

In [5]:
p = dff.groupby('WHO_region').apply(
  lambda x: x.set_index("Date_reported").rolling(window='7d').mean()
).reset_index().rename(columns={
  'New_cases': 'New_cases_rolling',
})

![](data_2.png)

We’ll now use seaborn’s `lineplot() `function to create a line chart of the daily new cases by region.

In [6]:
# set figure asthetics
sns.set_context(rc={'lines.linewidth': 2.5})
g = sns.lineplot(
  data=p, 
  x="Date_reported", 
  y="New_cases_rolling", 
  hue="WHO_region",
)
# format dates on x-axis
nice_dates(g)
# move legend outside plot
g.legend(loc=2, bbox_to_anchor=(1.05, 1))
# # rename axes labels
g.set(xlabel="", ylabel="Daily new cases (7d rolling avg)")
# remove x gridlines
g.xaxis.grid(False)
# despine top and right borders
sns.despine(left=True, bottom=True);

To make this chart a bit clearer we increased the linewidth by using the sns.set_context() method and passing an “rc” param. To see what other params are available for tweaking in this way you can run sns.plotting_context() (with no params passed) and it will print out the default values. Other params are also stored in sns.axes_style(), but those are set in a slightly different way.

Other styling we applied to the above chart was the use of the helper function nice_dates() to format the timeseries axis in a hierarchical manner, moving the legend to outside of the plot area, renaming the axes, removing vertical gridlines, and despining all sides of the plot (note: top and right are True by default, whereas left and bottom are False by default).

This results in the following chart output.

![](fig_2.png)

### 3. Small multiples/facet grid
For the FacetGrid example we’ll filter our dataset to include only the top 8 countries by total number of cases and will overlay the daily new cases reported with the 7-day rolling average. Make sure that the data is in a “tidy” format before plotting.

To generate our layered chart we’ll use seaborn’s FacetGrid object and overlay a seaborn lineplot and a matplotlib barplot. I find it’s easiest to use pandas for sorting/ordering the data before passing to the plotting functions, which we do here by making use of the pandas categorical data type so that we can sort the countries by passing a list of values. The full code looks like this:

In [2]:
# restrict to just the top 8 countries by total cases
countries = df.groupby('Country')['Cumulative_cases'].max().sort_values(ascending=False).index[:8]
p = df.loc[df['Country'].isin(countries)]
p['Country_cat'] = pd.Categorical(
  p['Country'], 
  categories=countries, 
  ordered=True
)
data_sorted = p.sort_values(by='Country_cat')
# create facet grid
g = sns.FacetGrid(
  data=data_sorted, 
  col="Country", 
  col_wrap=4
)
# map the bar plot using a matplotlib plotting function
g.map(plt.bar, "Date_reported", "New_cases", color="b", alpha=0.4)
# map the lineplot using seaborn
kws_line = dict(
  x="Date_reported",
  y="New_cases_rolling",
)
g.map_dataframe(sns.lineplot, **kws_line)
# set subtitles and axis titles
g.set_titles("{col_name}")
# apply nicely formatted timeseries labels to bottom row
for ax in g.axes:
  nice_dates(ax)


A few things to note here. I like to use the .map_dataframe() method rather than .map() when using a seaborn plotting function. I find it a bit simpler to list out the params as key word arguments in a seperate kws = dict() variable and then unpack them rather than relying on the positional arguments that must be passed if using .map(). However, do note that this relies on the data being in a tidy format already, and so if using a matplotlib plotting function (e.g. plt.bar) where the data passed is not in the tidy format that seaborn prefers, then using .map() with the positional arguments would be better (or necessary).

The g returned here is no longer a matplotlib axes object but rather a seaborn FacetGrid object, and so setting additional style parameters is a bit different. We use the nifty .set_titles("{col_name}") method to override the default seaborn style of having the title of each subplot as “var_name = col_name”. We also looped through the matplotlib axes (fetched from g.axes) to apply the nice_dates() helper function to nicely format the datetime axis. FacetGrid despines the subplots be default, so we don’t need to call that here.

The final chart output looks like this.

![](fig_3.png)

You can map multiple charts onto the same FacetGrid object, but just make sure the data is being passed correctly (i.e. using .map() or .map_dataframe()). This creates really useful and data rich visualisations.

### 4. Saving figures
Saving charts in seaborn is very straightforward and leverages matplotlib’s savefig() functionality. However, there is a slight difference between the single charts shown above and the facet chart.

The single charts return a matplotlib axes object, which we store in the variable g. To call .savefig() we actually first need to get the figure object. For example:

In [3]:
fig = g.get_figure()
fig.savefig('example_chart.svg', bbox_inches="tight")

The facet chart doesn’t return a matplotlib axes object, but rather a seaborn FacetGrid object, which we store in the variable g. Seaborn has built the .savefig() method directly into this object, and so saving a facet chart is as easy as:

In [4]:
g.savefig('example_facet_chart.svg', bbox_inches="tight")

Alternatively, if you’d like to keep it consistent you can first grab the matplotlib figure object with .fig, then save in the usual way:

In [5]:
fig = g.fig # note: fig is an attribute not a method
fig.savefig('example_facet_chart.svg', bbox_inches="tight")

## Summary
There is obviously a bunch more you can do by digging through all the possible matplotlib objects and functions, but it does get pretty overwhelming pretty quickly.

The seriousness of the data underlying these examples is nothing to take lightly. However it is important that we can interpret and communicate issues in our society accurately and fairly, and so giving ourselves the skills to do so can help to make the available information more democratic and robust.

For some useful resources on communicating data honestly and effectively, I can highly recommend the following:

- [Storytelling with Data](http://www.storytellingwithdata.com/) — a classic data viz blog and book
- [The WSJ Guide to Information Graphics](https://www.goodreads.com/book/show/6542897-the-wall-street-journal-guide-to-information-graphics) by Dona M. Wong

<div class="admonition note alert alert-info">
<p class="first admonition-title" style="font-weight: bold;">Note</p>
<p>The content in this notebook was copied from https://medium.com/@tttgm/styling-charts-in-seaborn-92136331a541</p>
</div>