In [None]:
%load_ext autoreload
%autoreload 2
CUDA_VISIBLE_DEVICES=1

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import argparse
import torch

from tqdm.notebook import tqdm, trange

from diffmask.models.sentiment_classification_sst import (
    BertSentimentClassificationSST,
    MyDataset,
    my_collate_fn,
    my_collate_fn_rationale,
    load_sst
)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--gpu", type=str, default="0")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument(
        "--val_filename",
        type=str,
        # default="datasets/eSNLI/esnli_test.csv"
        # default="datasets/coco_test/train_api.txt"
        default="datasets/sci_chatgpt/test_data/snli_1.0_dev_output_edit.txt"
        # default="datasets/sci_chatgpt/test_data/500_mnli_mis.txt"
        
    )
    parser.add_argument(
        "--val_rationale",
        type=str,
        default="datasets/eSNLI/esnli_test.rationale_idx"
        
)
    parser.add_argument(
        "--model_path",
        type=str,
    
        default="outputs/coco-bert-hjy_snli_ml_15rules_for_distribution_acc/epoch=14-val_acc=0.9637-val_f1=0.8513.ckpt"
        # default="outputs/coco-bert-hjy_snli_ml_19rules_for_distribution_acc/epoch=7-val_acc=0.9448-val_f1=0.7821.ckpt"

    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="coco", # coco, esnli
    )
    
    parser.add_argument("--token_cls", type=bool, default=False, help="Enable token classification")

    hparams, _ = parser.parse_known_args()
    
    torch.manual_seed(hparams.seed)
    
    os.environ["CUDA_VISIBLE_DEVICES"] = hparams.gpu

device = "cuda:0"

model = BertSentimentClassificationSST.load_from_checkpoint(hparams.model_path).to(device)

model.freeze()
model.prepare_data()

In [None]:
import pandas as pd

# Validation 데이터 로드
val_dataset, _ = load_sst(
    hparams.val_filename, None, hparams.dataset, model.hparams.num_labels, hparams.val_rationale, hparams.token_cls
)

# 데이터 필터링: 라벨이 '-'인 행 제거
filtered_dataset = []
skipped_count = 0

for data in val_dataset:
    label = data[1]  # 튜플의 두 번째 항목이 라벨로 가정
    if isinstance(label, str) and label.strip() == '-':  # 라벨이 '-'인 경우
        skipped_count += 1
        continue
    filtered_dataset.append(data)  # 유효한 데이터 추가

print(f"Skipped rows with '-' label: {skipped_count}")

# DataLoader 준비
val_dataloader = torch.utils.data.DataLoader(
    filtered_dataset, batch_size=model.hparams.batch_size, collate_fn=my_collate_fn_rationale, num_workers=8
)

# 이후 코드는 기존 코드와 동일
val_acc, num = [0]*15, [0]*15  # 각 라벨별 정확도
# val_acc, num = [0]*19, [0]*19  # 각 라벨별 정확도
results = []

for i, batch in tqdm(enumerate(val_dataloader), total=len(filtered_dataset) // model.hparams.batch_size):
    inputs = model.tokenizer.batch_encode_plus(batch[0], pad_to_max_length=True, return_tensors='pt').to(device)
    input_ids = inputs['input_ids']
    mask = inputs['attention_mask']
    token_type_ids = inputs['token_type_ids']
    labels = batch[1].to(device)

    logits = model.forward(input_ids, mask, token_type_ids)[0]
    predictions = logits.argmax(-1)  # 예측값

    for logit, label, sentence in zip(predictions, labels, batch[0]):
        val_acc[label] += (logit == label).int()
        num[label] += 1

        # 전제-가설-원본라벨-예측라벨 저장
        premise, hypothesis = sentence
        results.append({
            "Premise": premise,
            "Hypothesis": hypothesis,
            "Original Label": label.item(),
            "Predicted Label": logit.item()
        })

# 정확도 출력
print(f"Overall Accuracy: {sum(val_acc) / sum(num)}")

# 데이터프레임으로 변환
results_df = pd.DataFrame(results)

# 라벨 역매핑 추가
label_idx = {
    'entailment_HS': 0, 'entailment_PS': 1, 'entailment_COUNT': 2, 'entailment_PA': 3, 'entailment_ES': 4,
    'contradiction_CW_adj': 5, 'contradiction_CW_noun': 6, 'contradiction_CV': 7, 'contradiction_NS': 8,
    'contradiction_SOS': 9, 'contradiction_IH': 10, 'contradiction_NI': 11,
    'neutral_AM': 12, 'neutral_CON': 13, 'neutral_SSNCV': 14
}

# label_idx = {
#     'entailment_HS': 0, 'entailment_PS': 1, 'entailment_COUNT': 2, 'entailment_PA': 3, 'entailment_ES': 4,
#     'contradiction_CW_adj': 5, 'contradiction_CW_noun': 6, 'contradiction_CV': 7, 'contradiction_NS': 8,
#     'contradiction_SOS': 9, 'contradiction_IH': 10, 'contradiction_NI': 11,
#     'neutral_AM': 12, 'neutral_CON': 13, 'neutral_SSNCV': 14, 'neutral_CA': 15, 'neutral_EI' : 16, 
#             'entailment_RG': 17, 'neutral_VS' : 18
# }

reverse_label_idx = {v: k for k, v in label_idx.items()}

# 라벨을 문자열로 변환
results_df["Original Label"] = results_df["Original Label"].map(reverse_label_idx)
results_df["Predicted Label"] = results_df["Predicted Label"].map(reverse_label_idx)

# 엑셀 파일로 저장
output_file = "ml_outputs/devset_10000_15rules_distribution2.xlsx"
results_df.to_excel(output_file, index=False)

print(f"Results saved to {output_file}")
