In [1]:
import sys
from tqdm import tqdm
import pandas as pd
import numpy as np
from sklearn import datasets
import matplotlib
%matplotlib inline

import edward as ed
import tensorflow as tf
from edward.models import Bernoulli, MultivariateNormalTriL, Normal, Categorical
from edward.util import rbf
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam
from sklearn.preprocessing import OneHotEncoder

Using TensorFlow backend.


In [2]:
iris = datasets.load_iris()
df = pd.DataFrame(data=iris['data'], columns=iris['feature_names'])
i = pd.Series(iris['target'])
names = i.map(lambda x: iris['target_names'][x])
df['index'] = i
df['category_name'] = names
df

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),index,category_name
0,5.1,3.5,1.4,0.2,0,setosa
1,4.9,3.0,1.4,0.2,0,setosa
2,4.7,3.2,1.3,0.2,0,setosa
3,4.6,3.1,1.5,0.2,0,setosa
4,5.0,3.6,1.4,0.2,0,setosa
5,5.4,3.9,1.7,0.4,0,setosa
6,4.6,3.4,1.4,0.3,0,setosa
7,5.0,3.4,1.5,0.2,0,setosa
8,4.4,2.9,1.4,0.2,0,setosa
9,4.9,3.1,1.5,0.1,0,setosa


In [3]:
# shuffle
df = df.sample(frac=1)

ys = df['index'].values
xs = df[['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']].values
x_train = xs[:100]
y_train = ys[:100]
x_test = xs[100:]
y_test = ys[100:]
print("Number of train data: {}".format(x_train.shape[0]))
print("Number of features: {}".format(x_train.shape[1]))

Number of train data: 100
Number of features: 4


In [4]:
encoder = OneHotEncoder(sparse=False)
y_train_ = encoder.fit_transform(y_train.reshape(-1, 1))
y_test_ = encoder.fit_transform(y_test.reshape(-1, 1))


model = Sequential()

model.add(Dense(10, input_shape=(4,), activation='relu', name='fc1'))
model.add(Dense(10, activation='relu', name='fc2'))
model.add(Dense(3, activation='softmax', name='output'))
# model.add(Dense(3, input_shape=(4, ), activation='softmax', name='output'))

optimizer = Adam(lr=0.001)
model.compile(optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

print('Neural Network Model Summary: ')
print(model.summary())

model.fit(x_train, y_train_, verbose=2, batch_size=5, epochs=30)

results = model.evaluate(x_test, y_test_)
results

Neural Network Model Summary: 
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
fc1 (Dense)                  (None, 10)                50        
_________________________________________________________________
fc2 (Dense)                  (None, 10)                110       
_________________________________________________________________
output (Dense)               (None, 3)                 33        
Total params: 193
Trainable params: 193
Non-trainable params: 0
_________________________________________________________________
None
Epoch 1/30
 - 0s - loss: 1.4916 - acc: 0.3500
Epoch 2/30
 - 0s - loss: 1.2996 - acc: 0.3500
Epoch 3/30
 - 0s - loss: 1.2097 - acc: 0.3500
Epoch 4/30
 - 0s - loss: 1.1035 - acc: 0.3700
Epoch 5/30
 - 0s - loss: 1.0217 - acc: 0.5800
Epoch 6/30
 - 0s - loss: 0.9488 - acc: 0.6200
Epoch 7/30
 - 0s - loss: 0.8848 - acc: 0.6300
Epoch 8/30
 - 0s - loss: 0.8344 - acc: 0.6400
Epoc

[0.4225070309638977, 0.8000000071525574]

In [5]:
N = 100
num_feature = 4
def neural_network(x):
    h = tf.nn.relu(tf.matmul(x, w1) + b1)
    h = tf.nn.relu(tf.matmul(h, w2) + b2)
    h = tf.nn.softmax(tf.matmul(h, w3) + b3)
    return h

with tf.name_scope("model"):
    w1 = Normal(loc=tf.zeros([num_feature, 10]), scale=tf.ones([num_feature, 10]), name="w1")
    w2 = Normal(loc=tf.zeros([10, 10]), scale=tf.ones([10, 10]), name="w2")
    w3 = Normal(loc=tf.zeros([10, 3]), scale=tf.ones([10, 3]), name="w3")
    b1 = Normal(loc=tf.zeros(10), scale=tf.ones(10), name="b1")
    b2 = Normal(loc=tf.zeros(10), scale=tf.ones(10), name="b2")
    b3 = Normal(loc=tf.zeros(3), scale=tf.ones(3), name="b3")

    x = tf.placeholder(tf.float32, [None, num_feature], name="x")
    y = Categorical(neural_network(x), name="y")

    
def bayesian_neural_network(x):
    h = tf.nn.relu(tf.matmul(x, qw1) + qb1)
    h = tf.nn.relu(tf.matmul(h, qw2) + qb2)
    h = tf.nn.softmax(tf.matmul(h, qw3) + qb3)
    return h
    

# INFERENCE
with tf.variable_scope("posterior"):
    with tf.variable_scope("qw1"):
        loc = tf.get_variable("loc", [num_feature, 10])
        scale = tf.nn.softplus(tf.get_variable("scale", [num_feature, 10]))
        qw1 = Normal(loc=loc, scale=scale)
    with tf.variable_scope("qw2"):
        loc = tf.get_variable("loc", [10, 10])
        scale = tf.nn.softplus(tf.get_variable("scale", [10, 10]))
        qw2 = Normal(loc=loc, scale=scale)
    with tf.variable_scope("qw3"):
        loc = tf.get_variable("loc", [10, 3])
        scale = tf.nn.softplus(tf.get_variable("scale", [10, 3]))
        qw3 = Normal(loc=loc, scale=scale)
    with tf.variable_scope("qb1"):
        loc = tf.get_variable("loc", [10])
        scale = tf.nn.softplus(tf.get_variable("scale", [10]))
        qb1 = Normal(loc=loc, scale=scale)
    with tf.variable_scope("qb2"):
        loc = tf.get_variable("loc", [10])
        scale = tf.nn.softplus(tf.get_variable("scale", [10]))
        qb2 = Normal(loc=loc, scale=scale)
    with tf.variable_scope("qb3"):
        loc = tf.get_variable("loc", [3])
        scale = tf.nn.softplus(tf.get_variable("scale", [3]))
        qb3 = Normal(loc=loc, scale=scale)

    qy = Categorical(bayesian_neural_network(x), name="qy")
        
        
inference = ed.KLqp({w1: qw1, b1: qb1,
                   w2: qw2, b2: qb2,
                   w3: qw3, b3: qb3}, data={y: qy})
inference.initialize()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    samples_num = 100
    BATCH_SIZE = 10
    for epoch in tqdm(range(1000), file=sys.stdout):
        perm = np.random.permutation(N)
        for i in range(0, N, BATCH_SIZE):
            batch_x = x_train[perm[i:i+BATCH_SIZE]]
            batch_y = y_train[perm[i:i+BATCH_SIZE]]
            inference.update(feed_dict={x: batch_x, y: batch_y})
        y_samples = qy.sample(samples_num).eval(feed_dict={x: x_train})
        acc = (np.round(y_samples.sum(axis=0) / samples_num) == y_train).mean()
        y_samples = qy.sample(samples_num).eval(feed_dict={x: x_test})
        test_acc = (np.round(y_samples.sum(axis=0) / samples_num) == y_test).mean()
        if (epoch+1) % 1 == 0:
            tqdm.write('epoch:\t{}\taccuracy:\t{}\tvaridation accuracy:\t{}'.format(epoch+1, acc, test_acc))


epoch:	1	accuracy:	0.37	varidation accuracy:	0.3
epoch:	2	accuracy:	0.37	varidation accuracy:	0.3
epoch:	3	accuracy:	0.36	varidation accuracy:	0.3
epoch:	4	accuracy:	0.34	varidation accuracy:	0.3
epoch:	5	accuracy:	0.38	varidation accuracy:	0.32
epoch:	6	accuracy:	0.37	varidation accuracy:	0.32
epoch:	7	accuracy:	0.33	varidation accuracy:	0.3
epoch:	8	accuracy:	0.35	varidation accuracy:	0.3
epoch:	9	accuracy:	0.34	varidation accuracy:	0.28
epoch:	10	accuracy:	0.37	varidation accuracy:	0.3
epoch:	11	accuracy:	0.35	varidation accuracy:	0.3
epoch:	12	accuracy:	0.35	varidation accuracy:	0.28
epoch:	13	accuracy:	0.36	varidation accuracy:	0.3
epoch:	14	accuracy:	0.35	varidation accuracy:	0.3
epoch:	15	accuracy:	0.34	varidation accuracy:	0.28
epoch:	16	accuracy:	0.36	varidation accuracy:	0.32
epoch:	17	accuracy:	0.35	varidation accuracy:	0.32
epoch:	18	accuracy:	0.35	varidation accuracy:	0.32
epoch:	19	accuracy:	0.33	varidation accuracy:	0.3
epoch:	20	accuracy:	0.33	varidation accuracy:	0.34


epoch:	162	accuracy:	0.33	varidation accuracy:	0.3
epoch:	163	accuracy:	0.35	varidation accuracy:	0.3
epoch:	164	accuracy:	0.37	varidation accuracy:	0.3
epoch:	165	accuracy:	0.36	varidation accuracy:	0.32
epoch:	166	accuracy:	0.35	varidation accuracy:	0.3
epoch:	167	accuracy:	0.34	varidation accuracy:	0.3
epoch:	168	accuracy:	0.35	varidation accuracy:	0.3
epoch:	169	accuracy:	0.35	varidation accuracy:	0.3
epoch:	170	accuracy:	0.4	varidation accuracy:	0.3 
epoch:	171	accuracy:	0.35	varidation accuracy:	0.3
epoch:	172	accuracy:	0.34	varidation accuracy:	0.28
epoch:	173	accuracy:	0.35	varidation accuracy:	0.3
epoch:	174	accuracy:	0.35	varidation accuracy:	0.3
epoch:	175	accuracy:	0.36	varidation accuracy:	0.3
epoch:	176	accuracy:	0.32	varidation accuracy:	0.24
epoch:	177	accuracy:	0.36	varidation accuracy:	0.32
epoch:	178	accuracy:	0.34	varidation accuracy:	0.26
epoch:	179	accuracy:	0.37	varidation accuracy:	0.3
epoch:	180	accuracy:	0.38	varidation accuracy:	0.3
epoch:	181	accuracy:	0.35	

epoch:	321	accuracy:	0.35	varidation accuracy:	0.28
epoch:	322	accuracy:	0.31	varidation accuracy:	0.3
epoch:	323	accuracy:	0.35	varidation accuracy:	0.32
epoch:	324	accuracy:	0.34	varidation accuracy:	0.3
epoch:	325	accuracy:	0.34	varidation accuracy:	0.3
epoch:	326	accuracy:	0.35	varidation accuracy:	0.3
epoch:	327	accuracy:	0.35	varidation accuracy:	0.3
epoch:	328	accuracy:	0.35	varidation accuracy:	0.3
epoch:	329	accuracy:	0.33	varidation accuracy:	0.3
epoch:	330	accuracy:	0.35	varidation accuracy:	0.3
epoch:	331	accuracy:	0.33	varidation accuracy:	0.3
epoch:	332	accuracy:	0.36	varidation accuracy:	0.32
epoch:	333	accuracy:	0.37	varidation accuracy:	0.28
epoch:	334	accuracy:	0.35	varidation accuracy:	0.3
epoch:	335	accuracy:	0.36	varidation accuracy:	0.3
epoch:	336	accuracy:	0.34	varidation accuracy:	0.3
epoch:	337	accuracy:	0.36	varidation accuracy:	0.3
epoch:	338	accuracy:	0.35	varidation accuracy:	0.28
epoch:	339	accuracy:	0.34	varidation accuracy:	0.26
epoch:	340	accuracy:	0.35

epoch:	480	accuracy:	0.35	varidation accuracy:	0.28
epoch:	481	accuracy:	0.35	varidation accuracy:	0.32
epoch:	482	accuracy:	0.34	varidation accuracy:	0.3
epoch:	483	accuracy:	0.35	varidation accuracy:	0.28
epoch:	484	accuracy:	0.33	varidation accuracy:	0.32
epoch:	485	accuracy:	0.34	varidation accuracy:	0.3
epoch:	486	accuracy:	0.35	varidation accuracy:	0.36
epoch:	487	accuracy:	0.34	varidation accuracy:	0.3
epoch:	488	accuracy:	0.4	varidation accuracy:	0.32
epoch:	489	accuracy:	0.35	varidation accuracy:	0.3
epoch:	490	accuracy:	0.33	varidation accuracy:	0.28
epoch:	491	accuracy:	0.35	varidation accuracy:	0.3
epoch:	492	accuracy:	0.38	varidation accuracy:	0.32
epoch:	493	accuracy:	0.36	varidation accuracy:	0.28
epoch:	494	accuracy:	0.37	varidation accuracy:	0.3
epoch:	495	accuracy:	0.35	varidation accuracy:	0.32
epoch:	496	accuracy:	0.36	varidation accuracy:	0.3
epoch:	497	accuracy:	0.32	varidation accuracy:	0.3
epoch:	498	accuracy:	0.37	varidation accuracy:	0.28
epoch:	499	accuracy:	

epoch:	639	accuracy:	0.35	varidation accuracy:	0.3
epoch:	640	accuracy:	0.35	varidation accuracy:	0.3
epoch:	641	accuracy:	0.37	varidation accuracy:	0.3
epoch:	642	accuracy:	0.37	varidation accuracy:	0.3
epoch:	643	accuracy:	0.37	varidation accuracy:	0.3
epoch:	644	accuracy:	0.36	varidation accuracy:	0.3
epoch:	645	accuracy:	0.35	varidation accuracy:	0.34
epoch:	646	accuracy:	0.37	varidation accuracy:	0.28
epoch:	647	accuracy:	0.35	varidation accuracy:	0.3
epoch:	648	accuracy:	0.35	varidation accuracy:	0.34
epoch:	649	accuracy:	0.35	varidation accuracy:	0.28
epoch:	650	accuracy:	0.31	varidation accuracy:	0.3
epoch:	651	accuracy:	0.35	varidation accuracy:	0.32
epoch:	652	accuracy:	0.37	varidation accuracy:	0.3
epoch:	653	accuracy:	0.35	varidation accuracy:	0.28
epoch:	654	accuracy:	0.34	varidation accuracy:	0.3
epoch:	655	accuracy:	0.35	varidation accuracy:	0.32
epoch:	656	accuracy:	0.35	varidation accuracy:	0.32
epoch:	657	accuracy:	0.35	varidation accuracy:	0.28
epoch:	658	accuracy:	0

epoch:	798	accuracy:	0.34	varidation accuracy:	0.34
epoch:	799	accuracy:	0.35	varidation accuracy:	0.3
epoch:	800	accuracy:	0.36	varidation accuracy:	0.3
epoch:	801	accuracy:	0.35	varidation accuracy:	0.28
epoch:	802	accuracy:	0.34	varidation accuracy:	0.28
epoch:	803	accuracy:	0.35	varidation accuracy:	0.3
epoch:	804	accuracy:	0.35	varidation accuracy:	0.3
epoch:	805	accuracy:	0.36	varidation accuracy:	0.32
epoch:	806	accuracy:	0.35	varidation accuracy:	0.3
epoch:	807	accuracy:	0.36	varidation accuracy:	0.3
epoch:	808	accuracy:	0.35	varidation accuracy:	0.3
epoch:	809	accuracy:	0.33	varidation accuracy:	0.32
epoch:	810	accuracy:	0.35	varidation accuracy:	0.3
epoch:	811	accuracy:	0.35	varidation accuracy:	0.3
epoch:	812	accuracy:	0.36	varidation accuracy:	0.3
epoch:	813	accuracy:	0.35	varidation accuracy:	0.28
epoch:	814	accuracy:	0.35	varidation accuracy:	0.3
epoch:	815	accuracy:	0.35	varidation accuracy:	0.3
epoch:	816	accuracy:	0.35	varidation accuracy:	0.3
epoch:	817	accuracy:	0.35

epoch:	957	accuracy:	0.34	varidation accuracy:	0.3
epoch:	958	accuracy:	0.35	varidation accuracy:	0.28
epoch:	959	accuracy:	0.35	varidation accuracy:	0.26
epoch:	960	accuracy:	0.35	varidation accuracy:	0.3
epoch:	961	accuracy:	0.35	varidation accuracy:	0.3
epoch:	962	accuracy:	0.35	varidation accuracy:	0.32
epoch:	963	accuracy:	0.31	varidation accuracy:	0.26
epoch:	964	accuracy:	0.33	varidation accuracy:	0.3
epoch:	965	accuracy:	0.35	varidation accuracy:	0.3
epoch:	966	accuracy:	0.33	varidation accuracy:	0.3
epoch:	967	accuracy:	0.36	varidation accuracy:	0.28
epoch:	968	accuracy:	0.35	varidation accuracy:	0.28
epoch:	969	accuracy:	0.35	varidation accuracy:	0.28
epoch:	970	accuracy:	0.35	varidation accuracy:	0.3
epoch:	971	accuracy:	0.34	varidation accuracy:	0.26
epoch:	972	accuracy:	0.35	varidation accuracy:	0.3
epoch:	973	accuracy:	0.35	varidation accuracy:	0.32
epoch:	974	accuracy:	0.33	varidation accuracy:	0.28
epoch:	975	accuracy:	0.36	varidation accuracy:	0.3
epoch:	976	accuracy:	