# Classification Distributions Visualizer

To visualize the old DRAM model, DRAM_classify_blobs, use these settings:
min_edge = 2,
max_edge = 5,
min_blobs = 1,
max_blobs = 9,
batch_size = 9000.
Change analysis.py to import DRAM_classify_blobs and load the DRAM_test_square checkpoint.

In [1]:
print("Setting everything up!")
import warnings
warnings.filterwarnings('ignore')
from bokeh.io import push_notebook, show, output_notebook
output_notebook()
from bokeh.layouts import row, column
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource, HoverTool, CustomJS, FixedTicker
import bokeh.palettes as pal
from bokeh.layouts import layout, Spacer, gridplot

Setting everything up!


In [2]:
import ipywidgets as widgets
from ipywidgets import *
from IPython.display import display, clear_output

In [3]:
import numpy as np
from bokeh.charts import Bar, Histogram

The bokeh.charts API has moved to a separate 'bkcharts' package.

This compatibility shim will remain until Bokeh 1.0 is released.
After that, if you want to use this API you will have to install
the bkcharts package explicitly.

  warn(message)


In [4]:
from analysis import classify_image, glimpses, read_n, classify_imgs2

['/usr/local/lib/python3.5/dist-packages/ipykernel_launcher.py', 'true', 'true', 'true', 'true', 'true', 'true', 'model_runs/move_attn2/classify_log.csv', 'model_runs/move_attn2/classifymodel_0.ckpt', 'model_runs/move_attn2/classifymodel_', 'model_runs/move_attn2/zzzdraw_data_5000.npy', 'false', 'true', 'false', 'false', 'true']
analysis.py


In [5]:
# import numpy as np
# import scipy.special

%matplotlib inline

import matplotlib
import numpy as np
import matplotlib.pyplot as plt
from random import randint

from scipy import linspace
from scipy import pi,sqrt,exp
from scipy.special import erf
import scipy

from bokeh.layouts import gridplot
from bokeh.plotting import figure, show

In [6]:
def get_pdf(sigma, x, mu):
    """Calculate the pdf."""
    
#     pdf = 1/(x* sigma * np.sqrt(2*np.pi)) * np.exp(-(np.log(x)-mu)**2 / (2*sigma**2))
    pdf = 1/(sigma * np.sqrt(2*np.pi)) * np.exp(-(x-mu)**2 / (2*sigma**2))
    return pdf


def get_cdf(sigma, x, mu, a):
    """Calculate the cdf."""
    
#     cdf = (1 + erf(a*x / sqrt(2))) / 2 
    cdf = (1 + scipy.special.erf((a*x - mu) / np.sqrt(2*sigma**2))) / 2
    return cdf


def get_p(pdf, cdf):
    """Calculate p to create skew curve."""
    
    p = 2 / pdf * cdf / 100000
    return p

    
def curve(sigma, x, mu, plot, label="", color="gray"):
    """Add normal distribution curve to plot."""
    
    pdf = get_pdf(sigma, x, mu)
    min_blobs = 1
    plot.line(x + min_blobs, pdf, line_color=color, line_width=8, alpha=1, legend=label)
    
    
def skew_curve(sigma, x, mu, a, plot):
    """Add a skewed curve to the plot."""
    
    p = get_p(get_pdf(sigma, x, mu), get_cdf(sigma, x, mu, a))

    def f(x): return 1/(sigma * np.sqrt(2*np.pi)) * np.exp(-(x-mu)**2 / (2*sigma**2))
    max_x = scipy.optimize.fmin(lambda x: -f(x), a)

    plot.line(9-x, p / f(max_x), line_color="blue", line_width=8, alpha=0.4)
    
    
def combined_curve(sigma, x, mu, a, plot):
    """Add curve combining skew and normal distribution to plot."""
    
    p = get_p(get_pdf(sigma, x, mu), get_cdf(sigma, x, mu, a))
    
    def f(x): return 1/(sigma * np.sqrt(2*np.pi)) * np.exp(-(x-mu)**2 / (2*sigma**2))
    max_x = scipy.optimize.fmin(lambda x: -f(x), a)

    def f2(x): return (1/(sigma * np.sqrt(2*np.pi)) * np.exp(-(x-mu)**2 / (2*sigma**2)) + p / f(max_x))
    max_x2 = scipy.optimize.fmin(lambda x: -f(x), a)

    plot.line(x, (((1/(sigma * np.sqrt(2*np.pi)) * np.exp(-(x-mu)**2 / (2*sigma**2)))
            + p / f(max_x))/ f2(max_x2)), line_color="purple", line_width=2, alpha=1)

In [7]:
def set_figure_colors(p, bg, fg):
    """Set figure p background colors bg and foreground colors fg."""
    
    p.border_fill_color = bg
    p.title.text_color = fg
    p.xaxis.axis_label_text_color = fg
    p.yaxis.axis_label_text_color = fg
    p.xaxis.axis_line_color = fg
    p.yaxis.axis_line_color = fg
    p.xaxis.major_label_text_color = fg
    p.yaxis.major_label_text_color = fg
    p.xaxis.major_tick_line_color = fg
    p.xaxis.minor_tick_line_color = fg
    p.yaxis.major_tick_line_color = fg
    p.yaxis.minor_tick_line_color = fg

In [8]:
clear_output()
b2 = Button(description="Click to Start", icon="arrow", width=400)

dropdown2 = Dropdown(options=['0', '1000', '2000', '3000', '4000', '5000', '10000', '20000', '30000', '40000', '50000',
                             '60000', '70000', '80000', '90000', '100000', '110000', '120000', '130000', '140000', '150000',
                             '160000', '170000', '180000', '190000', '200000', '250000', '300000', '400000', '500000',
                             '600000', '700000', '800000', '900000', '910000', '920000', '1000000', '1100000', '1200000',
                              '1300000', '1400000', '1500000'],
                    value='10000', 
                    description='Iteration:'
)

data = None

def update_curves():
    clear_output()
    global data
    num_imgs = 9000
    print("number of images: %d" % num_imgs)
    imgs_data = classify_imgs2(int(dropdown2.value), True, num_imgs)
    
#     num_blobs = randint(0, 9)
    max_blobs = 9
    min_blobs = 1
    
    curves = list()
    dark = "#111111"
    light = "#DDDDDD"
    p2 = figure(title="Blob Number Classification Probabilities Distributions", y_range=(0, 1), tools="save", background_fill_color=dark)
    set_figure_colors(p2, dark, light)


    for num_blobs in range(min_blobs, max_blobs + 1):
        print("number of blobs: ", num_blobs)
        

        p1 = figure(title="Blob Number Classification Probabilities Distribution for %d Blobs" % num_blobs , y_range=(0, 1), tools="save",
                    background_fill_color=dark)
        set_figure_colors(p1, dark, light)

        m = 0.1
        z_size = max_blobs - min_blobs + 1

        new_hist = np.zeros(z_size)
        choice_hist = np.zeros(z_size)
        value_counts = np.zeros(z_size)
        values_sum = 0
        sqr_sum = 0
        num_imgs_with_num_blobs = 0

        for idx, data in enumerate(imgs_data):

            if data["label"][(num_blobs - min_blobs)] == 1: # data is for an image with num_blobs blobs
                num_imgs_with_num_blobs += 1
                
                max_glimpse = 2
                min_glimpse = 0
                glimpses = 10#max_glimpse - min_glimpse + 1
                
                for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:#range(min_glimpse, max_glimpse + 1):
                    # Histogram of softmaxes
                    new_hist += data["classifications"][i][0] / glimpses
                    
                    # Histogram of choices
                    choice = np.argmax(data["classifications"][i][0])
                    choice_list = [0] * z_size
                    choice_list[choice] = 1 / glimpses
                    choice_hist += choice_list
                    
#                 glimpse = 1
        
#                 new_hist += data["classifications"][glimpse][0]
        
#                 choice = np.argmax(data["classifications"][glimpse][0])
#                 choice_list = [0] * z_size
#                 choice_list[choice] = 1
#                 choice_hist += choice_list
                
        print("num_imgs_with_num_blobs: ", num_imgs_with_num_blobs)
        
        new_hist = new_hist / num_imgs_with_num_blobs
        print("new_hist: ", new_hist)
        choice_hist = choice_hist / num_imgs_with_num_blobs

        x = np.linspace(-2, 11.0, 1000)
        source = ColumnDataSource(data=dict(color=["red"] * z_size, top=new_hist, bottom=np.zeros(z_size), left=np.arange(min_blobs, max_blobs + 1) + m - 0.55, right=np.arange(min_blobs + 1, max_blobs + 2) - m - 0.55))
        source2 = ColumnDataSource(data=dict(color=["yellow"] * z_size, top=choice_hist, bottom=np.zeros(z_size), left=np.arange(min_blobs, max_blobs + 1) + m - 0.45, right=np.arange(min_blobs + 1, max_blobs + 2) - m - 0.45))
        source.data["color"][(num_blobs - min_blobs)] = "lime"
        p1.quad('left', 'right', 'top', 'bottom', source=source, color="color", alpha=1)
        p1.quad('left', 'right', 'top', 'bottom', source=source2, color="color", alpha=0.5)


        # FORMAT PLOT ##############################

        p1.xaxis.axis_label = 'Number of Blobs'
        p1.yaxis.axis_label = 'Classification Probability'
        p1.xaxis[0].ticker=FixedTicker(ticks=np.arange(min_blobs, max_blobs + 1))
        
        
        # PLOT CURVES #############################
        
        # Find the mean
        for j in range(z_size):
            values_sum += j * choice_hist[j] # curve based on classification distribution. Use new_hist to get curve based on softmax.
        mu = values_sum
        print("mu: ", mu)

        # Find the standard deviation
        for k in range(z_size):
            sqr_sum += choice_hist[k] * ((k - mu) ** 2)
        sigma = np.sqrt(sqr_sum)
        
        curves.append((sigma * 2, x, mu, p2, str(num_blobs),
                       "#" + str(randint(2, 9)) + str(randint(2, 9))+ str(randint(2, 9))+ str(randint(2, 9))+ str(randint(2, 9))+ str(randint(2, 9))))
        
        curve(sigma, x, mu, p1) # plot gaussian curve
#         a = np.argmax(new_hist)
#         skew_curve(sigma, x, mu, a, p1)
#         combined_curve(sigma, x, mu, a, p1)

        show(gridplot(p1, ncols=2, plot_width=500, plot_height=400, toolbar_location=None))
            
    for stats in curves:
        curve(*stats)
    show(p2)
    

def on_click2(b2, new_image=True):
    """Load new random image after button is clicked."""
    
    b2.description = "Loading..."
    update_curves()
    b2.description = "Next (Random) Image"

b2.on_click(on_click2)


def on_change2(change):
    """Change the iteration number to new dropdown selection."""
    
    if change['type'] == 'change' and change['name'] == 'value':
        on_click(b2, new_image=False)
        

dropdown2.observe(on_change2)
display(HBox([b2, dropdown2]))

number of images: 9000
INFO:tensorflow:Restoring parameters from model_runs/DRAM_test_square/classifymodel_300000.ckpt
number of blobs:  1
num_imgs_with_num_blobs:  1000
new_hist:  [  9.85630519e-01   1.39853762e-02   3.68177500e-04   1.48928526e-05
   8.69502608e-07   5.56483565e-08   6.29355866e-08   5.11211316e-08
   1.75462587e-09]
mu:  0.0022


number of blobs:  2
num_imgs_with_num_blobs:  1000
new_hist:  [  5.02940173e-02   7.04388310e-01   2.15246966e-01   2.87745849e-02
   1.21089023e-03   8.14057496e-05   3.29661836e-06   4.58399532e-07
   7.41713137e-08]
mu:  1.1077


number of blobs:  3
num_imgs_with_num_blobs:  1000
new_hist:  [  2.32003442e-03   2.11603570e-01   4.76487827e-01   2.69454845e-01
   3.15293526e-02   7.32514607e-03   1.12004630e-03   1.30342860e-04
   2.88395516e-05]
mu:  2.0312


number of blobs:  4
num_imgs_with_num_blobs:  1000
new_hist:  [  1.71729976e-07   3.04953064e-02   2.54781294e-01   4.76466985e-01
   1.36395767e-01   7.02591321e-02   2.43179169e-02   5.28121311e-03
   2.00221438e-03]
mu:  2.9021


number of blobs:  5
num_imgs_with_num_blobs:  1000
new_hist:  [  7.70834855e-10   3.13642764e-03   8.53670935e-02   3.64267546e-01
   2.10274533e-01   1.91219968e-01   1.00120277e-01   3.04385342e-02
   1.51756218e-02]
mu:  3.77


number of blobs:  6
num_imgs_with_num_blobs:  1000
new_hist:  [  1.47127867e-13   8.45661228e-04   2.33924599e-02   1.85742024e-01
   1.80161117e-01   2.56712140e-01   1.99298727e-01   9.01764443e-02
   6.36714280e-02]
mu:  4.8101


number of blobs:  7
num_imgs_with_num_blobs:  1000
new_hist:  [  4.72562558e-17   2.13479691e-05   4.44533355e-03   6.50863939e-02
   1.01068292e-01   2.23732806e-01   2.58715317e-01   1.71756159e-01
   1.75174351e-01]
mu:  5.9089


number of blobs:  8
num_imgs_with_num_blobs:  1000
new_hist:  [  6.86679433e-20   1.27662381e-06   7.34011049e-04   1.83516272e-02
   4.57582139e-02   1.49982805e-01   2.44036678e-01   2.21171017e-01
   3.19964372e-01]
mu:  6.7974


number of blobs:  9
num_imgs_with_num_blobs:  1000
new_hist:  [  1.97443991e-20   1.61444355e-07   1.04827453e-04   4.34823131e-03
   1.66719196e-02   8.02625560e-02   1.86746258e-01   2.37967545e-01
   4.73898501e-01]
mu:  7.431
