In [6]:
%load_ext autoreload
%autoreload 2

import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.metrics import CategoricalAccuracy
import matplotlib.pyplot as plt
from os import listdir

from models import *
from utils import *

import warnings
from scipy.sparse import SparseEfficiencyWarning
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=SparseEfficiencyWarning)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [15]:
class ModelEval():
    def __init__(self, print_freq=10, epochs=100, continual_learning=False,
                 lambda_=0.1, lr=0.001, verbose=True, print_previous_val=False):
        super(ModelEval, self).__init__()

        self.print_freq = print_freq
        self.tasks = load_tasks()
        self.prior_weights = dict()
        self.fisher_matrix = dict()
        self.loss_fn = CategoricalCrossentropy()
        self.mean_accuracy = GlobalAccuracy()
        self.lr = lr
        self.continual_learning = continual_learning
        self.epochs = epochs
        self.lambda_ = lambda_
        self.model = None
        self.verbose = verbose
        self.print_previous_val = print_previous_val
    
    def train(self, task_num):
        (x, adj, y), (mask_tr, _, _) = self.tasks[task_num]

        # initialize model if not already trained
        if not self.continual_learning or self.model is None:
            self.model = GNNNodeClassifier()
        
        accuracy = CategoricalAccuracy()
        optim = Adam(learning_rate=self.lr)
        
        for i in range(self.epochs+1):
            with tf.GradientTape() as tape:
                pred = self.model([x, adj])
                loss = self.loss_fn(y[mask_tr], pred[mask_tr])

                # if continual learning is enabled add l2 loss wrt previous weights
                if self.continual_learning:
                    for j in range(task_num):
                        fisher_mat = self.fisher_matrix[j]
                        theta = self.prior_weights[j]
                        theta_new = flat_tensor_list(self.model.trainable_variables)
                        penalty = self.lambda_ * tf.reduce_sum(fisher_mat * ((theta-theta_new)**2))
                        loss += penalty
                
            accuracy.update_state(y_pred=pred[mask_tr], y_true=y[mask_tr])

            gradients = tape.gradient(loss, self.model.trainable_variables)
            optim.apply_gradients(zip(gradients, self.model.trainable_variables))

            acc = accuracy.result().numpy()
            
            if i % self.print_freq == 0 and self.verbose:
                val_acc = self.test(task_num, validation=True, print_acc=False)
                s = f"epoch={i}, train_loss={loss:.3f}, train_acc={acc:.3f}, val_acc={val_acc:.3f}"
                
                if self.print_previous_val:
                    for j in range(task_num):
                        prev_acc = self.test(j, validation=True, print_acc=False)
                        s += f", task[{j}]_val_acc={prev_acc:.3f}"
                
                print(s)
        
        # after training on current task add accuracy on previous tasks
        for i in range(task_num+1):
            (x_prev, a_prev, y_prev), (_, _, mask_te_prev) = self.tasks[i]
            pred_prev = self.model([x_prev, a_prev])
            self.mean_accuracy.update_state(y_pred=pred_prev[mask_te_prev],
                                            y_true=y_prev[mask_te_prev])

        # save the flatted list of parameters and fisher matrix of current model
        self.prior_weights[task_num] = flat_tensor_list(self.model.trainable_variables)
        self.fisher_matrix[task_num] = self.compute_fisher(task_num)
    
    def compute_fisher(self, task_num):
        with tf.GradientTape() as tape:
            (x, a, y), (mask_tr, _, _) = self.tasks[task_num]
            pred = self.model([x, a])
            loss = self.loss_fn(y_pred=pred[mask_tr], y_true=y[mask_tr])
        
        gradients = tape.gradient(loss, self.model.trainable_variables)

        return flat_tensor_list(gradients)**2

    def test(self, task_num, validation=False, print_acc=True):
        (x, adj, y), (_, mask_va, mask_te) = self.tasks[task_num]

        mask = [mask_te, mask_va][validation]

        pred = self.model([x, adj])
        accuracy = CategoricalAccuracy()
        accuracy.update_state(y_pred=pred[mask], y_true=y[mask])
        val = accuracy.result().numpy()
        
        if print_acc:
            print(f"accuracy task {task_num} = {val:.3f}")
        
        return val
    
    def print_mean_accuracy(self):
        acc = self.mean_accuracy.result().numpy()
        print(f"mean_accuracy = {acc:.3f}")

In [81]:
n_tasks = 2
epochs = 80

model_eval = ModelEval(epochs=epochs, continual_learning=False, verbose=False)

print("Accuracy after training on the task:")
for i in range(n_tasks):
    model_eval.train(i)
    model_eval.test(i)

model_eval.print_mean_accuracy()

Pre-processing node features
Accuracy after training on the task:
accuracy task 0 = 0.788
accuracy task 1 = 0.851
mean_accuracy = 0.654


In [65]:
model_eval = ModelEval(verbose=False)

model_eval.train(0, )
model_eval.train(1)
#model_eval.train(2)

print("Accuracy after training on all 3 tasks:")
model_eval.test(0)
model_eval.test(1)
#model_eval.test(2)

model_eval.print_mean_accuracy()

Pre-processing node features
Accuracy after training on all 3 tasks:
accuracy task 0 = 0.285
accuracy task 1 = 0.886
mean_accuracy = 0.638


In [16]:
model_eval = ModelEval(continual_learning=True, lr=0.001, lambda_=1e4,
                       verbose=True, epochs=100, print_freq=25, print_previous_val=True)

num_tasks = 2

for i in range(num_tasks):
    print(f"Training on task {i}")
    model_eval.train(i)
    print()

for i in range(num_tasks):
    print(f"Testing on task {i}")
    model_eval.test(i)
    print()

model_eval.print_mean_accuracy()

Pre-processing node features
Training on task 0
epoch=0, train_loss=0.693, train_acc=0.686, val_acc=0.691
epoch=25, train_loss=0.589, train_acc=0.691, val_acc=0.691
epoch=50, train_loss=0.430, train_acc=0.691, val_acc=0.691
epoch=75, train_loss=0.283, train_acc=0.734, val_acc=0.735
epoch=100, train_loss=0.176, train_acc=0.780, val_acc=0.750

Training on task 1
epoch=0, train_loss=5.001, train_acc=0.529, val_acc=0.510, task[0]_val_acc=0.750
epoch=25, train_loss=1.118, train_acc=0.456, val_acc=0.460, task[0]_val_acc=0.309
epoch=50, train_loss=0.567, train_acc=0.529, val_acc=0.580, task[0]_val_acc=0.338
epoch=75, train_loss=0.450, train_acc=0.610, val_acc=0.730, task[0]_val_acc=0.353
epoch=100, train_loss=0.361, train_acc=0.664, val_acc=0.770, task[0]_val_acc=0.338

Testing on task 0
accuracy task 0 = 0.314

Testing on task 1
accuracy task 1 = 0.777

mean_accuracy = 0.651
