In [7]:
import tensorflow as tf
from tensorflow.keras.optimizers.legacy import Adamax
from tensorflow.keras.losses import CategoricalCrossentropy



In [8]:
def evaluate(model, X_test, y_test, batch_size=128):
  acc = tf.keras.metrics.CategoricalAccuracy(name='accuracy')
  # 배치 단위로 데이터를 처리
  for i in range(0, len(X_test), batch_size):
      # 배치 데이터 분리
      direction = X_test[i:i + batch_size]
      labels = y_test[i:i + batch_size]
        
      # 모델 예측
      preds = model.predict_on_batch(direction)
        
      # 정확도 갱신
      acc.update_state(labels, preds)
    
    # 최종 정확도 반환
  return acc.result().numpy()

In [9]:
def compute_precision_matrices(model, X_train, y_train, num_batches=10, batch_size=128):
  
  # 데이터를 배치로 나누기
  dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))  # Feature와 Label을 TensorFlow Dataset으로 묶음
  task_set = dataset.batch(batch_size).repeat()  # 배치로 나누고 반복 가능하게 설정
  
  precision_matrices = {n: tf.zeros_like(p.value()) for n, p in enumerate(model.trainable_variables)}

  for i, (direction, labels) in enumerate(task_set.take(num_batches)):
    # We need gradients of model params
    with tf.GradientTape() as tape:
      # Get model predictions for each image
      preds = model(direction)
      # Get the log likelihoods of the predictions
      ll = tf.nn.log_softmax(preds)
    # Attach gradients of ll to ll_grads
    ll_grads = tape.gradient(ll, model.trainable_variables)
    # Compute F_i as mean of gradients squared
    for i, g in enumerate(ll_grads):
      precision_matrices[i] += tf.math.reduce_mean(g ** 2, axis=0) / num_batches

  return precision_matrices

In [10]:
def compute_elastic_penalty(F, theta, theta_A, alpha=1000):
  penalty = 0
  for i, theta_i in enumerate(theta):
    _penalty = tf.math.reduce_sum(F[i] * (theta_i - theta_A[i]) ** 2)
    penalty += _penalty
  return 0.5*alpha*penalty

In [11]:
def ewc_loss(labels, preds, model, F, theta_A):
  print("model.loss:", model.loss)
  loss_b = model.loss(labels, preds)
  penalty = compute_elastic_penalty(F, model.trainable_variables, theta_A)
  return loss_b + penalty

In [12]:
def train_with_ewc(model, X_A_train, X_A_test, y_A_train, y_A_test, X_B_train, X_B_test, y_B_train, y_B_test, epochs):
  # 모델 생성
  LENGTH = 10000 # Packet sequence length
  OPTIMIZER = Adamax(lr=0.002, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0) # Optimizer
  BATCH_SIZE = 128 # Batch size
  VERBOSE = 2 # Output display mode
  NB_CLASSES = 95 # number of outputs = number of classes
  INPUT_SHAPE = (LENGTH,1)

  print ("Building and training DF model")

  model = model.build(input_shape=INPUT_SHAPE, classes=NB_CLASSES)

  model.compile(loss=CategoricalCrossentropy(from_logits=False), optimizer=OPTIMIZER,
	  metrics=["accuracy"])
  print ("Model compiled")

  model.fit(X_A_train, y_A_train,
		batch_size=BATCH_SIZE, epochs=epochs,
		verbose=VERBOSE)

  theta_A = {n: p.value() for n, p in enumerate(model.trainable_variables.copy())}
  # We'll only compute Fisher once, you can do it whenever
  F = compute_precision_matrices(model, X_A_train, y_A_train, num_batches=128)

  print("Task A accuracy after training on Task A: {}".format(evaluate(model, X_A_test, y_A_test)))
  # Now we set up the training loop for task B with EWC
  accuracy = tf.keras.metrics.CategoricalAccuracy('accuracy')
  loss = tf.keras.metrics.CategoricalCrossentropy('loss')

  for epoch in range(epochs):
    accuracy.reset_states()
    loss.reset_states()

    batch_size = 128  # 배치 크기 정의
    num_batches = len(X_B_train) // batch_size  # 총 배치 수 계산

    for batch in range(num_batches):
        # 배치 데이터 추출
        direction = X_B_train[batch * batch_size:(batch + 1) * batch_size]
        labels = y_B_train[batch * batch_size:(batch + 1) * batch_size]

        with tf.GradientTape() as tape:
            # 예측값 계산
            preds = model(direction)
            # EWC 손실 계산
            total_loss = ewc_loss(labels, preds, model, F, theta_A)

        # 그래디언트 계산
        grads = tape.gradient(total_loss, model.trainable_variables)
        # 모델 파라미터 업데이트
        model.optimizer.apply_gradients(zip(grads, model.trainable_variables))

        # 손실 및 정확도 업데이트
        accuracy.update_state(labels, preds)
        loss.update_state(labels, preds)

        print("\rEpoch: {}, Batch: {}, Loss: {:.3f}, Accuracy: {:.3f}".format(
            epoch + 1, batch + 1, loss.result().numpy(), accuracy.result().numpy()), flush=True, end='')

    print("")


  print("Task B accuracy after training trained model on Task B: {}".format(evaluate(model, X_B_test, y_B_test)))
  print("Task A accuracy after training trained model on Task B: {}".format(evaluate(model, X_A_test, y_A_test)))