# Test 코드
: Pretrained Weight를 기반으로 Top@1 Accuracy를 구하여 논문에 보고된 내용과의 정합성을 확인합니다.

In [None]:
# !pip install tf-keras

In [1]:
import tensorflow as tf
import numpy as np
from sklearn.metrics import accuracy_score
from tqdm import tqdm
import sys

sys.path.append('../src')

from color_palette_completion.text_color_model.input_data_generator import DataGenerator
from color_palette_completion.text_color_model.model_config import Config

Config_pred = Config.copy()
Config_pred['Dataset_Type'] = 'test'
Config_pred['Batch_Size'] = 1
representation = Config_pred['representation']
Config_pred['representation'] = representation

# 모델 경로 및 불러오기
model_path = '/workspace/text2palette/color_palette_completion/data/trained_model/t2p_ca1_mca1_i10t_stop30_lr0.0002__clip_512d_lab_bins_16_0.4_0.5_0'

print("[INFO] Loading model...")
model = tf.saved_model.load(model_path)

serving_signature = model.signatures['serving_default']
# print("출력 텐서:", list(serving_signature.structured_outputs.keys()))

# # 모델 불러온 후 서명 확인
# print("모델 서명:", list(model.signatures.keys()))
# serving_key = list(model.signatures.keys())[0]  # 일반적으로 "serving_default"
# print("입력 텐서:", list(model.signatures[serving_key].structured_input_signature[1].keys()))

print("[INFO] Loading test dataset...")
dataset = DataGenerator(Config_pred)
id2vocab = dataset.corpus.id2vocab
vocab2id = dataset.corpus.vocab2id

gt_color_list = []
pred_color_list = []

print("[INFO] Starting inference...")
for i in tqdm(range(len(dataset)), desc="Inference"):
    (
        batch_x,
        batch_mlm_mask,
        batch_mcc_mask,
        origin_x,
        batch_segment,
        batch_padding_mask,
        batch_text_embed,
        batch_image_embed,
    ) = dataset[i]

    # SavedModel은 TensorFlow모델이기 때문에 입력값은 Tensor여야 함.
    batch_x_tf = tf.convert_to_tensor(batch_x, dtype=tf.int64)
    batch_mlm_mask_tf = tf.convert_to_tensor(batch_mlm_mask, dtype=tf.float32)
    batch_segment_tf = tf.convert_to_tensor(batch_segment, dtype=tf.int64)
    batch_text_embed_tf = tf.convert_to_tensor(batch_text_embed, dtype=tf.float32)
    batch_image_embed_tf = tf.convert_to_tensor(batch_image_embed, dtype=tf.float32)

    # 서명 기반 호출 사용
    output = serving_signature(
        input_1=batch_x_tf,
        input_2=batch_mlm_mask_tf,
        input_3=batch_segment_tf,
        input_4=batch_text_embed_tf,
        input_5=batch_image_embed_tf
    )

    mlm_logits = output['output_1'].numpy()  # (1, 18, vocab_size) # B, S, Vocab_size
    mlm_mask = batch_mlm_mask[0]             # (18,)
    origin_tokens = origin_x[0]              # (18,)

    for pos in np.where(mlm_mask == 1)[0]:
        pred_token_id = np.argmax(mlm_logits[0][pos])
        gt_token_id = origin_tokens[pos]

        pred_token = id2vocab[pred_token_id]
        gt_token = id2vocab[gt_token_id]

        pred_color_list.append(pred_token)
        gt_color_list.append(gt_token)

# ✅ Accuracy 계산
acc = accuracy_score(gt_color_list, pred_color_list)
print(f"\n✅ MLM prediction Accuracy@1 on masked tokens: {acc * 100:.2f}%")

2025-04-02 00:32:18.079722: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1743553938.096784 3687814 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1743553938.101961 3687814 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1743553938.116981 3687814 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1743553938.116994 3687814 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1743553938.116997 3687814 computation_placer.cc:177] computation placer alr

[INFO] Loading model...


W0000 00:00:1743553941.597471 3687814 gpu_device.cc:2341] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


[INFO] Loading test dataset...
[INFO] Starting inference...


Inference:   0%|          | 0/1712 [00:00<?, ?it/s]


KeyError: 'Mask_num'

In [2]:
print(dataset[1])

(array([[ 12,  13,  64,   8,  10,   0,  80, 141,  18,   2,   2,   0, 141,
         18,   2,   2,   2,   0]]), array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.]]), array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), array([[ 12,  13,  64,   8,  10,   0,  80, 141,  18,   2,   2,   0, 141,
         18,   2,   2,   2,   0]]), array([[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2]]), array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0]]), array([[[ 0.03081548, -0.00844604,  0.00122443, ..., -0.01502745,
          0.04747449,  0.00935382],
        [ 0.01672275, -0.02535274, -0.00559235, ...,  0.02505554,
          0.00573809,  0.02059767],
        [-0.00879186, -0.03690744,  0.02558524, ..., -0.04849922,
         -0.01688528, -0.02458045],
        ...,
        [ 0.        ,  0.        ,  0.        , ...,  0.        ,
          0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        , ...,  0.        ,
        