In [4]:
%matplotlib inline

import time

import numpy as np
import pandas as pd

# Make inline plots vector graphics instead of raster graphics
from IPython.display import set_matplotlib_formats

set_matplotlib_formats('pdf', 'svg')

import matplotlib
import matplotlib.pyplot as plt

import seaborn as sns
sns.set_style("whitegrid")
sns.set_context("paper")

import mpld3

class d3(object):
    """with statement for d3 in only one plot"""

    def __enter__(self):
        mpld3.enable_notebook()

    def __exit__(self ,type, value, traceback):
        mpld3.disable_notebook()

        
class Swap(object):

    def __init__(self, *args, **kwargs):
        self.args = args
        self.kwargs = kwargs

    def __enter__(self):
        pass

    def __exit__(self, type, value, traceback):
        pass
        
        
class SwapPalette(Swap):
    
    def __enter__(self):
        self.orig = sns.color_palette()
        sns.set_palette(*self.args, **self.kwargs)
        
    def __exit__(self, type, value, traceback):
        sns.set_palette(self.orig)

        
class SwapStyle(Swap):
    
    def __enter__(self):
        self.orig = sns.axes_style()
        sns.set_style(*self.args, **self.kwargs)
        
    def __exit__(self, type, value, traceback):
        sns.set_style(self.orig)

        
class SwapContext(Swap):
    
    def __enter__(self):
        self.orig = sns.plotting_context()
        sns.set_context(*self.args, **self.kwargs)
        
    def __exit__(self, type, value, traceback):
        sns.set_style(self.orig)

        
class Timer(object):
    
    def __init__(self, verbose=False):
        self.verbose = verbose

    def __enter__(self):
        self.start = time.time()
        return self

    def __exit__(self, *args):
        self.end = time.time()
        self.secs = self.end - self.start
        self.msecs = self.secs * 1000  # millisecs
        if self.verbose:
            print('elapsed time: {} ms'.format(self.msecs))


In [None]:

def horizon_plot(df, key, width, cut=None, start='start', chrom='chrom', pop='pop'):
    """
    Horizon bar plot made allowing multiple chromosomes and multiple samples.
    """
    
    from math import isclose, floor, log10
    
    def horizon(row, i, cut):
        """
        Compute the values for the three 
        positive and negative intervals.
        """
        val = getattr(row, i)

        if val < 0:
            for i in range(4):
                yield 0

        val = abs(val)
        for i in range(3):
            yield min(cut, val)
            val = max(0, val-cut)
        yield int(not isclose(val, 0, abs_tol=1e-8)) * cut

        if val >= 0:
            for i in range(4):
                yield 0

    def chrom_sort(item):
        """
        Sorts in a meaningful way for chromosomes.
        """
        if item.startswith('chr'):
            it = item[3:]
        if it.isdigit():
            return int(it)
        else:
            return it

    def round_to_1_signif(x):
        """
        Rounds to first significant digit.
        """
        return round(x, -int(floor(log10(abs(x)))))

    class SwapStyle(object):
        def __init__(self, *args):
            self.style = args
        def __enter__(self):
            self.orig = sns.axes_style()
            sns.set_style(*self.style)
        def __exit__(self ,type, value, traceback):
            sns.set_style(self.orig)
        
    # set cut if not set
    if cut is None:
        cut = max(max(df[key]), max(-df[key])) / 3

    # make the data frame to plot
    row_iter = df.itertuples()
    col_iterators = zip(*(horizon(row, key, cut) for row in row_iter))
    col_names = ('yp1', 'yp2', 'yp3', 'yp4', 
                 'yn1', 'yn2', 'yn3', 'yn4')
    df2 = (df.copy(deep=False)
           .assign(**dict(zip(col_names, col_iterators)))
          )

    # chromosome names
    chrom_names = list(df.groupby(chrom).groups.keys())
    # number of populations
    nr_pop = len(df.groupby(pop).groups)
    # sizes of chromosomes
    chrom_sizes = list(df.groupby(chrom).aggregate(np.max)[start])

    # relative width of each plot facet 
    # (using lengths of chromosomes)
    facet_widths_ratios = chrom_sizes * nr_pop

    # make the plot
    with SwapStyle('ticks'):

        # make the facet grid
        g = sns.FacetGrid(df2, 
                          col=chrom, 
                          row=pop,
                          sharex=False,
                          margin_titles=True,
                          size=1, 
                          aspect=10,
                          col_order=sorted(chrom_names, key=chrom_sort),
                          row_order=None,                      
                          gridspec_kws={'hspace':0.0, 
                                        "width_ratios": facet_widths_ratios}
                         )

        # plot colors
        colours = sns.color_palette("Blues", 3) + ['black'] + \
                  sns.color_palette("Reds", 3) + ['grey']

        # first y tick
        ytic1 = round_to_1_signif(cut / 3)

        for col_name, colour in zip(col_names, colours):
            plt.setp(g.fig.texts, text="") # hack to make y facet labels align...
            # map barplots to each facet
            g.map(plt.bar, 
                  start, 
                  col_name, 
                  edgecolor = "none", 
                  width=width, 
                  color=colour)
            # no tick labels on x
            g.set(xticklabels=[])
#            g.set_titles('{col_name}', '{row_name}')

        for ax, max_val in zip(g.axes.flat, facet_widths_ratios):
            ax.set_xlim(0, max_val+1)
            ax.set_ylim(0, cut)
            ax.set(xlabel='', ylabel='')
            ax.set(xticks=np.arange(0, max_val, round_to_1_signif(max_val) / 10))
            ax.set(yticks=[ytic1, ytic1*2, ytic1*3])
            
        # remove top and right frame
        sns.despine()
        
        return g.fig


# n = 100
# df = pd.DataFrame({'chrom': ['chr11']*2*n + ['chr2']*2*n,
#                    'pop': ['EUR']*1*n + ['AFR']*1*n + ['EUR']*1*n + ['AFR']*1*n, 
#                    'start': list(range(1*n)) * 4, 
#                    'pi': list(np.sin(np.linspace(-np.pi, np.pi, 1*n))) * 4})



# with Timer() as t:
#     with SwapContext("notebook", font_scale=0.5):
#         fig = horizon_plot(df, 'pi', width=1) # width should be end-start
#         # save to file
#         plt.savefig('/Users/kmt/Desktop/foo.pdf')
#         # suppress inline plot
#         plt.close(fig)    
# print(t.secs)