In [None]:
#reload edited modules automatically
%load_ext autoreload
%autoreload 2

In [None]:
from utils import create_permuted_mnist_task
from model import cnn_model

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
From IPython import display

In [None]:
#ignore the warning messages
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

In [None]:
#Plot the accuracy of test data
#Parameters:
# - name: the name of the model. It will be used in label
# - acc: list of accuracy
# - data_num: which data is plotted(D1,D2 or D3)
def acc_plot(name,acc,data_num):
	plt.figure(1)
	sub = '31'+str(data_num)
	plt.subplot(sub)
	plt.title('test accuracy on {}th dataset'.format(data_num))
	plt.plot(acc,label=name)
	plt.ylabel('acc')
	plt.xlabel('training time')
	for i in range(len(acc)-1):
		plt.vlines((i+1),0,1,color='r',linestyles='dashed')
	plt.legend(loc='upper right')
	plt.subplots_adjust(wspace=1,hspace=1)
	plt.savefig('./images/permuted.png'.format(name))
    display.display(plt.gcf())
    display.clear_output(wait=True)
    
def plot_history(file_name,model,shift=10):
    plt.subplot(211)
    plt.title('accuracy on current training data')
    for i in range(shift):
        plt.vlines(self.epoch*(i+1),0,1,color='r',linestyles='dashed')

    plt.plot(self.history.history['acc'],label='{}'.format(model))
    plt.ylabel('acc')
    plt.xlabel('training time')
    plt.legend(loc='upper right')
    plt.subplot(212)
    plt.title('validation accuracy on original data')
    plt.plot(self.history.history['val_acc'],label='{}'.format(model))
    plt.ylabel('acc')
    plt.xlabel('training time')
    for i in range(shift):
        plt.vlines(self.epoch*(i+1),0,1,color='r',linestyles='dashed')
    plt.legend(loc='upper right')
    plt.subplots_adjust(wspace=1,hspace=1)
    plt.savefig('./images/{}.png'.format(file_name))
    display.display(plt.gcf())
    display.clear_output(wait=True)

In [None]:
#record the test accuracy
def test_acc(model,acc_test_d1,acc_test_d2,acc_test_d3):
    acc_test_d1.append(model.evaluate(task[0].test.images,task[0].test.labels))
    acc_test_d2.append(model.evaluate(task[1].test.images,task[1].test.labels))
    acc_test_d3.append(model.evaluate(task[2].test.images,task[2].test.labels))

def save_acc(name,d):
	import json
	for i in range(3):
		path = './logs/permuted/acc{}/{}.txt'.format(str(i+1),name)
		with open(path,'w') as f:
			json.dump(d[i],f)

In [None]:
def train(name):

    acc_test_d1 = []
    acc_test_d2 = []
    acc_test_d3 = []
    t1_x,t1_y = task[0].train.images,task[0].train.labels
    t2_x,t2_y = task[1].train.images,task[1].train.labels
    t3_x,t3_y = task[2].train.images,task[2].train.labels

    model = cnn_model()
    model.val_data(task[0].validation.images,task[0].validation.labels)
    model.fit(t1_x,t1_y)


    test_acc(model,acc_test_d1,acc_test_d2,acc_test_d3)
    
    if name == 'kal':
        for i in range(1,TASK_NUM):
            print('---'*10,i,'---'*10)
            model.transfer(task[i].train.images,task[i].train.labels)
            test_acc(model,acc_test_d1,acc_test_d2,acc_test_d3)
        model.transfer(t1_x,t1_y,num=1)
        test_acc(model,acc_test_d1,acc_test_d2,acc_test_d3)
        model.transfer(t1_x,t1_y,num=2)
        test_acc(model,acc_test_d1,acc_test_d2,acc_test_d3)
        model.transfer(t1_x,t1_y,num=3)
        test_acc(model,acc_test_d1,acc_test_d2,acc_test_d3)

    if name == 'nor':
        for i in range(1,TASK_NUM):
            print('---'*10,i,'---'*10)
            model.fit(task[i].train.images,task[i].train.labels)
            test_acc(model,acc_test_d1,acc_test_d2,acc_test_d3)
        model.fit(t1_x,t1_y)
        test_acc(model,acc_test_d1,acc_test_d2,acc_test_d3)

    model.save(name)
    history = model.get_history()
    
    return acc_test_d1,acc_test_d2,acc_test_d3

In [None]:
#Load the drift data
TASK_NUM = 10
task = create_permuted_mnist_task(TASK_NUM)

In [None]:
print('--'*10,'kal','--'*10)
kal_d1,kal_d2,kal_d3 = train('kal')
save_acc('kal',[kal_d1,kal_d2,kal_d3])

In [None]:
print('--' * 10, 'nor', '--' * 10)
nor_d1, nor_d2, nor_d3 = train('nor')
save_acc('nor', [nor_d1, nor_d2, nor_d3])