In [None]:
import tensorflow as tf
tf.enable_eager_execution()
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout
from tensorflow.keras.models import Model, Sequential
import tensorflow.keras.backend as K
import numpy as np
from tensorflow import contrib

In [None]:
def ds_map(image, label):
    image = tf.cast(image, tf.float32)
    image = tf.reshape(image, [28,28,1])
    image = image//255.
    label = tf.one_hot(label,10)
    return image, label

In [None]:
train, test = tf.keras.datasets.mnist.load_data()

# Foramtting test data
x_test, y_test = test
x_test = x_test//255.
x_test = x_test.reshape(10000,28,28,1)
y_test = tf.keras.utils.to_categorical(y_test)

# Formatting training data
mnist_ds = tf.data.Dataset.from_tensor_slices(train)
mnist_ds = mnist_ds.map(ds_map).shuffle(60000).batch(32)

In [None]:
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),activation='relu'))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
#model.add(Dropout(0.5))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
#model.add(Dropout(0.5))
model.add(Dense(10, activation='linear', use_bias=False, name='svm'))

In [None]:
def svm_loss(layer, y_true, y_pred):
    weights = layer.weights[0]
    weights_tf = tf.convert_to_tensor(weights)
    
    pos = K.sum(y_true * y_pred, axis=-1)
    neg = K.max((1.0 - y_true) * y_pred, axis=-1)
    hinge_loss = K.mean(K.maximum(0.0, neg - pos + 1), axis=-1)
    regularization_loss = 0.5*(tf.reduce_sum(tf.multiply(weights_tf,weights_tf)))
    return regularization_loss + 0.4*hinge_loss

In [None]:
# Optimizers: Adam optimizers are used
optimizer_adam = tf.train.AdamOptimizer(1.e-3)

In [None]:
def train(dataset, epochs, optimizer):
    epoch_loss_avg = tfe.metrics.Mean()
    epoch_accuracy = tfe.metrics.Accuracy()
    for epoch in range(epochs):
        for input_image, target in dataset:
            with tf.GradientTape() as grad_tape:
                # getting the output image from generator and discriminator
                model_out = model(input_image)
                model_loss = svm_loss(model.get_layer('svm'), target, model_out)
                
            gradients = grad_tape.gradient(model_loss, model.variables)
            optimizer.apply_gradients(zip(gradients, model.variables))
            epoch_loss_avg(model_loss)  
            epoch_accuracy(tf.argmax(model(input_image), axis=1, output_type=tf.int32), tf.argmax(target, axis=1, output_type=tf.int32))
            
        train_loss_results.append(epoch_loss_avg.result())
        train_accuracy_results.append(epoch_accuracy.result())
        print("Epoch {:03d}: Loss: {:.3f}, Accuracy: {:.3%}".format(epoch,
                                                                epoch_loss_avg.result(),
                                                                epoch_accuracy.result()))

In [None]:
def train2(dataset, epochs, optimizer):
    for epoch in range(epochs):
        loss_mat = []
        acc_mat = []
        for input_image, target in dataset:
            with tf.GradientTape() as grad_tape:
                # getting the output image from generator and discriminator
                model_out = model(input_image)
                model_loss = svm_loss(model.get_layer('svm'), target, model_out)
            
            gradients = grad_tape.gradient(model_loss, model.variables)
            optimizer.apply_gradients(zip(gradients, model.variables))
            loss_mat.append(model_out)
            acc_mat.append(tf.math.equal(tf.argmax(model_out, axis=1, output_type=tf.int32), tf.argmax(target, axis=1, output_type=tf.int32)))
        loss_ = tf.reduce_sum(loss_mat)
        acc_mat = tf.cast(acc_mat, tf.float32)
        acc_ = tf.reduce_sum(acc_mat)
        
        print('Loss:{}, Accuracy:{}'.format(loss_, acc_))

In [None]:
EPOCHS = 10

tfe = contrib.eager
train_loss_results = []
train_accuracy_results = []

train2(mnist_ds, EPOCHS, optimizer_adam)

In [None]:
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])