In [None]:
import pandas as pd
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import random
from datasets import load_from_disk
import json

# 결과 값 분석 - Valid

In [None]:
def find_ans(txt):
    return txt['text'][0]
def valid_result(dataset_path, pred_path, print_result, return_result):
    dataset = load_from_disk(dataset_path)
    valid_dataset = dataset["validation"]
    valid_df=pd.DataFrame(valid_dataset)
    valid_df['con_len'] = valid_df['context'].apply(lambda x: len(x))
    valid_df['qu_len'] = valid_df['question'].apply(lambda x: len(x))
    valid_df['ans_len']=[len(valid_df['answers'][i]['text'][0]) for i in range(len(valid_df))]

    valid_df['answer_text'] = valid_df["answers"].apply(find_ans)

    with open(pred_path) as f:
        prediction = json.load(f)
    pred_id = []
    pred_ans = []
    for k,v in enumerate(prediction):
        pred_id.append(v)
        pred_ans.append(prediction[v])
    pred_df = pd.DataFrame({'id': pred_id, 'pred_ans': pred_ans})

    valid_pred_df = pd.merge(valid_df, pred_df, on='id')
    valid_pred_df = valid_pred_df[['title','context','question','answer_text','pred_ans','con_len','qu_len','ans_len']]
    corect_df = valid_pred_df.loc[valid_pred_df['answer_text']==valid_pred_df['pred_ans']]
    incorect_df = valid_pred_df.loc[valid_pred_df['answer_text']!=valid_pred_df['pred_ans']]
    
    if print_result:
        plt.figure(figsize=(10, 5))
        plt.subplots_adjust(left=0.125, bottom=0.1, right=0.9, top=0.9, wspace=0.2, hspace=1)

        plt.subplot(3,2,1)
        plt.plot(sorted(corect_df["con_len"]))
        plt.xlabel("Data index")
        plt.ylabel("Length of Data")
        plt.title("Length of Correct Context Data")

        plt.subplot(3,2,2)
        plt.plot(sorted(incorect_df["con_len"]))
        plt.xlabel("Data index")
        plt.ylabel("Length of Data")
        plt.title("Length of Incorrect Context Data")

        plt.subplot(3,2,3)
        plt.plot(sorted(corect_df["qu_len"]))
        plt.xlabel("Data index")
        plt.ylabel("Length of Data")
        plt.title("Length of Correct Question Data")

        plt.subplot(3,2,4)
        plt.plot(sorted(incorect_df["qu_len"]))
        plt.xlabel("Data index")
        plt.ylabel("Length of Data")
        plt.title("Length of Incorrect Question Data")

        plt.subplot(3,2,5)
        plt.plot(sorted(corect_df["ans_len"]))
        plt.xlabel("Data index")
        plt.ylabel("Length of Data")
        plt.title("Length of Correct Answer Data")

        plt.subplot(3,2,6)
        plt.plot(sorted(incorect_df["ans_len"]))
        plt.xlabel("Data index")
        plt.ylabel("Length of Data")
        plt.title("Length of Incorrect Answer Data")

        plt.show()
        
        plt.figure(figsize=(10, 5))
        plt.subplots_adjust(left=0.125, bottom=0.1, right=0.9, top=0.9, wspace=0.2, hspace=1)

        plt.subplot(3,2,1)
        plt.hist(sorted(corect_df["con_len"]))
        plt.xlabel("Length")
        plt.ylabel("Frequency")
        plt.title("Frequency of Correct Context Data")

        plt.subplot(3,2,2)
        plt.hist(sorted(incorect_df["con_len"]))
        plt.xlabel("Length")
        plt.ylabel("Frequency")
        plt.title("Frequency of Incorrect Context Data")

        plt.subplot(3,2,3)
        plt.hist(sorted(corect_df["qu_len"]))
        plt.xlabel("Length")
        plt.ylabel("Frequency")
        plt.title("Frequency of Correct Question Data")

        plt.subplot(3,2,4)
        plt.hist(sorted(incorect_df["qu_len"]))
        plt.xlabel("Length")
        plt.ylabel("Frequency")
        plt.title("Frequency of Incorrect Question Data")

        plt.subplot(3,2,5)
        plt.hist(sorted(corect_df["ans_len"]))
        plt.xlabel("Length")
        plt.ylabel("Frequency")
        plt.title("Frequency of Correct Answer Data")

        plt.subplot(3,2,6)
        plt.hist(sorted(incorect_df["ans_len"]))
        plt.xlabel("Length")
        plt.ylabel("Frequency")
        plt.title("Frequency of Incorrect Answer Data")
        
        plt.show()
        
        print('맞춘 개수: ', len(valid_pred_df)-len(incorect_df))
        print(corect_df[["con_len","qu_len","ans_len"]].describe())
        print('틀린 개수: ', len(incorect_df))
        print(incorect_df[["con_len","qu_len","ans_len"]].describe())
    if return_result:
        return corect_df, incorect_df

In [None]:
corect_df, incorect_df = valid_result(dataset_path="/opt/ml/input/data/train_dataset",
                                      pred_path = "/opt/ml/input/code/outputs/train_dataset/predictions.json", 
                                      print_result=True, return_result=True)

In [None]:
corect_df

In [None]:
incorect_df

In [None]:
corect_df, incorect_df = valid_result(dataset_path="/opt/ml/input/data/data_wiki_korquad",
                                      pred_path = "/opt/ml/input/code/outputs/data_wiki_korquad/predictions.json", 
                                      print_result=True, return_result=True)

## SHAP

# 결과 값 분석 - Test

In [None]:
import transformers
import shap
import torch

# load the model
pmodel = transformers.pipeline('question-answering')

# define two predictions, one that outputs the logits for the range start,
# and the other for the range end
def f(questions, start):
    outs = []
    for q in questions:
        question, context = q.split("[SEP]")
        d = pmodel.tokenizer(question, context)
        out = pmodel.model.forward(**{k: torch.tensor(d[k]).reshape(1, -1) for k in d})
        logits = out.start_logits if start else out.end_logits
        outs.append(logits.reshape(-1).detach().numpy())
    return outs
def f_start(questions):
    return f(questions, True)
def f_end(questions):
    return f(questions, False)

In [None]:
data = ["What is on the table?[SEP]When I got home today I saw my cat on the table, and my frog on the floor."]

In [None]:
def make_answer_scorer(answers):
    def f(questions):
        out = []
        for q in questions:
            question, context = q.split("[SEP]")
            results = pmodel(question, context, topk=20)
            values = []
            for answer in answers:
                value = 0
                for result in results:
                    if result["answer"] == answer:
                        value = result["score"]
                        break
                values.append(value)
            out.append(values)
        return out
    f.output_names = answers
    return f

f_answers = make_answer_scorer(["my cat", "cat", "my frog"])
explainer_answers = shap.Explainer(f_answers, pmodel.tokenizer)
shap_values_answers = explainer_answers(data)

shap.plots.text(shap_values_answers)

In [None]:
import argparse

from datamodule.base_data import *
from utils.data_utils import *
from utils.util import *
from omegaconf import OmegaConf
from models.base_model import *
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="base_config")
args, _ = parser.parse_known_args()

cfg = OmegaConf.load(f"/opt/ml/input/code/pl/config/{args.config}.yaml")
pl.seed_everything(cfg.train.seed, workers=True)


# dataloader와 model을 생성합니다.
dataloader = Dataloader(
    cfg.model.model_name,
    cfg.train.batch_size,
    cfg.data.shuffle,
    cfg.path.train_path,
    cfg.path.test_path,
    cfg.train.seed,
    cfg.retrieval,
)

# ckpt_path = "/opt/ml/input/code/pl/output/klue_roberta-large/epoch=3_val_em=70.00_korquad.ckpt"
pt_path = "/opt/ml/input/code/pl/output/large_78.pt"

# for checkpoint
# model = Model(cfg).load_from_checkpoint(checkpoint_path=ckpt_path)

# for pt
model = Model(cfg)
model.load_state_dict(torch.load(pt_path))

# gpu가 없으면 'gpus=0'을, gpu가 여러개면 'gpus=4'처럼 사용하실 gpu의 개수를 입력해주세요
trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    max_epochs=cfg.train.max_epoch,
    log_every_n_steps=cfg.train.logging_step,
    deterministic=True,
)

In [None]:
# load the model
pmodel = transformers.pipeline('question-answering')

# define two predictions, one that outputs the logits for the range start,
# and the other for the range end
def f(questions, start):
    outs = []
    for q in questions:
        question, context = q.split("[SEP]")
        d = pmodel.tokenizer(question, context)
        out = pmodel.model.forward(**{k: torch.tensor(d[k]).reshape(1, -1) for k in d})
        logits = out.start_logits if start else out.end_logits
        outs.append(logits.reshape(-1).detach().numpy())
    return outs
def f_start(questions):
    return f(questions, True)
def f_end(questions):
    return f(questions, False)


data = ["테이블 위에 무엇이 있나요?[SEP]내가 집에 돌아왔을 때 나의 고양이가 테이블 위에 있는 것을 보았고, 개구리가 바닥에 있는 것을 보았다"]
f_answers = make_answer_scorer(["나의 고양이", "고양이", "나의 개구리"])
explainer_answers = shap.Explainer(f_answers, pmodel.tokenizer)
shap_values_answers = explainer_answers(data)

shap.plots.text(shap_values_answers)