In [8]:
%load_ext autoreload
%autoreload 2

import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.metrics import CategoricalAccuracy
import matplotlib.pyplot as plt
from os import listdir

from models import *
from utils import *

import warnings
from scipy.sparse import SparseEfficiencyWarning
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=SparseEfficiencyWarning)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
class ModelEval():
    def __init__(self, print_freq=10, continual_learning=False,
                 lambda_=0.1, verbose=True, print_previous_val=False):
        super(ModelEval, self).__init__()

        self.print_freq = print_freq
        self.tasks = load_tasks_new()
        self.prior_weights = dict()
        self.fisher_matrix = dict()
        self.loss_fn = CategoricalCrossentropy()
        self.mean_accuracy = GlobalAccuracy()
        self.backward_transfer = BackwardTransfer(len(self.tasks))
        self.continual_learning = continual_learning
        self.lambda_ = lambda_
        self.model = None
        self.verbose = verbose
        self.print_previous_val = print_previous_val
    
    def train(self, task_num, epochs=100, lr=0.001):
        (x, adj, y), (mask_tr, _, _) = self.tasks[task_num]

        # initialize model if not already trained
        if not self.continual_learning or self.model is None:
            self.model = GNNNodeClassifier()
        
        accuracy = CategoricalAccuracy()
        optim = Adam(learning_rate=self.lr)
        
        for i in range(epochs+1):
            with tf.GradientTape() as tape:
                pred = self.model([x, adj])
                loss = self.loss_fn(y[mask_tr], pred[mask_tr])

                # if continual learning is enabled add l2 loss wrt previous weights
                if self.continual_learning:
                    for j in range(task_num):
                        fisher_mat = self.fisher_matrix[j]
                        theta = self.prior_weights[j]
                        theta_new = flat_tensor_list(self.model.trainable_variables)
                        penalty = self.lambda_ * tf.reduce_sum(fisher_mat * ((theta-theta_new)**2))
                        loss += penalty
                
            accuracy.update_state(y_pred=pred[mask_tr], y_true=y[mask_tr])

            gradients = tape.gradient(loss, self.model.trainable_variables)
            optim.apply_gradients(zip(gradients, self.model.trainable_variables))

            train_acc = accuracy.result().numpy()
            
            if i % self.print_freq == 0 and self.verbose:
                val_acc = self.test(task_num, validation=True, print_acc=False)
                s = f"epoch={i}, train_loss={loss:.3f}, train_acc={train_acc:.3f}, val_acc={val_acc:.3f}"
                
                if self.print_previous_val:
                    for j in range(task_num):
                        prev_acc = self.test(j, validation=True, print_acc=False)
                        s += f", task[{j}]_val_acc={prev_acc:.3f}"
                
                print(s)
        
        # after training on current task add accuracy on previous tasks
        for i in range(task_num+1):
            (x_i, a_i, y_i), (_, _, mask_te_i) = self.tasks[i]
            y_pred_i = self.model([x_i, a_i])[mask_te_i]
            y_true_i = y_i[mask_te_i]

            self.mean_accuracy.update_state(y_pred=y_pred_i, y_true=y_true_i)
            
            self.backward_transfer.update_state(task_i=i, task_j=task_num,
                                                y_pred=y_pred_i, y_true=y_true_i)

        # save the flatted list of parameters and fisher matrix of current model
        self.prior_weights[task_num] = flat_tensor_list(self.model.trainable_variables)
        self.fisher_matrix[task_num] = self.compute_fisher(task_num)
    
    def compute_fisher(self, task_num):
        with tf.GradientTape() as tape:
            (x, a, y), (mask_tr, _, _) = self.tasks[task_num]
            pred = self.model([x, a])
            loss = self.loss_fn(y_pred=pred[mask_tr], y_true=y[mask_tr])
        
        gradients = tape.gradient(loss, self.model.trainable_variables)

        return flat_tensor_list(gradients)**2

    def test(self, task_num, validation=False, print_acc=True):
        (x, adj, y), (_, mask_va, mask_te) = self.tasks[task_num]

        mask = [mask_te, mask_va][validation]

        pred = self.model([x, adj])
        accuracy = CategoricalAccuracy()
        accuracy.update_state(y_pred=pred[mask], y_true=y[mask])
        val = accuracy.result().numpy()
        
        if print_acc:
            print(f"accuracy task {task_num} = {val:.3f}")
        
        return val
    
    def print_mean_accuracy(self):
        acc = self.mean_accuracy.result().numpy()
        print(f"mean_accuracy = {acc:.3f}")
    
    def print_backward_transfer(self):
        acc = self.backward_transfer.result().numpy()
        print(f"backward_transfer = {acc:.3f}")

In [4]:
n_tasks = 3

model_eval = ModelEval(continual_learning=False, verbose=False)

print("Accuracy after training on the task:")
for i in range(n_tasks):
    model_eval.train(i)
    model_eval.test(i)

model_eval.print_mean_accuracy()

Pre-processing node features


  r_inv = np.power(rowsum, -1).flatten()
  self._set_arrayXarray(i, j, x)


Accuracy after training on the task:
accuracy task 0 = 0.838
accuracy task 1 = 0.815
accuracy task 2 = 0.829
mean_accuracy = 0.640


In [5]:
model_eval = ModelEval(verbose=False)

model_eval.train(0, )
model_eval.train(1)
model_eval.train(2)

print("Accuracy after training on all 3 tasks:")
model_eval.test(0)
model_eval.test(1)
model_eval.test(2)

model_eval.print_mean_accuracy()
model_eval.print_backward_transfer()

Pre-processing node features


  r_inv = np.power(rowsum, -1).flatten()
  self._set_arrayXarray(i, j, x)


Accuracy after training on all 3 tasks:
accuracy task 0 = 0.359
accuracy task 1 = 0.450
accuracy task 2 = 0.829
mean_accuracy = 0.641
backward_transfer = -1.120


In [10]:
model_eval = ModelEval(continual_learning=True, lambda_=1e4,
                       verbose=True, print_freq=50, print_previous_val=True)

num_tasks = 3

for i in range(num_tasks):
    print(f"Training on task {i}")
    model_eval.train(i)
    print()

for i in range(num_tasks):
    print(f"Testing on task {i}")
    model_eval.test(i)
    print()

model_eval.print_mean_accuracy()
model_eval.print_backward_transfer()

Pre-processing node features
Training on task 0
epoch=0, train_loss=1.099, train_acc=0.424, val_acc=0.755
epoch=50, train_loss=0.657, train_acc=0.739, val_acc=0.755
epoch=100, train_loss=0.319, train_acc=0.771, val_acc=0.836

Training on task 1
epoch=0, train_loss=1.365, train_acc=0.563, val_acc=0.537, task[0]_val_acc=0.829
epoch=50, train_loss=0.557, train_acc=0.627, val_acc=0.785, task[0]_val_acc=0.621
epoch=100, train_loss=0.338, train_acc=0.742, val_acc=0.836, task[0]_val_acc=0.561

Training on task 2
epoch=0, train_loss=1.905, train_acc=0.409, val_acc=0.435, task[0]_val_acc=0.532, task[1]_val_acc=0.824
epoch=50, train_loss=0.477, train_acc=0.694, val_acc=0.853, task[0]_val_acc=0.383, task[1]_val_acc=0.465
epoch=100, train_loss=0.291, train_acc=0.798, val_acc=0.862, task[0]_val_acc=0.377, task[1]_val_acc=0.448

Testing on task 0
accuracy task 0 = 0.365

Testing on task 1
accuracy task 1 = 0.465

Testing on task 2
accuracy task 2 = 0.839

mean_accuracy = 0.652
backward_transfer = -1