In [46]:
%load_ext autoreload
%autoreload 2

import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.metrics import CategoricalAccuracy

from models import *
from utils import *

import warnings
from scipy.sparse import SparseEfficiencyWarning
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=SparseEfficiencyWarning)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [30]:
class ModelEval():
    def __init__(self, print_freq=10, continual_learning=False, ewc_type='add',
                 lambda_=0.1, verbose=True, print_previous_val=False, use_dropout=False):
        super(ModelEval, self).__init__()

        self.print_freq = print_freq
        self.tasks = load_tasks()
        self.prior_weights = dict()
        self.fisher_matrix = dict()
        self.loss_fn = CategoricalCrossentropy()
        self.mean_accuracy = GlobalAccuracy()
        self.backward_transfer = BackwardTransfer(len(self.tasks))
        self.continual_learning = continual_learning
        self.lambda_ = lambda_
        self.model = None
        self.verbose = verbose
        self.print_previous_val = print_previous_val
        self.optim = None
        self.use_dropout = use_dropout

        assert(ewc_type in ['add', 'combine', 'last'])
        self.ewc_type = ewc_type
    
    def train(self, task_num, epochs=50, lr=0.001):
        (x, adj, y), (mask_tr, _, _) = self.tasks[task_num]

        # initialize model if not already trained
        if self.model is None:
            self.model = GNNNodeClassifier(use_dropout=self.use_dropout)
            self.optim = Adam(learning_rate=lr)
        
        accuracy = CategoricalAccuracy()

        train_acc_list = []
        train_loss_list = []
        val_acc_list = {i: [] for i in range(task_num+1)}
        val_loss_list = {i: [] for i in range(task_num+1)}
        
        for i in range(epochs+1):
            with tf.GradientTape() as tape:
                pred = self.model([x, adj])
                train_loss = self.loss_fn(y[mask_tr], pred[mask_tr])

                # add EWC penalty if training with continual learning
                if self.continual_learning:
                    train_loss += self.ewc_penalty(task_num)
                
            accuracy.update_state(y_pred=pred[mask_tr], y_true=y[mask_tr])

            gradients = tape.gradient(train_loss, self.model.trainable_variables)
            self.optim.apply_gradients(zip(gradients, self.model.trainable_variables))

            train_acc = accuracy.result().numpy()
            
            if i % self.print_freq == 0 and self.verbose:
                val_loss, val_acc = self.test(task_num, validation=True, print_acc=False)
                s = f"epoch={i}, train_loss={train_loss:.3f}, train_acc={train_acc:.3f}, val_acc={val_acc:.3f}"

                # saving data for plotting
                train_acc_list.append(train_acc)
                train_loss_list.append(train_loss.numpy())
                val_acc_list[task_num].append(val_acc)
                val_loss_list[task_num].append(val_loss.numpy())

                # appending summary on previous training set
                if self.print_previous_val:
                    for j in range(task_num):
                        prev_loss, prev_acc = self.test(j, validation=True, print_acc=False)
                        s += f", task[{j}]_val_acc={prev_acc:.3f}"
                        val_acc_list[j].append(prev_acc)
                        val_loss_list[j].append(prev_loss.numpy())

                print(s)
        
        # after training on current task add accuracy on previous tasks
        for i in range(task_num+1):
            (x_i, a_i, y_i), (_, _, mask_te_i) = self.tasks[i]
            y_pred_i = self.model([x_i, a_i])[mask_te_i]
            y_true_i = y_i[mask_te_i]

            self.mean_accuracy.update_state(y_pred=y_pred_i, y_true=y_true_i)
            
            self.backward_transfer.update_state(task_i=i, task_j=task_num,
                                                y_pred=y_pred_i, y_true=y_true_i)

        # save the flatted list of parameters and fisher matrix of current model
        self.prior_weights[task_num] = flat_tensor_list(self.model.trainable_variables)
        self.fisher_matrix[task_num] = self.compute_fisher(task_num)

        return train_loss_list, train_acc_list, val_loss_list, val_acc_list
    
    def compute_fisher(self, task_num):
        """ Computes the diagonal Fisher matrix of the model parameters on the target task """
        with tf.GradientTape() as tape:
            (x, a, y), (mask_tr, _, _) = self.tasks[task_num]
            pred = self.model([x, a])
            loss = self.loss_fn(y_pred=pred[mask_tr], y_true=y[mask_tr])
        
        gradients = tape.gradient(loss, self.model.trainable_variables)

        return flat_tensor_list(gradients)**2
    
    def ewc_penalty(self, task_num):
        """ Computes the EWC penalty for the target task """
        penalty = 0

        # current parameters
        theta_new = flat_tensor_list(self.model.trainable_variables)
        matrices = list(self.fisher_matrix.values())

        # the penalty is computed depending on the type of aggregation chosen
        if self.ewc_type == 'add':
            # compute an individual penalty loss for each previous task
            for j in range(task_num):
                fisher_mat = self.fisher_matrix[j]
                theta = self.prior_weights[j]
                penalty += self.lambda_ * tf.reduce_sum(fisher_mat * ((theta-theta_new)**2))
        elif self.ewc_type == 'combine':
            # sum an individual penalty loss using the same aggregate fisher matrix (sum)
            if len(matrices):
                fisher_sum = tf.reduce_sum(list(model_eval.fisher_matrix.values()), axis=0)

                for j in range(task_num):
                    theta = self.prior_weights[j]
                    penalty += self.lambda_ * tf.reduce_sum(fisher_sum * ((theta-theta_new)**2))
        elif self.ewc_type == 'last':
            # use only the last fisher matrix
            if len(matrices):
                theta = self.prior_weights[-1]
                penalty += self.lambda_ * tf.reduce_sum(matrices[-1] * ((theta-theta_new)**2))

        return penalty

    def test(self, task_num, validation=False, print_acc=True):
        """ Test the model on task_num and return the accuracy

        Args:
            task_num (int): target task
            validation (bool, optional): if True performs the test on the validation set. Defaults to False.
            print_acc (bool, optional): if True prints the accuracy on console. Defaults to True.

        Returns:
            float: test accuracy
        """
        (x, adj, y), (_, mask_va, mask_te) = self.tasks[task_num]

        mask = [mask_te, mask_va][validation]

        pred = self.model([x, adj])
        loss = self.loss_fn(y[mask], pred[mask])

        accuracy = CategoricalAccuracy()
        accuracy.update_state(y_pred=pred[mask], y_true=y[mask])
        acc = accuracy.result().numpy()
        
        if print_acc:
            print(f"accuracy task {task_num} = {acc:.3f}")
        
        return loss, acc
    
    def print_mean_accuracy(self):
        acc = self.mean_accuracy.result().numpy()
        print(f"mean_accuracy = {acc:.3f}")
    
    def print_backward_transfer(self):
        acc = self.backward_transfer.result().numpy()
        print(f"backward_transfer = {acc:.3f}")

In [8]:
n_tasks = 3
model_eval = ModelEval(continual_learning=False, verbose=False)

print("Accuracy for target task after training on it from scratch:")
for i in range(n_tasks):
    model_eval.model = None # reset the model
    model_eval.train(i, epochs=100)
    model_eval.test(i)

Pre-processing node features
Accuracy for target task after training on it from scratch:
accuracy task 0 = 0.836
accuracy task 1 = 0.812
accuracy task 2 = 0.835


In [47]:
model_eval = ModelEval(continual_learning=False, verbose=True, print_freq=10, print_previous_val=True)

training_data = dict()

for i in range(n_tasks):
    training_data[i] = model_eval.train(i, epochs=50)

print("Accuracy after training on all 3 tasks:")
model_eval.test(0)
model_eval.test(1)
model_eval.test(2)

model_eval.print_mean_accuracy()

Pre-processing node features
epoch=0, train_loss=1.099, train_acc=0.136, val_acc=0.755
epoch=10, train_loss=1.012, train_acc=0.690, val_acc=0.755
epoch=20, train_loss=0.831, train_acc=0.716, val_acc=0.755
epoch=30, train_loss=0.717, train_acc=0.726, val_acc=0.755
epoch=40, train_loss=0.692, train_acc=0.731, val_acc=0.755
epoch=50, train_loss=0.662, train_acc=0.734, val_acc=0.755
epoch=0, train_loss=0.941, train_acc=0.618, val_acc=0.626, task[0]_val_acc=0.755
epoch=10, train_loss=0.865, train_acc=0.618, val_acc=0.626, task[0]_val_acc=0.755
epoch=20, train_loss=0.795, train_acc=0.618, val_acc=0.626, task[0]_val_acc=0.755
epoch=30, train_loss=0.692, train_acc=0.618, val_acc=0.626, task[0]_val_acc=0.755
epoch=40, train_loss=0.604, train_acc=0.632, val_acc=0.708, task[0]_val_acc=0.683
epoch=50, train_loss=0.540, train_acc=0.654, val_acc=0.708, task[0]_val_acc=0.600
epoch=0, train_loss=1.297, train_acc=0.524, val_acc=0.565, task[0]_val_acc=0.582, task[1]_val_acc=0.711
epoch=10, train_loss=0.

In [44]:
import matplotlib.pyplot as plt

# train_loss_list, train_acc_list, val_loss_list, val_acc_list

epochs = list(range(0, 51, 10))

training_data[0][1], training_data[1][1]

([0.13326654, 0.68983424, 0.71633744, 0.7257418, 0.7305587, 0.7334866],
 [0.6182365, 0.6182365, 0.6182365, 0.6182365, 0.62004495, 0.64499587])

In [51]:
model_eval = ModelEval(continual_learning=True, lambda_=0.1,
                       verbose=True, print_freq=10, print_previous_val=True)

num_tasks = 3

for i in range(num_tasks):
    print(f"Training on task {i}")
    model_eval.train(i, epochs=50)
    print()

for i in range(num_tasks):
    print(f"Testing on task {i}")
    model_eval.test(i)
    print()

model_eval.print_mean_accuracy()
model_eval.print_backward_transfer()

Pre-processing node features
Training on task 0
epoch=0, train_loss=1.099, train_acc=0.436, val_acc=0.755
epoch=10, train_loss=1.012, train_acc=0.717, val_acc=0.755
epoch=20, train_loss=0.828, train_acc=0.731, val_acc=0.755
epoch=30, train_loss=0.715, train_acc=0.736, val_acc=0.755
epoch=40, train_loss=0.686, train_acc=0.738, val_acc=0.755
epoch=50, train_loss=0.659, train_acc=0.739, val_acc=0.755

Training on task 1
epoch=0, train_loss=0.939, train_acc=0.618, val_acc=0.626, task[0]_val_acc=0.755
epoch=10, train_loss=0.864, train_acc=0.618, val_acc=0.626, task[0]_val_acc=0.755
epoch=20, train_loss=0.807, train_acc=0.618, val_acc=0.626, task[0]_val_acc=0.755
epoch=30, train_loss=0.717, train_acc=0.618, val_acc=0.626, task[0]_val_acc=0.755
epoch=40, train_loss=0.628, train_acc=0.626, val_acc=0.711, task[0]_val_acc=0.707
epoch=50, train_loss=0.559, train_acc=0.649, val_acc=0.713, task[0]_val_acc=0.641

Training on task 2
epoch=0, train_loss=1.249, train_acc=0.515, val_acc=0.558, task[0]_v