In [7]:
import import_ipynb
import ewc
from ewc import *

import utils
from utils import *

In [9]:
# joint.ipynb (EWC 기반 Joint 학습)

import numpy as np
import pandas as pd
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.optimizers.legacy import Adamax
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import to_categorical



# 설정
lamb = 100
num_sample = 1000
epochs=30
MAX_LABEL =  95

# 데이터 불러오기
def load_mon_data():
    data_a = pd.read_pickle('mon_data_A.pkl')
    data_b = pd.read_pickle('mon_data_B.pkl')
    return data_a, data_b


# 데이터 로드 및 전처리
data_a, data_b = load_mon_data()

# joint_A에서 학습한 모델 및 정보 불러오기
model = load_model("joint_model.h5")
fisher_matrix = np.load("fisher_matrix.npy", allow_pickle=True)
optimal_weights = np.load("optimal_weights.npy", allow_pickle=True)

# mon_data_A + mon_data_B로 EWC 기반 학습 수행
x_a, y_a = to_input(data_a, MAX_LABEL)
x_b, y_b = to_input(data_b, MAX_LABEL)
x_total = np.concatenate([x_a, x_b], axis=0)
y_total = np.concatenate([y_a, y_b], axis=0)

x_train, x_test, y_train, y_test = train_test_split(x_total, y_total, test_size=0.2, random_state=42)

# EWC 손실로 모델 재컴파일
OPTIMIZER = Adamax(learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
model.compile(loss=ewc_loss(model, fisher_matrix, lamb=lamb, optimal_weights=optimal_weights),
              optimizer=OPTIMIZER, metrics=["accuracy"])

# 학습
model.fit(x_train, y_train,
          epochs=epochs,
          verbose=2)

# 평가: mon_data_A + mon_data_B 전체
loss, acc = model.evaluate(x_test, y_test, verbose=0)
print(f"[Joint EWC] Accuracy on A+B after training on B: {acc:.4f}")


Epoch 1/30
CE loss: 12.1431122 EWC loss: 0
CE loss: 9.28301525 EWC loss: 6.54419231
CE loss: 8.45622063 EWC loss: 3.85088205
CE loss: 10.470211 EWC loss: 3.15782976
CE loss: 16.3025379 EWC loss: 3.92340422
CE loss: 12.6091843 EWC loss: 4.66902828
CE loss: 10.7978916 EWC loss: 4.69159508
CE loss: 9.93146324 EWC loss: 4.19870853
CE loss: 9.31717587 EWC loss: 3.65431976
CE loss: 10.7038059 EWC loss: 3.3407948
CE loss: 9.6424017 EWC loss: 3.31379771
CE loss: 6.74866867 EWC loss: 3.43090367
CE loss: 11.782 EWC loss: 3.53229427
CE loss: 10.986433 EWC loss: 3.53626919
CE loss: 12.3716507 EWC loss: 3.41526818
CE loss: 9.06216145 EWC loss: 3.21316242
CE loss: 9.42837334 EWC loss: 3.00231528
CE loss: 8.32787228 EWC loss: 2.8415997
CE loss: 11.0711975 EWC loss: 2.75730371
CE loss: 8.14468765 EWC loss: 2.73578286
CE loss: 9.58660698 EWC loss: 2.73672724
CE loss: 9.9973774 EWC loss: 2.72778
CE loss: 7.50578165 EWC loss: 2.68887663
CE loss: 8.02502918 EWC loss: 2.62123179
CE loss: 11.2111492 EWC los