In [None]:
!git clone https://github.com/ninomiyalab/Memory_Less_Momentum_Quasi_Newton

In [None]:
import tensorflow as tf
import tensorflow.keras
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, Dense, Activation, Conv2D, Flatten
from tensorflow.keras import optimizers
from Memory_Less_Momentum_Quasi_Newton.MLQN import *
from Memory_Less_Momentum_Quasi_Newton.MLMoQ import *
import matplotlib.pyplot as plt
import numpy as np
import csv

def compare_MNIST(i = 0):
    np.random.seed(i)
    
    mnist = tf.keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()

  
    x_train, x_test = x_train[..., np.newaxis]/255.0, x_test[..., np.newaxis]/255.0
    y_train, y_test = tf.keras.utils.to_categorical(y_train), tf.keras.utils.to_categorical(y_test)

    #defined Neural Network Model
    def My_Model(input_shape, Output_dim):
        inputs = Input(shape = (input_shape))
        x = Flatten()(inputs)
        x = Dense(10, activation="sigmoid")(x)
        outputs = Dense(Output_dim, activation="softmax")(x)
        model = Model(inputs = [inputs], outputs = [outputs])
        return model
    
    model = My_Model(x_train.shape[1:], Output_dim=10)
    
    model.save("model.h5")
    
    loss_fn = tf.keras.losses.CategoricalCrossentropy()

    epochs = 2000

    verbose = False

    graph = True

    # MLQN Training 
    # --------------------------------------------------------------------------------------
    model = load_model("model.h5")

    optimizer = MLQN( )

    model.compile(loss=loss_fn, optimizer=optimizer, metrics=['accuracy'])

    MLQN_history = model.fit(x_train, y_train, epochs = epochs, verbose = verbose, batch_size = x_train.shape[0], validation_data = (x_test, y_test))
    # --------------------------------------------------------------------------------------

    # MLMoQ Training
    # --------------------------------------------------------------------------------------
    model = load_model("model.h5")

    optimizer = MLMoQ()

    model.compile(loss=loss_fn, optimizer=optimizer, metrics=['accuracy'])

    MLMoQ_history = model.fit(x_train, y_train, epochs = epochs, verbose = verbose, batch_size = x_train.shape[0], validation_data = (x_test, y_test))
    # --------------------------------------------------------------------------------------

    # Adam Training 
    # --------------------------------------------------------------------------------------
    model = load_model("model.h5")

    optimizer = tf.keras.optimizers.Adam()
    
    model.compile(loss=loss_fn, optimizer=optimizer, metrics=['accuracy'])

    Adam_history = model.fit(x_train, y_train, epochs = epochs, verbose = verbose, batch_size = x_train.shape[0], validation_data = (x_test, y_test))
    # --------------------------------------------------------------------------------------

    if graph:
        fig, (axL, axR) = plt.subplots(ncols=2, figsize=(10,4))
        #Train Loss vs. Iteration graph
        axL.set_title("Train_Loss")
        axL.plot(MLQN_history.history['loss'],color="blue", label="MLQN")
        axL.plot(MLMoQ_history.history['loss'], color="m",label="MLMoQ")
        axL.plot(Adam_history.history['loss'], color="orange",label="Adam")
        axL.set_xlabel('Iterations')
        axL.set_ylabel('Train Loss')
        axL.legend(bbox_to_anchor=(0, -0.2), loc='upper left', borderaxespad=0)
        axL.legend()
        #Train Accuracy vs. Iteration graph
        axR.set_title("Train_Accuracy")
        axR.plot(MLQN_history.history['accuracy'],color="blue", label="MLQN")
        axR.plot(MLMoQ_history.history['accuracy'], color="m",label="MLMoQ")
        axR.plot(Adam_history.history['accuracy'],color="orange", label="Adam")
        axR.set_xlabel('Iterations')
        axR.set_ylabel('Train Accuracy')
        axR.legend(bbox_to_anchor=(0, -0.2), loc='upper left', borderaxespad=0)
        axR.legend()
        plt.show()
        
        fig, (axL, axR) = plt.subplots(ncols=2, figsize=(10,4))
        #Test Loss vs. Iteration graph
        axL.set_title("Test_Loss")
        axL.plot(MLQN_history.history['val_loss'],color="blue", label="MLQN")
        axL.plot(MLMoQ_history.history['val_loss'],color="m", label="MLMoQ")
        axL.plot(Adam_history.history['val_loss'],color="orange",label="Adam")
        axL.set_xlabel('Iterations')
        axL.set_ylabel('Test Loss')
        axL.legend(bbox_to_anchor=(0, -0.2), loc='upper left', borderaxespad=0)
        axL.legend()

        #Test Accuracy vs. Iteration graph
        axR.set_title("Test_Accuracy")
        axR.plot(MLQN_history.history['val_accuracy'],color="blue",label="MLQN")
        axR.plot(MLMoQ_history.history['val_accuracy'],color="m", label="MLMoQ")
        axR.plot(Adam_history.history['val_accuracy'],color="orange", label="Adam")
        axR.set_xlabel('Iterations')
        axR.set_ylabel('Test Accuracy')
        axR.legend(bbox_to_anchor=(0, -0.2), loc='upper left', borderaxespad=0)
        axR.legend()
        plt.show()
        
        

for i in range(10):
    print(i + 1)
    compare_MNIST(i)