In [62]:
import numpy as np
import pickle
import matplotlib.pyplot as plt
import tensorflow
import matplotlib
from tensorflow import keras
from tensorflow.keras.datasets import cifar10,mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import SGD
from tensorflow.keras import backend as K
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import os
import sys
sys.path.insert(0, '..')
from binarization_utils import *
from model_architectures import get_model

from tensorflow.keras import activations

In [63]:
print(tf.__version__)
print(keras.__version__)

2.4.4
2.4.0


In [64]:
dataset='MNIST'
Train=True
Evaluate=False
batch_size=100
epochs=200

In [65]:
if dataset=="MNIST":
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    # convert class vectors to binary class matrices
    X_train = X_train.reshape(-1,784)
    X_test = X_test.reshape(-1,784)
    use_generator=False
elif dataset=="CIFAR-10":
    use_generator=True
    (X_train, y_train), (X_test, y_test) = cifar10.load_data()

In [66]:
X_train=X_train.astype(np.float32)
X_test=X_test.astype(np.float32)
Y_train = to_categorical(y_train, 10)
Y_test = to_categorical(y_test, 10)
X_train /= 255
X_test /= 255
X_train=2*X_train-1
X_test=2*X_test-1


print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')


X_train shape: (60000, 784)
60000 train samples
10000 test samples


In [67]:
# the following cell block defines the activation layer that simulates the errors

# prob stores the probability of a sign flipping (1 -> -1 or -1 -> 1)
prob = 0
    
class Nonideal_sign(Layer):
    def __init__(self, levels=1,**kwargs):
        self.levels=levels
        super(Nonideal_sign, self).__init__(**kwargs)
    def build(self, input_shape):
        ars=np.arange(self.levels)+1.0
        ars=ars[::-1]
        means=ars/np.sum(ars)
        self.means=[K.variable(m) for m in means]
        self._trainable_weights = self.means
    def call(self, x, mask=None):
        resid = x
        out_bin=0
        for l in range(self.levels):
            out=binarize(resid)*(K.abs(self.means[l])) *((2*tf.cast(tf.random.uniform(self.means[l].shape) > prob, tf.float32)) - 1)
            out_bin=out_bin+out
            resid=resid-out
        return out_bin
    
        # the following lines were an idea to implement flips using tensor operations
        '''positive_mask = tf.cast(out_bin > 0, tf.float32)
        negative_mask = tf.cast(out_bin < 0, tf.float32)
        
        positive_flips = tf.random.uniform(out_bin.shape) < p[1]
        positives = tf.math.multiply(positive_mask, (tf.cast(tf.random.uniform(out_bin.shape) < p[1], tf.float32) - 1))
        negatives = tf.math.multiply(negative_mask, (tf.cast(tf.random.uniform(out_bin.shape) < p[0], tf.float32) - 1))
        return'''

    def get_output_shape_for(self,input_shape):
        return input_shape
    def compute_output_shape(self,input_shape):
        return input_shape
    def set_means(self,X):
        means=np.zeros((self.levels))
        means[0]=1
        resid=np.clip(X,-1,1)
        approx=0
        for l in range(self.levels):
            m=np.mean(np.absolute(resid))
            out=np.sign(resid)*m
            approx=approx+out
            resid=resid-out
            means[l]=m
            err=np.mean((approx-np.clip(X,-1,1))**2)

        means=means/np.sum(means)
        sess=K.get_session()
        sess.run(self.means.assign(means))
    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'levels': self.levels
        })
        return config


In [97]:
# binary dense layer with introduction of errors
class binary_dense_error(Layer):
    def __init__(self,n_in,n_out,p,**kwargs):
        self.n_in=n_in
        self.n_out=n_out
        self.p=p
        super(binary_dense_error,self).__init__(**kwargs)
    def build(self, input_shape):
        stdv=1/np.sqrt(self.n_in)
        w = np.random.normal(loc=0.0, scale=stdv,size=[self.n_in,self.n_out]).astype(np.float32)
        self.w=K.variable(w)
        self.gamma=K.variable(1.0)
        self._trainable_weights=[self.w,self.gamma]

    def call(self, x,mask=None):
        constraint_gamma=K.abs(self.gamma)#K.clip(self.gamma,0.01,10)
        self.clamped_w=constraint_gamma*binarize(self.w)*((2*tf.cast(tf.random.uniform(self.w.shape) > self.p, tf.float32)) - 1)
        self.out=K.dot(x,self.clamped_w)
        return self.out
    def  get_output_shape_for(self,input_shape):
        return (input_shape[0], self.n_out)
    def compute_output_shape(self,input_shape):
        return (input_shape[0], self.n_out)
    
    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'n_in': self.n_in,
            'n_out': self.n_out,
            'p': self.p
        })
        return config
    
class binary_dense_errordot(Layer):
    def __init__(self,n_in,n_out,p,**kwargs):
        self.n_in=n_in
        self.n_out=n_out
        self.p=p
        super(binary_dense_error,self).__init__(**kwargs)
    def build(self, input_shape):
        stdv=1/np.sqrt(self.n_in)
        w = np.random.normal(loc=0.0, scale=stdv,size=[self.n_in,self.n_out]).astype(np.float32)
        self.w=K.variable(w)
        self.gamma=K.variable(1.0)
        self._trainable_weights=[self.w,self.gamma]

    def call(self, x,mask=None):
        constraint_gamma=K.abs(self.gamma)#K.clip(self.gamma,0.01,10)
        self.clamped_w=constraint_gamma*binarize(self.w)
        self.out=K.dot(x,self.clamped_w)
        return self.out*((2*tf.cast(tf.random.uniform(self.out.shape) > self.p, tf.float32)) - 1)
    def  get_output_shape_for(self,input_shape):
        return (input_shape[0], self.n_out)
    def compute_output_shape(self,input_shape):
        return (input_shape[0], self.n_out)
    
    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'n_in': self.n_in,
            'n_out': self.n_out,
            'p': self.p
        })
        return config

In [95]:
# probability of weights flipping sign
prob = 0.01

# enter the model name
model_name = "newmodel"

if not(os.path.exists('models')):
    os.mkdir('models')
if not(os.path.exists('models/'+model_name)):
    os.mkdir('models/'+model_name)
resid_levels=1
sess=tf.compat.v1.keras.backend.get_session()

resid_levels=1
batch_norm_eps=1e-4
batch_norm_alpha=0.1#(this is same as momentum)

if dataset=="MNIST":
    model=Sequential()
    model.add(binary_dense_error(n_in=784,n_out=256,input_shape=[784],p=prob))
    model.add(BatchNormalization(axis=-1, momentum=batch_norm_alpha, epsilon=batch_norm_eps))
    model.add(Nonideal_sign(levels=resid_levels))
    model.add(binary_dense_error(n_in=int(model.output.get_shape()[1]),n_out=256,p=prob))
    model.add(BatchNormalization(axis=-1, momentum=batch_norm_alpha, epsilon=batch_norm_eps))
    model.add(Nonideal_sign(levels=resid_levels))
    model.add(binary_dense_error(n_in=int(model.output.get_shape()[1]),n_out=256,p=prob))
    model.add(BatchNormalization(axis=-1, momentum=batch_norm_alpha, epsilon=batch_norm_eps))
    model.add(Nonideal_sign(levels=resid_levels))
    model.add(binary_dense_error(n_in=int(model.output.get_shape()[1]),n_out=256,p=prob))
    model.add(BatchNormalization(axis=-1, momentum=batch_norm_alpha, epsilon=batch_norm_eps))
    model.add(Nonideal_sign(levels=resid_levels))
    model.add(binary_dense_error(n_in=int(model.output.get_shape()[1]),n_out=10,p=prob))
    model.add(BatchNormalization(axis=-1, momentum=batch_norm_alpha, epsilon=batch_norm_eps))
    model.add(Activation('softmax'))
elif dataset=="CIFAR-10":
    model=Sequential()
    model.add(binary_conv(nfilters=64,ch_in=3,k=3,padding='valid',input_shape=[32,32,3]))
    model.add(BatchNormalization(axis=-1, momentum=batch_norm_alpha, epsilon=batch_norm_eps))
    model.add(Nonideal_sign(levels=resid_levels))
    model.add(binary_conv(nfilters=64,ch_in=64,k=3,padding='valid'))
    model.add(BatchNormalization(axis=-1, momentum=batch_norm_alpha, epsilon=batch_norm_eps))
    model.add(Nonideal_sign(levels=resid_levels))
    model.add(MaxPooling2D(pool_size=(2, 2),strides=(2,2)))

    model.add(binary_conv(nfilters=128,ch_in=64,k=3,padding='valid'))
    model.add(BatchNormalization(axis=-1, momentum=batch_norm_alpha, epsilon=batch_norm_eps))
    model.add(Nonideal_sign(levels=resid_levels))
    model.add(binary_conv(nfilters=128,ch_in=128,k=3,padding='valid'))
    model.add(BatchNormalization(axis=-1, momentum=batch_norm_alpha, epsilon=batch_norm_eps))
    model.add(Nonideal_sign(levels=resid_levels))
    model.add(MaxPooling2D(pool_size=(2, 2),strides=(2,2)))

    model.add(binary_conv(nfilters=256,ch_in=128,k=3,padding='valid'))
    model.add(BatchNormalization(axis=-1, momentum=batch_norm_alpha, epsilon=batch_norm_eps))
    model.add(Nonideal_sign(levels=resid_levels))
    model.add(binary_conv(nfilters=256,ch_in=256,k=3,padding='valid'))
    model.add(BatchNormalization(axis=-1, momentum=batch_norm_alpha, epsilon=batch_norm_eps))
    model.add(Nonideal_sign(levels=resid_levels))
    #model.add(MaxPooling2D(pool_size=(2, 2),strides=(2,2)))

    model.add(my_flat())

    model.add(binary_dense(n_in=int(model.output.get_shape()[1]),n_out=512))
    model.add(BatchNormalization(axis=-1, momentum=batch_norm_alpha, epsilon=batch_norm_eps))
    model.add(Nonideal_sign(levels=resid_levels))
    model.add(binary_dense(n_in=int(model.output.get_shape()[1]),n_out=512))
    model.add(BatchNormalization(axis=-1, momentum=batch_norm_alpha, epsilon=batch_norm_eps))
    model.add(Nonideal_sign(levels=resid_levels))
    model.add(binary_dense(n_in=int(model.output.get_shape()[1]),n_out=10))
    model.add(BatchNormalization(axis=-1, momentum=batch_norm_alpha, epsilon=batch_norm_eps))
    model.add(Activation(activations.softmax))
    
# the following is a workaround so that the model weights can be saved
# https://github.com/tensorflow/tensorflow/issues/46871
j = 0
for w in model.weights:
    w._handle_name = 'model_' + str(j) + w.name
    j = j + 1
    

In [96]:
#gather all binary dense and binary convolution layers:
binary_layers=[]
for l in model.layers:
    if isinstance(l,binary_dense) or isinstance(l,binary_conv):
        binary_layers.append(l)

#gather all residual binary activation layers:
resid_bin_layers=[]
for l in model.layers:
    if isinstance(l,Residual_sign):
        resid_bin_layers.append(l)
lr=0.01
opt = keras.optimizers.Adam(lr=lr,decay=1e-6)#SGD(lr=lr,momentum=0.9,decay=1e-5)
model.compile(loss='sparse_categorical_crossentropy',optimizer=opt,metrics=['accuracy'])


weights_path='models/'+model_name+'.h5'
cback=keras.callbacks.ModelCheckpoint(weights_path, monitor='val_accuracy', save_best_only=True)
if use_generator:
    if dataset=="CIFAR-10":
        horizontal_flip=True
    datagen = ImageDataGenerator(
        width_shift_range=0.15,  # randomly shift images horizontally (fraction of total width)
        height_shift_range=0.15,  # randomly shift images vertically (fraction of total height)
        horizontal_flip=horizontal_flip)  # randomly flip images
    if keras.__version__[0]=='2':
        history=model.fit_generator(datagen.flow(X_train, y_train,batch_size=batch_size),steps_per_epoch=X_train.shape[0]/batch_size,
        epochs=epochs,validation_data=(X_test, y_test),verbose=2,callbacks=[cback])
    if keras.__version__[0]=='1':
        history=model.fit_generator(datagen.flow(X_train, y_train,batch_size=batch_size), samples_per_epoch=X_train.shape[0], 
        epochs=epochs, verbose=2,validation_data=(X_test,y_test),callbacks=[cback])

else:
    if keras.__version__[0]=='2':
        history=model.fit(X_train, y_train,batch_size=batch_size,validation_data=(X_test, y_test), verbose=2,epochs=epochs,callbacks=[cback])
    if keras.__version__[0]=='1':
        history=model.fit(X_train, y_train,batch_size=batch_size,validation_data=(X_test, y_test), verbose=2,nb_epoch=epochs,callbacks=[cback])
dic={'hard':history.history}
foo=open('models/'+model_name+'.pkl','wb')
pickle.dump(dic,foo)
foo.close()

Epoch 1/200
600/600 - 6s - loss: 0.6048 - accuracy: 0.8650 - val_loss: 0.3443 - val_accuracy: 0.9079
Epoch 2/200
600/600 - 4s - loss: 0.3665 - accuracy: 0.9129 - val_loss: 0.4049 - val_accuracy: 0.9179
Epoch 3/200
600/600 - 4s - loss: 0.3634 - accuracy: 0.9244 - val_loss: 0.3450 - val_accuracy: 0.9262
Epoch 4/200
600/600 - 5s - loss: 0.2687 - accuracy: 0.9410 - val_loss: 0.3847 - val_accuracy: 0.9217
Epoch 5/200
600/600 - 5s - loss: 0.2686 - accuracy: 0.9449 - val_loss: 0.5071 - val_accuracy: 0.9081
Epoch 6/200
600/600 - 4s - loss: 0.2731 - accuracy: 0.9474 - val_loss: 0.2958 - val_accuracy: 0.9199
Epoch 7/200
600/600 - 5s - loss: 0.2406 - accuracy: 0.9517 - val_loss: 0.4012 - val_accuracy: 0.9354
Epoch 8/200
600/600 - 5s - loss: 0.1708 - accuracy: 0.9607 - val_loss: 0.3551 - val_accuracy: 0.9226
Epoch 9/200
600/600 - 4s - loss: 0.2507 - accuracy: 0.9544 - val_loss: 0.4845 - val_accuracy: 0.9114
Epoch 10/200
600/600 - 4s - loss: 0.1905 - accuracy: 0.9607 - val_loss: 0.2759 - val_accura

Epoch 82/200
600/600 - 4s - loss: 0.1400 - accuracy: 0.9813 - val_loss: 0.3529 - val_accuracy: 0.9295
Epoch 83/200
600/600 - 4s - loss: 0.0707 - accuracy: 0.9883 - val_loss: 0.3319 - val_accuracy: 0.9405
Epoch 84/200
600/600 - 4s - loss: 0.0523 - accuracy: 0.9899 - val_loss: 0.2835 - val_accuracy: 0.9529
Epoch 85/200
600/600 - 5s - loss: 0.2181 - accuracy: 0.9729 - val_loss: 0.1806 - val_accuracy: 0.9552
Epoch 86/200
600/600 - 5s - loss: 0.1503 - accuracy: 0.9806 - val_loss: 0.3370 - val_accuracy: 0.9333
Epoch 87/200
600/600 - 5s - loss: 0.1355 - accuracy: 0.9825 - val_loss: 0.1679 - val_accuracy: 0.9599
Epoch 88/200
600/600 - 5s - loss: 0.1534 - accuracy: 0.9802 - val_loss: 0.2961 - val_accuracy: 0.9411
Epoch 89/200
600/600 - 5s - loss: 0.1363 - accuracy: 0.9824 - val_loss: 0.2479 - val_accuracy: 0.9590
Epoch 90/200
600/600 - 5s - loss: 0.1362 - accuracy: 0.9829 - val_loss: 0.3004 - val_accuracy: 0.9439
Epoch 91/200
600/600 - 5s - loss: 0.1341 - accuracy: 0.9823 - val_loss: 0.5001 - v

Epoch 162/200
600/600 - 4s - loss: 0.1275 - accuracy: 0.9849 - val_loss: 0.2708 - val_accuracy: 0.9541
Epoch 163/200
600/600 - 4s - loss: 0.1742 - accuracy: 0.9799 - val_loss: 0.3085 - val_accuracy: 0.9431
Epoch 164/200
600/600 - 4s - loss: 0.1123 - accuracy: 0.9863 - val_loss: 0.1739 - val_accuracy: 0.9541
Epoch 165/200
600/600 - 4s - loss: 0.0786 - accuracy: 0.9894 - val_loss: 0.2613 - val_accuracy: 0.9554
Epoch 166/200
600/600 - 4s - loss: 0.0851 - accuracy: 0.9871 - val_loss: 0.2323 - val_accuracy: 0.9396
Epoch 167/200
600/600 - 4s - loss: 0.2364 - accuracy: 0.9729 - val_loss: 0.4009 - val_accuracy: 0.9255
Epoch 168/200
600/600 - 4s - loss: 0.0933 - accuracy: 0.9879 - val_loss: 0.1812 - val_accuracy: 0.9558
Epoch 169/200
600/600 - 4s - loss: 0.1287 - accuracy: 0.9846 - val_loss: 0.4051 - val_accuracy: 0.9399
Epoch 170/200
600/600 - 4s - loss: 0.0940 - accuracy: 0.9888 - val_loss: 0.2947 - val_accuracy: 0.9414
Epoch 171/200
600/600 - 4s - loss: 0.1923 - accuracy: 0.9774 - val_loss: 

In [101]:
# model with binary dense and batchnorm layers
prob = 0.01

model=Sequential()
model.add(binary_dense_error(n_in=784,n_out=256,input_shape=[784],p=prob))
model.add(BatchNormalization(axis=-1, momentum=batch_norm_alpha, epsilon=batch_norm_eps))
model.add(binary_dense_error(n_in=int(model.output.get_shape()[1]),n_out=256,p=prob))
model.add(BatchNormalization(axis=-1, momentum=batch_norm_alpha, epsilon=batch_norm_eps))
model.add(binary_dense_error(n_in=int(model.output.get_shape()[1]),n_out=256,p=prob))
model.add(BatchNormalization(axis=-1, momentum=batch_norm_alpha, epsilon=batch_norm_eps))
model.add(binary_dense_error(n_in=int(model.output.get_shape()[1]),n_out=256,p=prob))
model.add(BatchNormalization(axis=-1, momentum=batch_norm_alpha, epsilon=batch_norm_eps))
model.add(Activation('softmax'))

j = 0
for w in model.weights:
    w._handle_name = 'model_' + str(j) + w.name
    j = j + 1

lr=0.01
opt = keras.optimizers.Adam(lr=lr,decay=1e-6)#SGD(lr=lr,momentum=0.9,decay=1e-5)
model.compile(loss='sparse_categorical_crossentropy',optimizer=opt,metrics=['accuracy'])


weights_path='models/'+model_name+'.h5'
cback=keras.callbacks.ModelCheckpoint(weights_path, monitor='val_accuracy', save_best_only=True)

history=model.fit(X_train, y_train,batch_size=batch_size,validation_data=(X_test, y_test), verbose=2,epochs=epochs,callbacks=[cback])

dic={'hard':history.history}
foo=open('models/'+model_name+'.pkl','wb')
pickle.dump(dic,foo)
foo.close()


Epoch 1/200
600/600 - 5s - loss: 0.8243 - accuracy: 0.8625 - val_loss: 0.3898 - val_accuracy: 0.8900
Epoch 2/200
600/600 - 4s - loss: 0.3971 - accuracy: 0.8897 - val_loss: 0.3435 - val_accuracy: 0.9032
Epoch 3/200
600/600 - 5s - loss: 0.3607 - accuracy: 0.8962 - val_loss: 0.3339 - val_accuracy: 0.9041
Epoch 4/200
600/600 - 5s - loss: 0.3510 - accuracy: 0.8989 - val_loss: 0.3294 - val_accuracy: 0.9082
Epoch 5/200
600/600 - 5s - loss: 0.3384 - accuracy: 0.9022 - val_loss: 0.3316 - val_accuracy: 0.9041
Epoch 6/200
600/600 - 5s - loss: 0.3352 - accuracy: 0.9036 - val_loss: 0.3366 - val_accuracy: 0.9048
Epoch 7/200
600/600 - 5s - loss: 0.3278 - accuracy: 0.9052 - val_loss: 0.3138 - val_accuracy: 0.9088
Epoch 8/200
600/600 - 4s - loss: 0.3244 - accuracy: 0.9062 - val_loss: 0.3167 - val_accuracy: 0.9060
Epoch 9/200
600/600 - 4s - loss: 0.3224 - accuracy: 0.9066 - val_loss: 0.3135 - val_accuracy: 0.9110
Epoch 10/200
600/600 - 4s - loss: 0.3187 - accuracy: 0.9074 - val_loss: 0.3186 - val_accura

Epoch 82/200
600/600 - 5s - loss: 0.2874 - accuracy: 0.9173 - val_loss: 0.3179 - val_accuracy: 0.9086
Epoch 83/200
600/600 - 5s - loss: 0.2886 - accuracy: 0.9165 - val_loss: 0.3104 - val_accuracy: 0.9116
Epoch 84/200
600/600 - 5s - loss: 0.2856 - accuracy: 0.9164 - val_loss: 0.3168 - val_accuracy: 0.9140
Epoch 85/200
600/600 - 5s - loss: 0.2893 - accuracy: 0.9165 - val_loss: 0.2941 - val_accuracy: 0.9180
Epoch 86/200
600/600 - 4s - loss: 0.2887 - accuracy: 0.9167 - val_loss: 0.3008 - val_accuracy: 0.9166
Epoch 87/200
600/600 - 5s - loss: 0.2876 - accuracy: 0.9167 - val_loss: 0.3636 - val_accuracy: 0.8985
Epoch 88/200
600/600 - 5s - loss: 0.2883 - accuracy: 0.9173 - val_loss: 0.3034 - val_accuracy: 0.9150
Epoch 89/200
600/600 - 4s - loss: 0.2881 - accuracy: 0.9174 - val_loss: 0.3203 - val_accuracy: 0.9100
Epoch 90/200
600/600 - 4s - loss: 0.2859 - accuracy: 0.9180 - val_loss: 0.3013 - val_accuracy: 0.9127
Epoch 91/200
600/600 - 4s - loss: 0.2858 - accuracy: 0.9185 - val_loss: 0.3168 - v

Epoch 162/200
600/600 - 5s - loss: 0.2748 - accuracy: 0.9210 - val_loss: 0.3186 - val_accuracy: 0.9097
Epoch 163/200
600/600 - 5s - loss: 0.2731 - accuracy: 0.9220 - val_loss: 0.3232 - val_accuracy: 0.9055
Epoch 164/200
600/600 - 5s - loss: 0.2746 - accuracy: 0.9212 - val_loss: 0.3792 - val_accuracy: 0.8826
Epoch 165/200
600/600 - 4s - loss: 0.2739 - accuracy: 0.9218 - val_loss: 0.3299 - val_accuracy: 0.9024
Epoch 166/200
600/600 - 4s - loss: 0.2736 - accuracy: 0.9215 - val_loss: 0.3220 - val_accuracy: 0.9094
Epoch 167/200
600/600 - 4s - loss: 0.2752 - accuracy: 0.9209 - val_loss: 0.3185 - val_accuracy: 0.9086
Epoch 168/200
600/600 - 4s - loss: 0.2736 - accuracy: 0.9212 - val_loss: 0.3322 - val_accuracy: 0.9018
Epoch 169/200
600/600 - 4s - loss: 0.2742 - accuracy: 0.9214 - val_loss: 0.3065 - val_accuracy: 0.9149
Epoch 170/200
600/600 - 5s - loss: 0.2736 - accuracy: 0.9214 - val_loss: 0.3078 - val_accuracy: 0.9117
Epoch 171/200
600/600 - 5s - loss: 0.2734 - accuracy: 0.9229 - val_loss: 