In [1]:
%matplotlib inline

In [9]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.colors as colors
import matplotlib.cm as cmx
import numpy as np

In [10]:
N_shuffle = 14
N_generatigton = 10
q_vals = [0.1, 0.3, 0.5, 0.7, 0.9]
N_q = len(q_vals)

q_batches = [40, -37, 9.9, -22, -9.8]
q_sgds = np.zeros((N_q,N_shuffle))
for i in range(N_q):
    q_sgds[i] = np.random.normal( q_batches[i], 2, N_shuffle)
q_sgds

array([[ 34.91567312,  39.54396675,  40.97603758,  42.50653306,
         39.61639467,  40.21547644,  38.90662515,  39.32204441,
         41.74963086,  41.151529  ,  43.34524857,  38.55947317,
         39.66301054,  40.17966514],
       [-38.23123833, -37.71464305, -36.19460087, -38.84068848,
        -36.98953606, -36.59524624, -36.83944062, -40.83243183,
        -39.74379482, -37.2817071 , -38.2924016 , -36.86650398,
        -35.04448991, -36.0370665 ],
       [ 10.80704276,   9.53439256,  11.52180429,  13.06197995,
         13.07370963,  12.81894533,  12.64687333,   8.7396911 ,
          8.5241926 ,  12.46301935,   4.73869092,   9.35624978,
         10.15127028,  10.20966171],
       [-18.5530093 , -22.16396346, -24.67297023, -22.39028864,
        -18.13693893, -24.33570219, -22.90230161, -21.8208639 ,
        -21.74499427, -24.21026599, -19.17459128, -22.87638189,
        -19.96322721, -22.76681103],
       [-12.12981225, -10.06635572,  -9.70405143, -13.4842777 ,
        -11.94254849

In [11]:
c_Norm = colors.Normalize(vmin=0, vmax=1)
scalarMap = cmx.ScalarMappable(norm=c_Norm, cmap=plt.get_cmap('gist_rainbow'))
styles = ['-', '--']
names = ['-q batch', '-q sgd']

In [15]:
def plot_quantile_shuffles(q_vals, q_batches, q_sgds):

    fig = plt.figure(figsize=(16,4))
    ax_batch = fig.add_subplot(211)
    ax_sgd = fig.add_subplot(212, sharex = ax_batch)
    fig.suptitle('Batch quantile VS SGD quantile')
    
    for idx in range(N_q):
        colorVal = scalarMap.to_rgba(q_vals[idx])
        qb = q_batches[idx]

        ax_batch.plot([qb,qb], [0,1], styles[0], label= str(q_vals[idx])+names[0], color=colorVal)

        for i in range(q_sgds.shape[0]):
            qs = q_sgds[idx][i]
            if (i==0): 
                ax_sgd.plot([qs,qs], [0,1], styles[1], label=' '+str(+q_vals[idx])+names[1]+'  ', color=colorVal)
            else: 
                ax_sgd.plot([qs,qs], [0,1], styles[1], color=colorVal)

    #set
    ax_batch.set_xlabel('batch value')
    ax_batch.xaxis.set_label_coords(-0.05, -0.05)
    ax_batch.set_ylim([-0.1,1.1])
    ax_batch.set_yticks([])

    plt.locator_params(axis='x', nbins=10)

    ax_sgd.set_xlabel('sgd value')
    ax_sgd.xaxis.set_label_coords(-0.05, -0.05)

    # set position of legend
    ax_batch.legend(loc='lower center', bbox_to_anchor=(0.5, -1.8),
            frameon=False, ncol=len(q_vals))
    ax_sgd.legend(loc='lower center', bbox_to_anchor=(0.5, -0.8),
        frameon=False, ncol=len(q_vals))
        
    return fig
                    
# fig = plot_quantile_shuffles(q_vals, q_batches, q_sgds)


In [16]:
def generate_mtx():
    mtx = np.zeros((N_q, N_generation, N_shuffle))
    for q_idx in range(N_q):
        # each generation:
        mtx_q = mtx[q_idx]
        for gen_idx in range(N_generation):
            # each shuffle:
            mtx_q[gen_idx] = np.random.normal(q_batches[q_idx], 5, N_shuffle)
    return mtx

q_batches_lst = np.zeros((N_q,N_generation))
for i in range(N_q):
    q_batches_lst[i] = np.random.normal(q_batches[i], 2, N_generation)

q_sgds_lst = generate_mtx()

In [17]:
# q_batches_lst: (N_q,N_generation)
# q_sgds_lst:  (N_q, N_generation, N_shuffle)

def plot_quantile_generations(q_vals, q_batches_lst, q_sgds_lst):
    fig = plt.figure(figsize=(16,4))
    ax_batch = fig.add_subplot(211)
    ax_sgd = fig.add_subplot(212, sharex = ax_batch)
    fig.suptitle('Batch quantile VS SGD quantile')
    bins = 50
    
    for q_idx in range(N_q):
        colorVal = scalarMap.to_rgba(q_vals[q_idx])
        for gen_idx in range(N_generation):
            qb = q_batches_lst[q_idx][gen_idx]
            if gen_idx==0:
                ax_batch.plot([qb,qb], [0,1], styles[0], label= str(q_vals[q_idx])+names[0], color=colorVal)
            else:
                ax_batch.plot([qb,qb], [0,1], styles[0], color=colorVal)
        
        mtx_q = q_sgds_lst[q_idx].reshape(-1)
        ax_sgd.hist(mtx_q, bins, alpha=0.5, label = ' '+str(q_vals[q_idx])+names[1]+'  ', color=colorVal)
        

    #set
    ax_batch.set_xlabel('batch value')
    ax_batch.xaxis.set_label_coords(0.04, 0.98)
#     ax_batch.set_ylim([-0.1,1.1])
    ax_batch.set_yticks([])

    plt.locator_params(axis='x', nbins=10)

    ax_sgd.set_xlabel('sgd value')
    ax_sgd.xaxis.set_label_coords(0.035, 0.98)
    ax_sgd.set_yticks([])


    # set position of legend
    ax_batch.legend(loc='lower center', bbox_to_anchor=(0.5, -1.8),
            frameon=False, ncol=len(q_vals))
    ax_sgd.legend(loc='lower center', bbox_to_anchor=(0.5, -0.8),
        frameon=False, ncol=len(q_vals))
        
    return fig
                    
# fig = plot_quantile_generations(q_vals, q_batches_lst, q_sgds_lst)

array([[[ 0.51967811,  0.40028173,  0.53910664, ...,  1.3202296 ,
         -2.14616434, -0.12513262],
        [ 0.27211068, -0.56608795,  2.51701533, ...,  0.56960873,
         -0.96089013, -1.14036916],
        [-0.45198768, -0.35282595,  0.1422503 , ...,  0.13383598,
          0.94578448, -0.56654507],
        ...,
        [-1.55076983, -0.05175916, -2.10675117, ...,  0.33567524,
          0.953328  ,  0.26744839],
        [ 0.35724911, -0.6843614 , -0.45656863, ..., -0.42856936,
          1.30875607,  0.90033222],
        [ 1.44441597,  1.16641531, -0.22491533, ..., -0.95811295,
          0.02894003,  0.06582314]],

       [[ 1.00797587, -0.13220323,  0.1576292 , ...,  1.20980154,
         -0.3413909 , -0.70576462],
        [-0.77905469,  0.62884281, -1.10670449, ..., -1.41236483,
          1.47922036,  0.15704651],
        [-0.10411607, -1.08454886, -0.87707018, ..., -0.40535383,
          0.33336862, -1.64124987],
        ...,
        [-0.64772167, -2.17114946,  0.24130834, ..., -