In [14]:
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn import preprocessing
import numpy as np
sns.set()

def stacked_histogram(df, param, yscale = 'linear'):
    fig, ax = plt.subplots(1, 1, figsize=(8, 5))
    vals = [df[df.mom == i][param] for i in df.mom.unique()[1:]]
    ax.hist(vals, normed = True, stacked = True, histtype = "stepfilled")#, linewidth = 2)
    ax.set_yscale(yscale)

In [112]:
%matplotlib inline

def get_cmap(n, name='hsv'):
    '''Returns a function that maps each index in 0, 1, ..., n-1 to a distinct 
    RGB color; the keyword argument name must be a standard mpl colormap name.'''
    return plt.cm.get_cmap(name, n + 1)

def predicted_histogram(data, target, labels = None, nbins = 10):
    """@params:
        data = n x 1 array of parameter values
        target = n x categories array of predictions
    """
    target = preprocessing.normalize(target, norm = "l1")
    if labels == None:
        labels = ["" for i in range(target.shape[1])]
    #1 decide bins
    ma = np.amax(data) * 1.0
    mi = np.amin(data)
    bins = np.linspace(mi, ma, nbins)
    bin_locs = np.digitize(data, bins, right = True)
    #2 set up bin x category matrix
    #  Each M(bin, category) = Sum over particles with param in bin of category
    M = np.array([np.sum(target[np.where(bin_locs == i)], axis = 0) for i in range(nbins)])
    #3 plot each category/bin
    fig, ax = plt.subplots(1, 1, figsize=(8, 5))
    bars = np.array([M[:, i] for i in range(M.shape[1])])
    cmap = get_cmap(len(bars), 'gnuplot')
    for i in range(len(bars)):
        ax.bar(bins, bars[i], 
               bottom = sum(bars[:i]), 
               color = cmap(i), 
               label = labels[i],
               width = bins[1]
              )
    ax.legend()

[[ 0.80790136  0.10983708  0.0377026   0.04455896]
 [ 0.05050174  0.30910549  0.42908013  0.21131264]
 [ 0.3645101   0.07785208  0.10645016  0.45118766]
 [ 0.20786002  0.2123008   0.19062013  0.38921905]
 [ 0.10662716  0.52360468  0.0372157   0.33255246]
 [ 0.1399682   0.06051862  0.04438893  0.75512425]
 [ 0.52544248  0.15647124  0.05715551  0.26093077]
 [ 0.15577868  0.58128735  0.08151145  0.18142252]
 [ 0.04631088  0.48728348  0.19569862  0.27070702]
 [ 0.21792471  0.15505257  0.28904563  0.33797708]]
9 0
