In [None]:
import import_ipynb
import augmentation
import testing
import imagenet
import tensorflow as tf
import tensorflow_addons as tfa
import matplotlib.pyplot as plt

# all the parameters
learning_rate = 0.01
momentum = 0.9
weight_decay = 0.0005
no_of_epochs = 25
total_batches = 25

y = tf.placeholder(tf.float32, [None, imagenet.n_classes])

# cost function and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
    logits=imagenet.out,
    labels=y))
optimizer = tfa.optimizers.SGDW(
    learning_rate=learning_rate, momentum=momentum, weight_decay=weight_decay, nesterov=True, name='SGDW').minimize(cost)

# accuracy functions
top_1 = tf.equal(tf.argmax(imagenet.out, 1), tf.argmax(y, 1))
top_1_accuracy = tf.reduce_mean(tf.cast(top_1, tf.float32))
top_5 = tf.math.in_top_k(y, imagenet.out, 5)
top_5_accuracy = tf.reduce_mean(tf.cast(top_5, tf.float32))

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)

    # for plotting
    train_loss = []
    test_loss = []
    train_accuracy_top_1 = []
    train_accuracy_top_5 = []
    test_accuracy_top_1 = []
    test_accuracy_top_5 = []

    inp_test, out_test = testset.getTestSample() # so that after every epoch we test the same dataset

    summary_writer = tf.summary.FileWriter('./Output', sess.graph)

    for epoch in range(no_of_epochs):
         # running the architecture for 1 epoc
        for batch_no in range(total_batches):
            inp, out = augmentation.getTrainingSample()
            sess.run([optimizer],
                        feed_dict={
                            imagenet.input_images: inp,
                            y: out})

        # calculating accuracy and loss after 1 epoc for the last training data
        top_1_acc, top_5_acc, loss = sess.run([top_1_accuracy, top_5_accuracy, cost],
                                                                feed_dict={
                                                                    imagenet.input_images: inp,
                                                                    y: out})
        print("Epoch: {}, Top 1 Acc (training): {}, Top 5 Acc (training) = {}, Loss: {} ".format(epoch, top_1_acc, top_5_acc, loss))

        # calculating accuracy and loss after 1 epoc for the test data
        test_top_1_acc, test_top_5_acc, valid_loss = sess.run([top_1_accuracy, top_5_accuracy, cost],
                                                                feed_dict={
                                                                    imagenet.input_images: inp_test,
                                                                    y: out_test})
        print("Top 1 Acc (testing): {}, Top 5 Acc (testing) = {}, Loss: {}".format(test_top_1_acc, test_top_5_acc))

        # storing for plotting
        train_loss.append(loss)
        test_loss.append(valid_loss)
        train_accuracy_top_1.append(top_1_acc)
        train_accuracy_top_5.append(top_5_acc)
        test_accuracy_top_1.append(test_top_1_acc)
        test_accuracy_top_5.append(test_top_5_acc)

    summary_writer.close()

    # plotting training and testing - loss
    plt.figure(0)
    plt.plot(range(len(train_loss)), train_loss, 'b', label='Training loss')
    plt.plot(range(len(train_loss)), test_loss, 'r', label='Test loss')
    plt.title('Training and Test loss')
    plt.xlabel('Epochs ',fontsize=16)
    plt.ylabel('Loss',fontsize=16)
    plt.legend()
    plt.figure()
    plt.show()

    # plotting training and testing - top 1 accuracy
    plt.figure(1)
    plt.plot(range(len(train_loss)), train_accuracy_top_1, 'b', label='Top 1 Training Accuracy')
    plt.plot(range(len(train_loss)), test_accuracy_top_1, 'r', label='Top 1 Test Accuracy')
    plt.title('Top 1 Training and Test Accuracy')
    plt.xlabel('Epochs ',fontsize=16)
    plt.ylabel('Loss',fontsize=16)
    plt.legend()
    plt.figure()
    plt.show()

    # plotting training and testing - top 5 accuracy
    plt.figure(2)
    plt.plot(range(len(train_loss)), train_accuracy_top_5, 'b', label='Top 5 Training Accuracy')
    plt.plot(range(len(train_loss)), test_accuracy_top_5, 'r', label='Top 5 Test Accuracy')
    plt.title('Top 5 Training and Test Accuracy')
    plt.xlabel('Epochs ',fontsize=16)
    plt.ylabel('Loss',fontsize=16)
    plt.legend()
    plt.figure()
    plt.show()