# EWC

In [1]:
import tensorflow as tf
from tensorflow.keras.optimizers.legacy import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.metrics import BinaryAccuracy
from tensorflow.keras.metrics import Mean
from tensorflow.keras.utils import to_categorical

In [2]:
import import_ipynb
import pandas as pd
import numpy as np

In [3]:
from copy import deepcopy

In [4]:
import utils
from utils import *

In [5]:
import Model_DomainIL
from Model_DomainIL import *

In [6]:
def evaluate(model, test_set):
  acc = tf.keras.metrics.BinaryAccuracy(name='accuracy')
  for i, (seq, labels) in enumerate(test_set):
    preds = model.predict_on_batch(seq)
    acc.update_state(labels, preds)
  return acc.result().numpy()

In [7]:
def compute_ewc_penalty(model, fisher_matrix, optimal_weights, lamb):   
    loss = 0
    current = model.trainable_weights 
    
    for F, c, o in zip(fisher_matrix, current, optimal_weights):
        loss += tf.reduce_sum(F * ((c - o) ** 2))


    return loss * (lamb / 2)

In [8]:
def ewc_loss(model, fisher_matrix, lamb):
    optimal_weights = deepcopy(model.trainable_weights)

    def loss_fn(y_true, y_pred):

        ce_loss = BinaryCrossentropy(from_logits=False)(y_true, y_pred)
        ewc_loss = compute_ewc_penalty(model, fisher_matrix, optimal_weights, lamb=lamb)

        return ce_loss + ewc_loss
    
    return loss_fn

In [9]:
def compute_fisher_matrix(model, data, num_sample=10):
    
    weights = model.trainable_weights
    variance = [tf.zeros_like(tensor) for tensor in weights]

    indices = np.random.choice(len(data), size=num_sample, replace=False)

    for i in indices:

        with tf.GradientTape() as tape:
            tape.watch(weights)
            x = tf.expand_dims(data[i], axis=0)
            output = model(x, training=False) 
            log_likelihood = tf.math.log(output)

        gradients = tape.gradient(log_likelihood, weights)
        variance = [var + (grad ** 2) for var, grad in zip(variance, gradients)]

    fisher_matrix = [tensor / num_sample for tensor in variance]
    
    return fisher_matrix

In [10]:
def compute_fisher_matrix_empirical(model, data, label, num_sample=100):
    indices = np.random.choice(len(data), size=num_sample, replace=False)
    weights = model.trainable_weights
    variance = [tf.zeros_like(w) for w in weights]


    for i in indices:
        with tf.GradientTape() as tape:
            tape.watch(weights)
            x = tf.expand_dims(data[i], axis=0)
            y = tf.expand_dims(label[i], axis=0) 
            logits = model(x, training=True)
            loss = tf.keras.losses.categorical_crossentropy(y, logits, from_logits=False)

        gradients = tape.gradient(loss, weights)
        for j in range(len(variance)):
            if gradients[j] is not None:
                variance[j] += tf.square(gradients[j])

    fisher_matrix = [v / num_sample for v in variance]

    # [디버깅 추가]
    print("\nlogits : [DEBUG] Fisher matrix shapes:")
    for i, f in enumerate(fisher_matrix):
        print(f" - Fisher {i}: {f.shape}, mean={tf.reduce_mean(f):.4f}, std={tf.math.reduce_std(f):.4f}")

    return fisher_matrix


In [11]:
def compute_fisher_matrix_logits(model, data, label, num_sample=100):
    indices = np.random.choice(len(data), size=num_sample, replace=False)
    weights = model.trainable_weights
    variance = [tf.zeros_like(w) for w in weights]


    for i in indices:
        with tf.GradientTape() as tape:
            tape.watch(weights)
            x = tf.expand_dims(data[i], axis=0)
            y = tf.expand_dims(label[i], axis=0) 
            logits = model(x, training=False)
            class_idx = tf.argmax(logits, axis=1)
            selected_logit = tf.gather_nd(logits, tf.stack([tf.range(tf.shape(class_idx)[0]), class_idx], axis=1))
            # loss = tf.keras.losses.categorical_crossentropy(y, logits, from_logits=False)

        gradients = tape.gradient(selected_logit, weights)
        for j in range(len(variance)):
            if gradients[j] is not None:
                variance[j] += tf.square(gradients[j])

    fisher_matrix = [v / num_sample for v in variance]

    
    # [디버깅 추가]
    print("\nlogits : [DEBUG] Fisher matrix shapes:")
    for i, f in enumerate(fisher_matrix):
        print(f" - Fisher {i}: {f.shape}, mean={tf.reduce_mean(f):.4f}, std={tf.math.reduce_std(f):.4f}")


    
    return fisher_matrix
