In [57]:
%load_ext autoreload
%autoreload 2

import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.metrics import CategoricalAccuracy
import matplotlib.pyplot as plt
from os import listdir

from models import *
from utils import *

import warnings
from scipy.sparse import SparseEfficiencyWarning
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=SparseEfficiencyWarning)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [58]:
class ModelEval():
    def __init__(self, print_freq=10):
        super(ModelEval, self).__init__()

        self.print_freq = print_freq
        self.tasks = load_tasks()
        self.prior_weights = dict()
        self.fisher_matrix = dict()
        self.loss_fn = CategoricalCrossentropy()
        self.mean_accuracy = GlobalAccuracy()
    
    def train(self, task_num, epochs=20, continual_learning=False, lambda_=0.1, lr=0.001, verbose=True):
        (x, adj, y), (mask_tr, _, _) = self.tasks[task_num]

        # initialize model if not already trained
        if not continual_learning:
            self.model = GNNNodeClassifier()
        
        accuracy = CategoricalAccuracy()
        optim = Adam(learning_rate=lr)
        
        for i in range(epochs):
            with tf.GradientTape() as tape:
                pred = self.model([x, adj])
                loss = self.loss_fn(y[mask_tr], pred[mask_tr])

                # if continual learning is enabled add l2 loss wrt previous weights
                if continual_learning:
                    for i in range(task_num):
                        fisher_mat = self.fisher_matrix[i]
                        theta = self.prior_weights[i]
                        theta_new = flat_tensor_list(self.model.trainable_variables)
                        loss += lambda_ * tf.reduce_sum(fisher_mat * ((theta-theta_new)**2))
                
            accuracy.update_state(y_pred=pred[mask_tr], y_true=y[mask_tr])

            gradients = tape.gradient(loss, self.model.trainable_variables)
            optim.apply_gradients(zip(gradients, self.model.trainable_variables))

            acc = accuracy.result().numpy()
            
            if i % self.print_freq == 0 and verbose:
                print(f"epoch={i}, train_loss={loss:.3f}, train_acc={acc:.3f}")
        
        # after training on current task add accuracy on previous tasks
        for i in range(task_num):
            (x_prev, a_prev, y_prev), (mask_tr_prev, _, _) = self.tasks[i]
            pred_prev = self.model([x_prev, a_prev])
            self.mean_accuracy.update_state(y_pred=pred_prev[mask_tr_prev],
                                            y_true=y_prev[mask_tr_prev])

        # save the flatted list of parameters and fisher matrix of current model
        self.prior_weights[task_num] = flat_tensor_list(self.model.trainable_variables)
        self.fisher_matrix[task_num] = self.compute_fisher(task_num)
    
    def compute_fisher(self, task_num):
        with tf.GradientTape() as tape:
            (x, a, y), (mask_tr, _, _) = self.tasks[task_num]
            pred = self.model([x, a])
            loss = self.loss_fn(y_pred=pred[mask_tr], y_true=y[mask_tr])
        
        gradients = tape.gradient(loss, self.model.trainable_variables)

        return flat_tensor_list(gradients)**2

    def test(self, task_num):
        (x, adj, y), (_, _, mask_te) = self.tasks[task_num]

        pred = self.model([x, adj])
        accuracy = CategoricalAccuracy()
        accuracy.update_state(y_pred=pred[mask_te], y_true=y[mask_te])
        
        return accuracy.result().numpy()

In [59]:
model_eval = ModelEval()
n_tasks = 3
epochs = 80

for i in range(n_tasks):
    model_eval.train(i, epochs=epochs, continual_learning=False, verbose=False)
    acc = model_eval.test(i)
    print(f"task={i}, test_acc={acc:.3f}")

Pre-processing node features
task=0, test_acc=0.759
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0
prev_acc=0.0

In [54]:
model_eval.mean_accuracy.result()

<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

In [None]:
model_eval = ModelEval()
epochs = 100

model_eval.train(0, epochs=epochs, lr=0.01, verbose=False)
model_eval.train(1, epochs=epochs, lr=0.01, verbose=False)
model_eval.train(2, epochs=epochs, lr=0.01, verbose=False)

acc0 = model_eval.test(0)
acc1 = model_eval.test(1)
acc2 = model_eval.test(2)

print("Accuracy after training on two three without CL")
print(f"task=0, test_acc={acc0:.3f}")
print(f"task=1, test_acc={acc1:.3f}")
print(f"task=2, test_acc={acc2:.3f}")

Pre-processing node features
Accuracy after training on two three without CL
task=0, test_acc=0.482
task=1, test_acc=0.248
task=2, test_acc=0.854


In [53]:
model_eval = ModelEval(print_freq=25)

from sklearn.model_selection import ParameterGrid

param_grid = {'epochs': [100, 200], 'lambda_': [0.1, 1, 10, 100], 'lr': [0.001, 0.01]}
for grid in list(ParameterGrid(param_grid)):

    model_eval.train(0, verbose=False, **grid)
    model_eval.train(1, continual_learning=True, verbose=False, **grid)
    model_eval.train(2, continual_learning=True, verbose=False, **grid)

    acc0 = model_eval.test(0)
    acc1 = model_eval.test(1)
    acc2 = model_eval.test(2)

    #print("Accuracy after training on two tasks with CL and EWC")
    #print(f"task=0, test_acc={acc0:.3f}")
    #print(f"task=1, test_acc={acc1:.3f}")
    #print(f"task=2, test_acc={acc2:.3f}")
    avg_acc = np.mean([acc0, acc1, acc2])
    print(f"epochs={grid['epochs']}, lr={grid['lr']}, lambda_={grid['lambda_']}, avg_acc={avg_acc:.3f}")

Pre-processing node features
epochs=100, lr=0.001, lambda_=0.1, avg_acc=0.548
epochs=100, lr=0.01, lambda_=0.1, avg_acc=0.558
epochs=100, lr=0.001, lambda_=1, avg_acc=0.556
epochs=100, lr=0.01, lambda_=1, avg_acc=0.559
epochs=100, lr=0.001, lambda_=10, avg_acc=0.548
epochs=100, lr=0.01, lambda_=10, avg_acc=0.551
epochs=100, lr=0.001, lambda_=100, avg_acc=0.551
epochs=100, lr=0.01, lambda_=100, avg_acc=0.547
epochs=200, lr=0.001, lambda_=0.1, avg_acc=0.556
epochs=200, lr=0.01, lambda_=0.1, avg_acc=0.540
epochs=200, lr=0.001, lambda_=1, avg_acc=0.548
epochs=200, lr=0.01, lambda_=1, avg_acc=0.555
epochs=200, lr=0.001, lambda_=10, avg_acc=0.551
epochs=200, lr=0.01, lambda_=10, avg_acc=0.551


KeyboardInterrupt: 