In [1]:
# %load_ext tensorboard
# fmt: off
from Utilities.Interpretability.ModelAugmentation import *
from Utilities.Interpretability.InterpretabilityMethods import *
from Utilities.Tasks.CIFAR10ClassificationTask import CIFAR10ClassificationTask as Task
from Utilities.SequentialLearning.EWC_Methods.EWC_Methods import *
from Utilities.SequentialLearning.EWC_Methods.ImportanceMeasureOrderings import validation_loss_by_importance_threshold
from Utilities.Interpretability.ModelAugmentation import ComparisonMethod, AggregationLevel, AggregationMethod
import pandas as pd
import matplotlib.pyplot as plt


import os
import shutil
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
# fmt: on
print(tf.config.list_physical_devices('GPU'))

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [2]:
MODEL_SAVE_PATH = "models/CIFAR10_medium_model/"
LOAD_MODEL = False
TRAIN_MODEL = not LOAD_MODEL
RUN_EAGERLY = False

image_size = Task.IMAGE_SIZE
task_labels = [0,1,2,3,4,5,6,7,8,9]
model_input_shape = image_size
training_batches = 0
validation_batches = 0
batch_size = 32

# Model Training and Architecture

In [3]:
model: tf.keras.Model
if LOAD_MODEL and os.path.exists(MODEL_SAVE_PATH):
    # Try to load model directly, if one exists
    print("LOADING MODEL")
    model = tf.keras.models.load_model(MODEL_SAVE_PATH, compile=False)  # type: ignore
else:
    # Otherwise, make an entire new model!
    print("CREATING MODEL")
    model_inputs = model_layer = tf.keras.Input(shape=model_input_shape)
    model_layer = tf.keras.layers.Conv2D(32, (3,3), activation="relu", name="conv2d_0")(model_layer)
    model_layer = tf.keras.layers.Conv2D(32, (3,3), activation="relu", name="conv2d_1")(model_layer)
    model_layer = tf.keras.layers.Conv2D(64, (3,3), activation="relu", name="conv2d_2")(model_layer)
    model_layer = tf.keras.layers.BatchNormalization()(model_layer)
    model_layer = tf.keras.layers.MaxPool2D((2,2))(model_layer)
    model_layer = tf.keras.layers.Conv2D(64, (3,3), activation="relu", name="conv2d_3")(model_layer)
    model_layer = tf.keras.layers.Conv2D(128, (3,3), activation="relu", name="conv2d_4")(model_layer)
    model_layer = tf.keras.layers.Conv2D(128, (3,3), activation="relu", name="conv2d_5")(model_layer)
    model_layer = tf.keras.layers.BatchNormalization()(model_layer)
    # model_layer = tf.keras.layers.MaxPool2D((2,2))(model_layer)
    model_layer = tf.keras.layers.Conv2D(128, (3,3), activation="relu", name="conv2d_6")(model_layer)
    model_layer = tf.keras.layers.Conv2D(128, (3,3), activation="relu", name="conv2d_7")(model_layer)
    model_layer = tf.keras.layers.Conv2D(256, (3,3), activation="relu", name="conv2d_8")(model_layer)
    model_layer = tf.keras.layers.Flatten()(model_layer)
    model_layer = tf.keras.layers.Dense(128, activation="relu")(model_layer)
    model_layer = tf.keras.layers.Dropout(0.2)(model_layer)
    model_layer = tf.keras.layers.Dense(128, activation="relu")(model_layer)
    model_layer = tf.keras.layers.Dropout(0.2)(model_layer)
    model_layer = tf.keras.layers.Dense(len(task_labels))(model_layer)
    model = tf.keras.Model(inputs=model_inputs, outputs=model_layer, name="base_model")
if len(task_labels) == 2:
    loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)
else:
    loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
model.summary()

CREATING MODEL
Model: "base_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 conv2d_0 (Conv2D)           (None, 30, 30, 32)        896       
                                                                 
 conv2d_1 (Conv2D)           (None, 28, 28, 32)        9248      
                                                                 
 conv2d_2 (Conv2D)           (None, 26, 26, 64)        18496     
                                                                 
 batch_normalization (BatchN  (None, 26, 26, 64)       256       
 ormalization)                                                   
                                                                 
 max_pooling2d (MaxPooling2D  (None, 13, 13, 64)       0         
 )                                       

In [4]:
training_image_augmentation = None
training_image_augmentation = tf.keras.Sequential([
    # tf.keras.layers.RandomFlip("horizontal"),
    tf.keras.layers.RandomZoom(
            height_factor=(-0.05, -0.25),
            width_factor=(-0.05, -0.25)),
    tf.keras.layers.RandomRotation(0.05, "constant")
])

task = Task(
        name=f"Task 0",
        model=model,
        model_base_loss=loss_fn,
        task_labels=task_labels,
        training_batches = 0,
        validation_batches = 0,
        batch_size=batch_size,
        training_image_augmentation = training_image_augmentation,
        run_eagerly=RUN_EAGERLY
    )

In [5]:
ewc_methods = [
    EWC_Method.FISHER_MATRIX,
    EWC_Method.SIGN_FLIPPING,
    EWC_Method.MOMENTUM_BASED,
    EWC_Method.WEIGHT_CHANGE,
    # EWC_Method.WEIGHT_MAGNITUDE,
    # EWC_Method.INVERSE_WEIGHT_MAGNITUDE,
    EWC_Method.RANDOM,
]

aggregation_levels = [
    AggregationLevel.NO_AGGREGATION,
    # AggregationLevel.UNIT,
    AggregationLevel.CONV_FILTER,
]

# ewc_methods = [EWC_Method.WEIGHT_CHANGE, EWC_Method.INVERSE_WEIGHT_MAGNITUDE, EWC_Method.RANDOM, EWC_Method.FISHER_MATRIX]
ewc_term_creators = [EWC_Term_Creator(ewc_method, model, callback_kwargs={"reset_on_train_begin": False}) for ewc_method in ewc_methods]

# Add all callbacks from all terms to the callback list
callbacks = []
for ewc_term_creator in ewc_term_creators:
    for k,v in ewc_term_creator.callback_dict.items():
        callbacks.append(v)

In [7]:
column_names = ["Epoch", "EWC Method", "Aggregation Level", "Threshold Value", "Loss", "Validation Loss"]
all_results_dataframe = pd.DataFrame(columns=column_names)
num_samples = 25
num_epochs = 25
sample_period = 5
sample_array = [(1/num_samples) * i for i in range(num_samples+1)]

In [9]:
def measure_val_loss_over_threshold(epoch_number, ewc_term_creators: List[EWC_Term_Creator]):
    epoch_results = pd.DataFrame(columns=column_names)
    for ewc_term_creator in ewc_term_creators:
        for aggregation_level in aggregation_levels:
            try:
                print(f"CURRENT TERM: {ewc_term_creator.ewc_method.name}, AGGREGATION LEVEL: {aggregation_level.name}")
                ewc_term = ewc_term_creator.create_term(ewc_lambda = 1, task=task)
                method_results = validation_loss_by_importance_threshold(task, ewc_term.omega_matrix, sample_array, aggregation_level=aggregation_level, show_plot=False)
                method_results["EWC Method"] = ewc_term_creator.ewc_method.name
                method_results["Aggregation Level"] = aggregation_level.name
                method_results["Epoch"] = epoch_number
                epoch_results = pd.concat([epoch_results, method_results], ignore_index=True)
            except Exception as e:
                print(f"EXCEPTION {e}")
                continue
    return epoch_results

epoch_index = 0
while epoch_index < num_epochs:
    epoch_index += sample_period
    task.train_on_task(epochs=sample_period, callbacks=callbacks)

    all_results_dataframe = pd.concat([all_results_dataframe, measure_val_loss_over_threshold(epoch_index, ewc_term_creators)], ignore_index=True)
    all_results_dataframe.to_csv("data/validation_loss_over_threshold.csv")
    task.model.save(filepath=MODEL_SAVE_PATH)

Epoch 1/5




Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
CURRENT TERM: FISHER_MATRIX, AGGREGATION LEVEL: NO_AGGREGATION
CURRENT TERM: FISHER_MATRIX, AGGREGATION LEVEL: CONV_FILTER                                         
CURRENT TERM: SIGN_FLIPPING, AGGREGATION LEVEL: NO_AGGREGATION                                      
CURRENT TERM: SIGN_FLIPPING, AGGREGATION LEVEL: CONV_FILTER                                         
CURRENT TERM: MOMENTUM_BASED, AGGREGATION LEVEL: NO_AGGREGATION                                     
CURRENT TERM: MOMENTUM_BASED, AGGREGATION LEVEL: CONV_FILTER                                        
CURRENT TERM: WEIGHT_CHANGE, AGGREGATION LEVEL: NO_AGGREGATION                                      
CURRENT TERM: WEIGHT_CHANGE, AGGREGATION LEVEL: CONV_FILTER                                         
CURRENT TERM: RANDOM, AGGREGATION LEVEL: NO_AGGREGATION                                             
CURRENT TERM: RANDOM, AGGREGATION LEVEL: CONV_FILTER                                     



INFO:tensorflow:Assets written to: models/CIFAR10_medium_model/assets


INFO:tensorflow:Assets written to: models/CIFAR10_medium_model/assets


Epoch 1/5




Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
CURRENT TERM: FISHER_MATRIX, AGGREGATION LEVEL: NO_AGGREGATION
CURRENT TERM: FISHER_MATRIX, AGGREGATION LEVEL: CONV_FILTER                                         
CURRENT TERM: SIGN_FLIPPING, AGGREGATION LEVEL: NO_AGGREGATION                                      
CURRENT TERM: SIGN_FLIPPING, AGGREGATION LEVEL: CONV_FILTER                                         
CURRENT TERM: MOMENTUM_BASED, AGGREGATION LEVEL: NO_AGGREGATION                                     
CURRENT TERM: MOMENTUM_BASED, AGGREGATION LEVEL: CONV_FILTER                                        
CURRENT TERM: WEIGHT_CHANGE, AGGREGATION LEVEL: NO_AGGREGATION                                      
CURRENT TERM: WEIGHT_CHANGE, AGGREGATION LEVEL: CONV_FILTER                                         
CURRENT TERM: RANDOM, AGGREGATION LEVEL: NO_AGGREGATION                                             
CURRENT TERM: RANDOM, AGGREGATION LEVEL: CONV_FILTER                                     



INFO:tensorflow:Assets written to: models/CIFAR10_medium_model/assets


INFO:tensorflow:Assets written to: models/CIFAR10_medium_model/assets


Epoch 1/5




Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
CURRENT TERM: FISHER_MATRIX, AGGREGATION LEVEL: NO_AGGREGATION
CURRENT TERM: FISHER_MATRIX, AGGREGATION LEVEL: CONV_FILTER                                         
CURRENT TERM: SIGN_FLIPPING, AGGREGATION LEVEL: NO_AGGREGATION                                      
CURRENT TERM: SIGN_FLIPPING, AGGREGATION LEVEL: CONV_FILTER                                         
CURRENT TERM: MOMENTUM_BASED, AGGREGATION LEVEL: NO_AGGREGATION                                     
CURRENT TERM: MOMENTUM_BASED, AGGREGATION LEVEL: CONV_FILTER                                        
CURRENT TERM: WEIGHT_CHANGE, AGGREGATION LEVEL: NO_AGGREGATION                                      
CURRENT TERM: WEIGHT_CHANGE, AGGREGATION LEVEL: CONV_FILTER                                         
CURRENT TERM: RANDOM, AGGREGATION LEVEL: NO_AGGREGATION                                             
CURRENT TERM: RANDOM, AGGREGATION LEVEL: CONV_FILTER                                     



INFO:tensorflow:Assets written to: models/CIFAR10_medium_model/assets


INFO:tensorflow:Assets written to: models/CIFAR10_medium_model/assets


Epoch 1/5




Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
CURRENT TERM: FISHER_MATRIX, AGGREGATION LEVEL: NO_AGGREGATION
CURRENT TERM: FISHER_MATRIX, AGGREGATION LEVEL: CONV_FILTER                                         
CURRENT TERM: SIGN_FLIPPING, AGGREGATION LEVEL: NO_AGGREGATION                                      
CURRENT TERM: SIGN_FLIPPING, AGGREGATION LEVEL: CONV_FILTER                                         
CURRENT TERM: MOMENTUM_BASED, AGGREGATION LEVEL: NO_AGGREGATION                                     
CURRENT TERM: MOMENTUM_BASED, AGGREGATION LEVEL: CONV_FILTER                                        
CURRENT TERM: WEIGHT_CHANGE, AGGREGATION LEVEL: NO_AGGREGATION                                      
CURRENT TERM: WEIGHT_CHANGE, AGGREGATION LEVEL: CONV_FILTER                                         
CURRENT TERM: RANDOM, AGGREGATION LEVEL: NO_AGGREGATION                                             
CURRENT TERM: RANDOM, AGGREGATION LEVEL: CONV_FILTER                                     



INFO:tensorflow:Assets written to: models/CIFAR10_medium_model/assets


INFO:tensorflow:Assets written to: models/CIFAR10_medium_model/assets


Epoch 1/5




Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
CURRENT TERM: FISHER_MATRIX, AGGREGATION LEVEL: NO_AGGREGATION
CURRENT TERM: FISHER_MATRIX, AGGREGATION LEVEL: CONV_FILTER                                         
CURRENT TERM: SIGN_FLIPPING, AGGREGATION LEVEL: NO_AGGREGATION                                      
CURRENT TERM: SIGN_FLIPPING, AGGREGATION LEVEL: CONV_FILTER                                         
CURRENT TERM: MOMENTUM_BASED, AGGREGATION LEVEL: NO_AGGREGATION                                     
CURRENT TERM: MOMENTUM_BASED, AGGREGATION LEVEL: CONV_FILTER                                        
CURRENT TERM: WEIGHT_CHANGE, AGGREGATION LEVEL: NO_AGGREGATION                                      
CURRENT TERM: WEIGHT_CHANGE, AGGREGATION LEVEL: CONV_FILTER                                         
CURRENT TERM: RANDOM, AGGREGATION LEVEL: NO_AGGREGATION                                             
CURRENT TERM: RANDOM, AGGREGATION LEVEL: CONV_FILTER                                     



INFO:tensorflow:Assets written to: models/CIFAR10_medium_model/assets


INFO:tensorflow:Assets written to: models/CIFAR10_medium_model/assets


In [10]:
all_results_dataframe

Unnamed: 0,Epoch,EWC Method,Aggregation Level,Threshold Value,Loss,Validation Loss
0,5,FISHER_MATRIX,NO_AGGREGATION,0.00,0.866979,0.979571
1,5,FISHER_MATRIX,NO_AGGREGATION,0.04,0.866580,0.979861
2,5,FISHER_MATRIX,NO_AGGREGATION,0.08,0.863429,0.978296
3,5,FISHER_MATRIX,NO_AGGREGATION,0.12,0.861961,0.973872
4,5,FISHER_MATRIX,NO_AGGREGATION,0.16,0.860800,0.969991
...,...,...,...,...,...,...
1295,25,RANDOM,CONV_FILTER,0.84,2.308533,2.308608
1296,25,RANDOM,CONV_FILTER,0.88,2.308389,2.308364
1297,25,RANDOM,CONV_FILTER,0.92,2.305713,2.305665
1298,25,RANDOM,CONV_FILTER,0.96,2.302805,2.302810
