# Import libraries

In [None]:
# Reload modules
%load_ext autoreload
%autoreload 2
# Inline plots
%matplotlib inline

# Standard
import numpy as np
import matplotlib.pyplot as plt
import time as time
from copy import deepcopy

# tf
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# Mine
from utility import display_image, permutate # Utilities
from model import Net   # Model definition
from train import train # Training code

# Set parameters

In [None]:
# Set random seed
num_seed = 1
tf.set_random_seed(num_seed)

# Training
N_task = 2       # Number of tasks
N_it = 800       # Number of iterations
batch_size = 100 # Number of samples in each minibatch

# Network
hidden_size = 50       # Number of hidden layer neurons

# Regularization and Fisher
lambda_L2 = 1000000        # Regularization parameter for L2
lambda_EWC = 1000000         # Regularization parameter for EWC
sample_size_Fish = 200 # Number of samples to use to estimate Fisher Info.

# Miscellaneous
ep_rec = 20            # Record test accuracy every ep_rec iterations
ep_time = int(N_it/1) # Display time every ep_time iterations

# Load data

In [None]:
# Extract dataset
dataset = input_data.read_data_sets('MNIST_data', one_hot=True)

# Create permutated datasets

In [None]:
datasets = [dataset] # List of datasets

# Append permuted datasets
for i in range(N_task-1):
    datasets.append(permutate(dataset))

# Plot sample images

In [None]:
sample_index = 10      # Sample image to display
# Display label
label = datasets[0].train.labels[sample_index]
label = np.nonzero(label)[0][0]
print('label = ' + str(label))

# Display images
plt.figure()
for i in range(N_task):
    image = datasets[i].train.images[sample_index]

    plt.subplot(1,N_task,1+i)
    display_image(image)

# Create models

In [None]:
sess = tf.InteractiveSession() # Initialize session
# input and output placeholders
input_dim = datasets[0].train.images.shape[1]
output_dim = datasets[0].train.labels.shape[1]
inputs = tf.placeholder(tf.float32, shape=[None,input_dim])
outputs = tf.placeholder(tf.float32, shape=[None,output_dim])

# Instantiate models
tf.set_random_seed(num_seed)
model = Net(inputs,outputs,hidden_dim=50)
tf.set_random_seed(num_seed)
model_L2 = Net(inputs,outputs,hidden_dim=50)
tf.set_random_seed(num_seed)
model_EWC = Net(inputs,outputs,hidden_dim=50)

# Initialize variables
sess.run(tf.global_variables_initializer())

# Train

In [None]:
# Initialize test accuracy lists
test_acc_lists = []
test_acc_lists_L2 = []
test_acc_lists_EWC = []
for i in range(N_task):
    test_acc_lists.append(np.zeros([0]))
    test_acc_lists_L2.append(np.zeros([0]))
    test_acc_lists_EWC.append(np.zeros([0]))

time_start = time.time()
time_p1 = time.time()
print('------------------------------------------------------------')
for i in range(N_task):

    print('Task ' + str(i))
    print('------------------------------------------------------------')
    # Vanilla
    test_acc_list = train(sess,model,datasets[i],datasets,inputs,outputs,N_it,batch_size,0,0,ep_rec,ep_time)
    # L2
    test_acc_list_L2 = train(sess,model_L2,datasets[i],datasets,inputs,outputs,N_it,batch_size,1,lambda_L2,ep_rec,ep_time)
    # EWC
    test_acc_list_EWC = train(sess,model_EWC,datasets[i],datasets,inputs,outputs,N_it,batch_size,2,lambda_EWC,ep_rec,ep_time)
    
    # Save parameters
    model_L2.save_parameters()
    model_EWC.save_parameters()
    
    time_F1 = time.time()
    # Compute Fisher Info.
    model_EWC.compute_Fisher(datasets[i].validation.images,sess,sample_size_Fish)
    time_F2 = time.time()
    print('Fisher computation runtime: ' + str(time_F2-time_F1) + ' s')
    
    # Record test accuracies
    for j in range(N_task):
        test_acc_lists[j] = np.hstack([test_acc_lists[j],test_acc_list[j]])
        test_acc_lists_L2[j] = np.hstack([test_acc_lists_L2[j],test_acc_list_L2[j]])
        test_acc_lists_EWC[j] = np.hstack([test_acc_lists_EWC[j],test_acc_list_EWC[j]])
    
    time_p2 = time.time()
    print('Total runtime for task ' + str(i) + ': ' + str(time_p2-time_p1) + ' s')
    time_p1 = time_p2
    
    print('------------------------------------------------------------')
time_finish = time.time()  
print('Total runtime: ' + str(time_finish-time_start) + ' s')


# Plot

In [None]:
plt.figure()

plt.subplot(1,3,1)
plt.title('Vanilla')
plt.subplot(1,3,2)
plt.title('L2')
plt.subplot(1,3,3)
plt.title('EWC')
for i in range(N_task):
    plt.subplot(1,3,1)
    plt.plot(test_acc_lists[i],label='Task ' + str(i))
    plt.subplot(1,3,2)
    plt.plot(test_acc_lists_L2[i],label='Task ' + str(i))
    plt.subplot(1,3,3)
    plt.plot(test_acc_lists_EWC[i],label='Task ' + str(i))

plt.show()

In [None]:
#a=tf.constant(0.1,shape=[50])
#x = tf.placeholder(tf.float32,shape=[None,10])
#a=tf.random_uniform([5,3])
#a.eval()