In [None]:
import pickle 
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
import scipy.spatial as sp

# *Plots for permuted MNIST Mean Accuracy across tasks*

In [None]:
fig = plt.figure(figsize=(10,7), dpi=100)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)

ntasks = 50
nseeds = 5

## LXDG+EWC 

lxdg_perf = np.zeros((ntasks,nseeds))

for seed in range(nseeds):
    with open(f'results/permuted_MNIST_perm_LXDG_EWC_{seed}.pkl','rb') as loadf:
        mean_acc, _= pickle.load(loadf)
        lxdg_perf[:,seed] = mean_acc[:]
        
lxdg_perf_c = np.mean(np.array(lxdg_perf), axis=1)
lxdg_perf_std = np.std(np.array(lxdg_perf), axis=1)
    
sns.lineplot(data=lxdg_perf_c, label='LXDG + EWC, learned context', color='red')
plt.fill_between(np.arange(ntasks), lxdg_perf_c-lxdg_perf_std, lxdg_perf_c+lxdg_perf_std, color='red', alpha = 0.5)


## EWC 

ewc_perf = np.zeros((ntasks,nseeds))

for seed in range(nseeds):
    with open(f'results/permuted_MNIST_perm_EWC_only_{seed}.pkl','rb') as loadf:
        mean_acc, _= pickle.load(loadf)
        ewc_perf[:,seed] = mean_acc[:]
    
ewc_perf_c = np.mean(np.array(ewc_perf), axis=1)
ewc_perf_std = np.std(np.array(ewc_perf), axis=1)

sns.lineplot(data=ewc_perf_c, label='EWC only', linestyle = 'dotted',  color='blue')
plt.fill_between(np.arange(ntasks), ewc_perf_c-ewc_perf_std, ewc_perf_c+ewc_perf_std, alpha = 0.5,  color='blue')


# XDG + EWC

xdgewc_perf = np.zeros((ntasks,nseeds))

for seed in range(nseeds):
    with open(f'results/permuted_MNIST_perm_XDG_EWC_{seed}.pkl','rb') as loadf:
        mean_acc, _= pickle.load(loadf)
        xdgewc_perf[:,seed] = mean_acc[:]

xdgewc_perf_c = np.mean(np.array(xdgewc_perf), axis=1)
xdgewc_perf_std = np.std(np.array(xdgewc_perf), axis=1)

sns.lineplot(data=xdgewc_perf_c, label='XDG + EWC, hand input context', color = 'grey', linestyle = '--')
plt.fill_between(np.arange(ntasks), xdgewc_perf_c-xdgewc_perf_std, xdgewc_perf_c+xdgewc_perf_std, alpha = 0.5,  color='grey')
    

# No Con

nocon_perf = np.zeros((ntasks,nseeds))

for seed in range(nseeds):
    with open(f'results/permuted_MNIST_perm_NOCON_{seed}.pkl','rb') as loadf:
        mean_acc, _= pickle.load(loadf)
        nocon_perf[:,seed] = mean_acc[:]
nocon_perf_c = np.mean(np.array(nocon_perf), axis=1)
nocon_perf_std = np.std(np.array(nocon_perf), axis=1)
        
sns.lineplot(data=nocon_perf_c, label='No constraint', linestyle = '-.', color='green')
plt.fill_between(np.arange(ntasks), nocon_perf_c-nocon_perf_std, nocon_perf_c+nocon_perf_std, alpha = 0.5,  color='green')
    

plt.xlabel('Permuted MNIST task', fontsize = 22)
plt.ylabel('Mean Accuracy', fontsize =22)

plt.xlim(0,ntasks-1)
plt.ylim(0.,1.)
plt.legend(fontsize= 14)


# *Plots for permuted MNIST Mean Accuracy across tasks*

In [None]:
fig = plt.figure(figsize=(10,7), dpi=100)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)

ntasks = 50
nseeds = 5

## LXDG+EWC 

lxdg_perf = np.zeros((ntasks,nseeds))

for seed in range(nseeds):
    with open(f'results/rotated_MNIST_rot_LXDG_EWC_{seed}.pkl','rb') as loadf:
        mean_acc, _= pickle.load(loadf)
        lxdg_perf[:,seed] = mean_acc[:]
        
lxdg_perf_c = np.mean(np.array(lxdg_perf), axis=1)
lxdg_perf_std = np.std(np.array(lxdg_perf), axis=1)
    
sns.lineplot(data=lxdg_perf_c, label='LXDG + EWC, learned context', color='red')
plt.fill_between(np.arange(ntasks), lxdg_perf_c-lxdg_perf_std, lxdg_perf_c+lxdg_perf_std, color='red', alpha = 0.5)

## EWC 

ewc_perf = np.zeros((ntasks,nseeds))

for seed in range(nseeds):
    with open(f'results/rotated_MNIST_rot_EWC_only_{seed}.pkl','rb') as loadf:
        mean_acc, _= pickle.load(loadf)
        ewc_perf[:,seed] = mean_acc[:]
    
ewc_perf_c = np.mean(np.array(ewc_perf), axis=1)
ewc_perf_std = np.std(np.array(ewc_perf), axis=1)

sns.lineplot(data=ewc_perf_c, label='EWC only', linestyle = 'dotted',  color='blue')
plt.fill_between(np.arange(ntasks), ewc_perf_c-ewc_perf_std, ewc_perf_c+ewc_perf_std, alpha = 0.5,  color='blue')

xdgewc_perf = np.zeros((ntasks,nseeds))

for seed in range(nseeds):
    with open(f'results/rotated_MNIST_rot_XDG_EWC_{seed}.pkl','rb') as loadf:
        mean_acc, _= pickle.load(loadf)
        xdgewc_perf[:,seed] = mean_acc[:]

xdgewc_perf_c = np.mean(np.array(xdgewc_perf), axis=1)
xdgewc_perf_std = np.std(np.array(xdgewc_perf), axis=1)

sns.lineplot(data=xdgewc_perf_c, label='XDG + EWC, hand input context', color = 'grey', linestyle = '--')
plt.fill_between(np.arange(ntasks), xdgewc_perf_c-xdgewc_perf_std, xdgewc_perf_c+xdgewc_perf_std, alpha = 0.5,  color='grey')
    
    
nocon_perf = np.zeros((ntasks,nseeds))

for seed in range(nseeds):
    with open(f'results/rotated_MNIST_rot_NOCON_{seed}.pkl','rb') as loadf:
        mean_acc, _= pickle.load(loadf)
        nocon_perff[:,seed] = mean_acc[:]

sns.lineplot(data=lxdgff, label='No constraint', linestyle = '-.', color='green')
    
plt.xlabel('Rotated MNIST task', fontsize = 22)
plt.ylabel('Mean Accuracy', fontsize =22)

plt.xlim(0,ntasks-1)
plt.ylim(0.,1.)
plt.legend(fontsize= 14)


# *Function for performing PCA on the gates*

In [None]:
def gate_pca(max_ctxt = 3, prep = 0, pca_input = None, layer = 1, name= ''):
    gates = []
    incl_th = 0.1
    
    for n in range(max_ctxt):
        gates.append((pickle.load(open(f'results/gate_vectors/{name}_gate_vector_{layer}_trained_{max_ctxt-1}_task_{n}.dat', 'rb'))>incl_th).astype(int))

    if pca_input == None:
        pca = PCA(n_components=2)
        decomp = pca.fit_transform(gates)
    else:
        pca = pca_input
        decomp = pca.transform(gates)
    
    return decomp, pca

# *Plots for permuted MNIST PCA*

In [None]:
def plot_pca(ax, layer, maxtasks):
    decomp, _ = gate_pca(max_ctxt=maxtasks, pca_input=None, layer=layer, name='perm_LXDG_EWC_0')
    for count, (x, y) in enumerate(decomp):
        np.random.seed(42 + count)
        if count == 0:
            c = (0, 0.5, 0) # green for first task
        else:
            c = (count / maxtasks, 0, 1 - count / maxtasks)
        ax.scatter(x, y, 100, marker='x', color=c, alpha=1.)

    ax.set_xlabel('PC1', fontsize=22)
    ax.set_ylabel('PC2', fontsize=22)
    ax.set_title(f'Layer {layer}', fontsize=22)


def set_tick_params(ax, labelsize):
    ax.xaxis.set_tick_params(labelsize=labelsize)
    ax.yaxis.set_tick_params(labelsize=labelsize)


maxtasks = 50
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 9), dpi=100)

set_tick_params(ax1, 16)
set_tick_params(ax2, 16)

plot_pca(ax1, layer=1, maxtasks=maxtasks)
plot_pca(ax2, layer=2, maxtasks=maxtasks)

plt.show()


# *Plots for rotated MNIST PCA*

In [None]:
def plot_pca_with_rotation(ax, layer, maxtasks, task_rot):
    decomp, _ = gate_pca(max_ctxt=maxtasks, pca_input=None, layer=layer, name='rot_LXDG_EWC_0')
    for count, (x, y) in enumerate(decomp):
        c = (task_rot[count] / 180., 0, 1 - task_rot[count] / 180.)
        ax.scatter(x, y, 100, marker='x', color=c, alpha=1.)

    ax.set_xlabel('PC1', fontsize=22)
    ax.set_ylabel('PC2', fontsize=22)
    ax.set_title(f'Layer {layer}', fontsize=22)


def set_tick_params(ax, labelsize):
    ax.xaxis.set_tick_params(labelsize=labelsize)
    ax.yaxis.set_tick_params(labelsize=labelsize)


maxtasks = 50

task_rot = []
for i in range(50):
    rng_permute = np.random.seed(42 * 100 + i)
    task_rot.append(np.random.uniform(0, 180))

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 9), dpi=100)

set_tick_params(ax1, 16)
set_tick_params(ax2, 16)

plot_pca_with_rotation(ax1, layer=1, maxtasks=maxtasks, task_rot=task_rot)
plot_pca_with_rotation(ax2, layer=2, maxtasks=maxtasks, task_rot=task_rot)

plt.show()