In [None]:
import sys
sys.path.append("..")
from train_by_reconnect.LaPerm import LaPermTrainLoop
from train_by_reconnect.weight_utils import agnosticize
from train_by_reconnect.viz_utils import Profiler

In [None]:
import math
import numpy as np
import tensorflow as tf
print(tf.__version__)
from tqdm.notebook import trange as nested_progress_bar
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [None]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.regularizers import l2

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = np.expand_dims(x_train.astype('float32')/255.0, -1)
x_test = np.expand_dims(x_test.astype('float32')/255.0, -1)

# no data augmentation
datagen = ImageDataGenerator()
datagen.fit(x_train)

# training hyper_parameters
batch_size = 128
learning_rate = 0.001 # initial learning rate
tsize = 30000 # size of data for getting the train accuracy
vali_freq = 250 # validate per vali_freq batches

def lr_scheduler(epoch):
    return learning_rate * 0.95 ** epoch

def k_scheduler(epoch):
    return 250

In [None]:
# Use a shared weight value
val = 0.08

# Prune the trainable_variable
rate = [(4, 7)]

# Section 5.5 F_1 model definition
F1 = Sequential()
F1.add(Flatten(input_shape=(28, 28, 1)))
F1.add(Dense(10, activation='softmax', use_bias=False))

# Make F1 weight agnostic
agnosticize(F1, val=val, prune_ratio=rate)
# Confirm if the model is weight agnostic
Profiler(F1, skip_1d=False)

# Train with LaPerm for 10 epochs
epochs = 10
loop = LaPermTrainLoop(model=F1, loss='sparse_categorical_crossentropy', inner_optimizer=tf.keras.optimizers.Adam(),
                       k_schedule=k_scheduler,
                       lr_schedule=lr_scheduler,
                       skip_bias=False)
loop.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, datagen=datagen,
         validation_data=(x_test, y_test), validation_freq=vali_freq, tsize=tsize)

# Use the weights with the best validation accuracy
F1.set_weights(loop.best_weights)
# Confirm accuracy using Keras's evaluate
F1.compile(loss='sparse_categorical_crossentropy', metrics=['acc'])
F1.evaluate(x_test, y_test)

# Confirm if the model is stil weight agnostic
Profiler(F1, skip_1d=False)

In [None]:
# Please refer to the cell above for more explanations.
val = 0.03
rate = [(1, 15), (1, 7), (1, 3)]

# Section 5.5 F_2 model definition
F2 = Sequential()
F2.add(Flatten(input_shape=(28, 28, 1)))
F2.add(Dense(128, activation='relu', use_bias=False, kernel_regularizer=l2(1e-4)))
F2.add(Dense(64, activation='relu', use_bias=False, kernel_regularizer=l2(1e-4)))
F2.add(Dense(10, activation='softmax', use_bias=False))

agnosticize(F2, val, rate)
Profiler(F2)

epochs = 25
loop2 = LaPermTrainLoop(model=F2, loss='sparse_categorical_crossentropy', inner_optimizer=tf.keras.optimizers.Adam(),
                        k_schedule=k_scheduler,
                        lr_schedule=lr_scheduler,
                        skip_bias=False)
loop2.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, datagen=datagen,
          validation_data=(x_test, y_test), validation_freq=vali_freq, tsize=tsize)

F2.set_weights(loop2.best_weights)
F2.compile(loss='sparse_categorical_crossentropy', metrics=['acc'])
F2.evaluate(x_test, y_test)

Profiler(F2, skip_1d=False)

In [None]:
# Visualize train and validation accuracies
import numpy as np
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 6))
plt.plot(loop._history['val accuracy'], label='Validation Accuracy')
plt.plot(loop._history['accuracy'], label='Train Accuracy')
plt.grid(linestyle='--')
plt.xlabel('Epochs', size=15)
plt.ylabel('Accuracy', size=15)
plt.legend(prop={'size':15})
plt.show()