In [53]:
import tensorflow as tf
import numpy as np
from tqdm import tqdm
from itertools import product
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D

%matplotlib qt

In [77]:
seed = 0
tf.reset_default_graph()
tf.random.set_random_seed(seed)
np.random.seed(seed)

x = tf.placeholder(dtype=tf.float32, shape=[None,2])

w1 = tf.Variable(np.random.normal(scale=np.sqrt(2./2),size=[2,512]).astype(np.float32))
b1 = tf.Variable(np.zeros(512,dtype=np.float32))
w2 = tf.Variable(np.random.normal(scale=np.sqrt(2./512),size=[512,512]).astype(np.float32))
b2 = tf.Variable(np.zeros(512,dtype=np.float32))
w3 = tf.Variable(np.random.normal(scale=np.sqrt(2./512),size=[512,1]).astype(np.float32))
b3 = tf.Variable(np.zeros(1,dtype=np.float32))

params = [w1,b1,w2,b2,w3,b3]
nr_params = sum([np.prod(p.get_shape().as_list()) for p in params])
scaling = 2**125
scaling = 2**126

def get_logits(par):
    h1 = tf.nn.bias_add(tf.matmul(x , par[0]), par[1]) / scaling
    h2 = tf.nn.bias_add(tf.matmul(h1, par[2]) , par[3] / scaling)   
    o =   tf.nn.bias_add(tf.matmul(h2, par[4]), par[5] / scaling)*scaling
    return o
    
def get_logits_non(par):
    h1 = tf.nn.relu( tf.nn.bias_add(tf.matmul(x , par[0]), par[1]) )
    h2 = tf.nn.relu( tf.nn.bias_add(tf.matmul(h1, par[2]) , par[3]))   
    o =   tf.nn.bias_add(tf.matmul(h2, par[4]), par[5] )
    return o
    
output = get_logits(params)
output_nonlin = get_logits_non(params)

In [81]:
import matplotlib.pyplot as plt
import numpy as np

class ZoomPlot():

    def __init__(self, resolution, fn="nlln"):
        """
        param resolution: integer specifying plot resolution
        param fn: string, either "nlln" for float32 nonlinearity
                  or "nln" for relu nonlinearity
        """
        self.fn = fn
        self.fig = plt.figure(figsize=(12, 10))
        self.ax = self.fig.add_subplot(111)
        self.xmin = -2.5; self.xmax = 1.0;
        self.ymin = -1.5; self.ymax = 1.5;
        self.xpress = self.xmin
        self.xrelease = self.xmax
        self.ypress = self.ymin
        self.yrelease = self.ymax
        self.resolution = resolution

        self.fig.canvas.mpl_connect('button_press_event', self.onpress)
        self.fig.canvas.mpl_connect('button_release_event', self.onrelease)
        self.plot_fixed_resolution(self.xmin, self.xmax,
                                   self.ymin, self.ymax)

    def nlln(self, X, Y):
        ins = np.dstack((X, Y)).reshape(-1, 2)
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            vals = sess.run(output, {x: ins * 30.0})
        vals = vals.reshape(self.resolution, self.resolution)
        return vals
    
    def nln(self, X, Y):
        ins = np.dstack((X, Y)).reshape(-1, 2)
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            vals = sess.run(output_nonlin, {x: ins * 30.0})
        vals = vals.reshape(self.resolution, self.resolution)
        return vals
    
    def plot_fixed_resolution(self, x1, x2, y1, y2):
        x_ = np.linspace(x1, x2, self.resolution)
        y_ = np.linspace(y1, y2, self.resolution)
        X, Y = np.meshgrid(x_, y_)
        
        if self.fn == "nln":
            C = self.nln(X, Y)
        else:  # just give the nlln
            C = self.nlln(X, Y)
            
        self.ax.clear()
        self.ax.set_xlim(x1, x2)
        self.ax.set_ylim(y1, y2)
        self.ax.pcolormesh(X, Y, C, cmap="jet")
        self.fig.canvas.draw()

    def onpress(self, event):
        if event.button != 1: return
        self.xpress = event.xdata
        self.ypress = event.ydata

    def onrelease(self, event):
        if event.button != 1: return
        self.xrelease = event.xdata
        self.yrelease = event.ydata
        self.xmin = min(self.xpress, self.xrelease)
        self.xmax = max(self.xpress, self.xrelease)
        self.ymin = min(self.ypress, self.yrelease)
        self.ymax = max(self.ypress, self.yrelease)
        self.plot_fixed_resolution(self.xmin, self.xmax,
                                   self.ymin, self.ymax)


        
    
plot = ZoomPlot(resolution=1000, fn="nlln")
plt.show()