In [3]:
import tensorflow as tf
from tensorflow.keras.optimizers.legacy import Adamax
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.metrics import CategoricalAccuracy
from tensorflow.keras.metrics import Mean

In [None]:
import pandas as pd
import numpy as np

In [4]:
from copy import deepcopy

In [8]:
def compute_ewc_penalty(model, fisher_matrix, optimal_weights, lamb):   
    loss = 0
    current = model.trainable_weight 
    
    for F, c, o in zip(fisher_matrix, current, optimal_weights):
        loss += tf.reduce_sum(F * ((c - o) ** 2))


    return loss * (lamb / 2)

In [9]:
def ewc_loss(model, fisher_matrix, lamb):
    optimal_weights = deepcopy(model.trainable_weights)

    def loss_fn(model, y_true, y_pred):

        ce_loss = CategoricalCrossentropy(from_logits=False)(y_true, y_pred)
        ewc_loss = compute_ewc_penalty(model, fisher_matrix, optimal_weights, lamb)

        return ce_loss + ewc_loss
    
    return loss_fn

In [None]:
def compute_fisher_matrix(model, data, num_sample=10):

    weights = model.trainable_weights
    variance = [tf.zeros_like(tensor) for tensor in weights]

    for _ in range(num_sample):
        data = tf.expand_dims(data, axis=0) # 0번 축을 추가 

        with tf.GradientTape() as tape:
            output = model(data)
            log_likelihood = tf.math.log(output)

        gradients = tape.gradient(log_likelihood, weights)
        variance = [var + (grad ** 2) for var, grad in zip(variance, gradients)]

    fisher_matrix = [tensor / num_sample for tensor in variance]
    
    return fisher_matrix

In [None]:
def train_loop(model, data, F):

    return 0

In [None]:
'''
def compute_fisher_info_matrix(model, task_set, num_batches=197, batch_size=32):
    
    matrix = 0
    return matrix

'''