In [1]:
# interpreting BERT Models with Captum library
import os
import json
import numpy as np
import pandas as pd
import seaborn as sns 
import matplotlib.pyplot as plt 

import torch 
import torch.nn as nn
from torch.utils.data import DataLoader, random_split

from transformers import DistilBertTokenizer

from captum.attr import visualization as viz
from captum.attr import LayerConductance, LayerIntegratedGradients

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# module
from models.dkvmn_text import SUBJ_DKVMN

from data_loaders.assist2009 import ASSIST2009
from data_loaders.assist2012 import ASSIST2012
from models.utils import collate_fn

In [3]:
# load config
model_name = "dkvmn+"
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
dataset_name = "ASSIST2009"
seq_len = 200

ckpts = f"ckpts/{model_name}2009-performance/{dataset_name}/"

with open("config.json") as f:
    config = json.load(f)
    model_config = config[model_name]
    train_config = config["train_config"]
    
batch_size = train_config["batch_size"]
num_epochs = train_config["num_epochs"]
train_ratio = train_config["train_ratio"]
learning_rate = train_config["learning_rate"]
optimizer = train_config["optimizer"] # can be sgd, adam
seq_len = train_config["seq_len"] # 샘플링 할 갯수

In [4]:
from torch.nn.utils.rnn import pad_sequence

if torch.cuda.is_available():
    from torch.cuda import FloatTensor, CharTensor, LongTensor
    torch.set_default_tensor_type(torch.cuda.FloatTensor)
else:
    from torch import FloatTensor, CharTensor, LongTensor


# custom collate function
def custom_collate(batch, pad_val=-1):
    '''
    This function for torch.utils.data.DataLoader

    Returns:
        q_seqs: the question(KC) sequences with the size of \
            [batch_size, maximum_sequence_length_in_the_batch]
        r_seqs: the response sequences with the size of \
            [batch_size, maximum_sequence_length_in_the_batch]
        qshft_seqs: the question(KC) sequences which were shifted \
            one step to the right with the size of \
            [batch_size, maximum_sequence_length_in_the_batch]
        rshft_seqs: the response sequences which were shifted \
            one step to the right with the size of \
            [batch_size, maximum_sequence_length_in_the_batch]
        mask_seqs: the mask sequences indicating where \
            the padded entry is with the size of \
            [batch_size, maximum_sequence_length_in_the_batch]
    '''

    q_seqs = []
    r_seqs = []
    qshft_seqs = []
    rshft_seqs = []
    at_seqs = []
    atshft_seqs = []
    q2diff_seqs = []
    pid_seqs = []
    pidshft_seqs = []
    hint_seqs = []

    # q_seq와 r_seq는 마지막 전까지만 가져옴 (마지막은 padding value)
    # q_shft와 rshft는 처음 값 이후 가져옴 (우측 시프트 값이므로..)
    for q_seq, r_seq, at_seq, q2diff, pid_seq, hint_seq in batch:
        q_seqs.append(FloatTensor(q_seq[:-1])) 
        r_seqs.append(FloatTensor(r_seq[:-1]))
        at_seqs.append(at_seq[:-1])
        atshft_seqs.append(at_seq[1:])
        qshft_seqs.append(FloatTensor(q_seq[1:]))
        rshft_seqs.append(FloatTensor(r_seq[1:]))
        q2diff_seqs.append(FloatTensor(q2diff[:-1]))
        pid_seqs.append(FloatTensor(pid_seq[:-1]))
        pidshft_seqs.append(FloatTensor(pid_seq[1:]))
        hint_seqs.append(FloatTensor(hint_seq[:-1]))

    # pad_sequence, 첫번째 인자는 sequence, 두번째는 batch_size가 첫 번째로 인자로 오게 하는 것이고, 3번째 인자의 경우 padding된 요소의 값
    # 시퀀스 내 가장 길이가 긴 시퀀스를 기준으로 padding이 됨, 길이가 안맞는 부분은 늘려서 padding_value 값으로 채워줌
    q_seqs = pad_sequence(
        q_seqs, batch_first=True, padding_value=pad_val
    )
    r_seqs = pad_sequence(
        r_seqs, batch_first=True, padding_value=pad_val
    )
    q2diff_seqs = pad_sequence(
        q2diff_seqs, batch_first=True, padding_value=pad_val
    )
    qshft_seqs = pad_sequence(
        qshft_seqs, batch_first=True, padding_value=pad_val
    )
    rshft_seqs = pad_sequence(
        rshft_seqs, batch_first=True, padding_value=pad_val
    )
    pid_seqs = pad_sequence(
        pid_seqs, batch_first=True, padding_value=pad_val
    )
    pidshft_seqs = pad_sequence(
        pidshft_seqs, batch_first=True, padding_value=pad_val
    )
    at_seqs = pad_sequence(
        at_seqs, batch_first=True, padding_value=pad_val
    )

    # 마스킹 시퀀스 생성 
    # 일반 question 시퀀스: 패딩 밸류와 다른 값들은 모두 1로 처리, 패딩 처리된 값들은 0으로 처리.
    # 일반 question padding 시퀀스: 한 칸 옆으로 시프팅 된 시퀀스 값들이 패딩 값과 다를 경우 1로 처리, 패딩 처리 된 값들은 0으로 처리.
    # 마스킹 시퀀스: 패딩 처리 된 시퀀스 밸류들은 모두 0, 두 값 모두 패딩처리 되지 않았을 경우 1로 처리. (원본 시퀀스와 shift 시퀀스 모두의 값)
    # 예를 들어, 현재 값과 다음 값이 패딩 값이 아닐 경우 1, 현재 값과 다음 값 둘 중 하나라도 패딩일 경우 0으로 처리함.
    mask_seqs = (q_seqs != pad_val) * (qshft_seqs != pad_val)

    # 원본 값의 다음 값이(shift value) 패딩이기만 해도 마스킹 시퀀스에 의해 값이 0로 변함. 아닐경우 원본 시퀀스 데이터를 가짐.
    q_seqs, r_seqs, qshft_seqs, rshft_seqs, q2diff_seqs, pid_seqs, pidshft_seqs, at_seqs = \
        q_seqs * mask_seqs, r_seqs * mask_seqs, qshft_seqs * mask_seqs, \
        rshft_seqs * mask_seqs, q2diff_seqs * mask_seqs, pid_seqs * mask_seqs, \
        pidshft_seqs * mask_seqs, at_seqs * mask_seqs
    
    return q_seqs, r_seqs, qshft_seqs, rshft_seqs, mask_seqs, at_seqs, q2diff_seqs, pid_seqs, pidshft_seqs

In [5]:
# load dataset
# 데이터셋 추가 가능
collate_pt = custom_collate
if dataset_name == "ASSIST2009":
    dataset = ASSIST2009(seq_len, 'datasets/ASSIST2009/')
elif dataset_name == "ASSIST2012":
    dataset = ASSIST2012(seq_len, 'datasets/ASSIST2012/')
    
# 데이터셋 분할
data_size = len(dataset)
train_size = int(data_size * train_ratio) 
valid_size = int(data_size * ((1.0 - train_ratio) / 2.0))
test_size = data_size - train_size - valid_size

train_dataset, valid_dataset, test_dataset = random_split(
    dataset, [train_size, valid_size, test_size], generator=torch.Generator(device=device)
)

train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True,
    collate_fn=collate_pt, generator=torch.Generator(device=device)
)
valid_loader = DataLoader(
    valid_dataset, batch_size=batch_size, shuffle=True,
    collate_fn=collate_pt, generator=torch.Generator(device=device)
)
test_loader = DataLoader(
    test_dataset, batch_size=batch_size, shuffle=True,
    collate_fn=collate_pt, generator=torch.Generator(device=device)
)

In [6]:
# 

In [7]:
# define help function

In [9]:
# load model / tokenizer
model = SUBJ_DKVMN(dataset.num_q, num_qid=dataset.num_pid, **model_config).to(device)
model.load_state_dict(torch.load(os.path.join(ckpts, "model.ckpt"), map_location=device))
model.eval()
model.zero_grad()

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


RuntimeError: Error(s) in loading state_dict for SUBJ_DKVMN:
	Missing key(s) in state_dict: "Mk", "Mv0", "k_emb_layer.weight", "d_emb_layer.weight", "transformer_encoder.layers.0.self_attn.in_proj_weight", "transformer_encoder.layers.0.self_attn.in_proj_bias", "transformer_encoder.layers.0.self_attn.out_proj.weight", "transformer_encoder.layers.0.self_attn.out_proj.bias", "transformer_encoder.layers.0.linear1.weight", "transformer_encoder.layers.0.linear1.bias", "transformer_encoder.layers.0.linear2.weight", "transformer_encoder.layers.0.linear2.bias", "transformer_encoder.layers.0.norm1.weight", "transformer_encoder.layers.0.norm1.bias", "transformer_encoder.layers.0.norm2.weight", "transformer_encoder.layers.0.norm2.bias", "bertmodel.embeddings.word_embeddings.weight", "bertmodel.embeddings.position_embeddings.weight", "bertmodel.embeddings.LayerNorm.weight", "bertmodel.embeddings.LayerNorm.bias", "bertmodel.transformer.layer.0.attention.q_lin.weight", "bertmodel.transformer.layer.0.attention.q_lin.bias", "bertmodel.transformer.layer.0.attention.k_lin.weight", "bertmodel.transformer.layer.0.attention.k_lin.bias", "bertmodel.transformer.layer.0.attention.v_lin.weight", "bertmodel.transformer.layer.0.attention.v_lin.bias", "bertmodel.transformer.layer.0.attention.out_lin.weight", "bertmodel.transformer.layer.0.attention.out_lin.bias", "bertmodel.transformer.layer.0.sa_layer_norm.weight", "bertmodel.transformer.layer.0.sa_layer_norm.bias", "bertmodel.transformer.layer.0.ffn.lin1.weight", "bertmodel.transformer.layer.0.ffn.lin1.bias", "bertmodel.transformer.layer.0.ffn.lin2.weight", "bertmodel.transformer.layer.0.ffn.lin2.bias", "bertmodel.transformer.layer.0.output_layer_norm.weight", "bertmodel.transformer.layer.0.output_layer_norm.bias", "bertmodel.transformer.layer.1.attention.q_lin.weight", "bertmodel.transformer.layer.1.attention.q_lin.bias", "bertmodel.transformer.layer.1.attention.k_lin.weight", "bertmodel.transformer.layer.1.attention.k_lin.bias", "bertmodel.transformer.layer.1.attention.v_lin.weight", "bertmodel.transformer.layer.1.attention.v_lin.bias", "bertmodel.transformer.layer.1.attention.out_lin.weight", "bertmodel.transformer.layer.1.attention.out_lin.bias", "bertmodel.transformer.layer.1.sa_layer_norm.weight", "bertmodel.transformer.layer.1.sa_layer_norm.bias", "bertmodel.transformer.layer.1.ffn.lin1.weight", "bertmodel.transformer.layer.1.ffn.lin1.bias", "bertmodel.transformer.layer.1.ffn.lin2.weight", "bertmodel.transformer.layer.1.ffn.lin2.bias", "bertmodel.transformer.layer.1.output_layer_norm.weight", "bertmodel.transformer.layer.1.output_layer_norm.bias", "bertmodel.transformer.layer.2.attention.q_lin.weight", "bertmodel.transformer.layer.2.attention.q_lin.bias", "bertmodel.transformer.layer.2.attention.k_lin.weight", "bertmodel.transformer.layer.2.attention.k_lin.bias", "bertmodel.transformer.layer.2.attention.v_lin.weight", "bertmodel.transformer.layer.2.attention.v_lin.bias", "bertmodel.transformer.layer.2.attention.out_lin.weight", "bertmodel.transformer.layer.2.attention.out_lin.bias", "bertmodel.transformer.layer.2.sa_layer_norm.weight", "bertmodel.transformer.layer.2.sa_layer_norm.bias", "bertmodel.transformer.layer.2.ffn.lin1.weight", "bertmodel.transformer.layer.2.ffn.lin1.bias", "bertmodel.transformer.layer.2.ffn.lin2.weight", "bertmodel.transformer.layer.2.ffn.lin2.bias", "bertmodel.transformer.layer.2.output_layer_norm.weight", "bertmodel.transformer.layer.2.output_layer_norm.bias", "bertmodel.transformer.layer.3.attention.q_lin.weight", "bertmodel.transformer.layer.3.attention.q_lin.bias", "bertmodel.transformer.layer.3.attention.k_lin.weight", "bertmodel.transformer.layer.3.attention.k_lin.bias", "bertmodel.transformer.layer.3.attention.v_lin.weight", "bertmodel.transformer.layer.3.attention.v_lin.bias", "bertmodel.transformer.layer.3.attention.out_lin.weight", "bertmodel.transformer.layer.3.attention.out_lin.bias", "bertmodel.transformer.layer.3.sa_layer_norm.weight", "bertmodel.transformer.layer.3.sa_layer_norm.bias", "bertmodel.transformer.layer.3.ffn.lin1.weight", "bertmodel.transformer.layer.3.ffn.lin1.bias", "bertmodel.transformer.layer.3.ffn.lin2.weight", "bertmodel.transformer.layer.3.ffn.lin2.bias", "bertmodel.transformer.layer.3.output_layer_norm.weight", "bertmodel.transformer.layer.3.output_layer_norm.bias", "bertmodel.transformer.layer.4.attention.q_lin.weight", "bertmodel.transformer.layer.4.attention.q_lin.bias", "bertmodel.transformer.layer.4.attention.k_lin.weight", "bertmodel.transformer.layer.4.attention.k_lin.bias", "bertmodel.transformer.layer.4.attention.v_lin.weight", "bertmodel.transformer.layer.4.attention.v_lin.bias", "bertmodel.transformer.layer.4.attention.out_lin.weight", "bertmodel.transformer.layer.4.attention.out_lin.bias", "bertmodel.transformer.layer.4.sa_layer_norm.weight", "bertmodel.transformer.layer.4.sa_layer_norm.bias", "bertmodel.transformer.layer.4.ffn.lin1.weight", "bertmodel.transformer.layer.4.ffn.lin1.bias", "bertmodel.transformer.layer.4.ffn.lin2.weight", "bertmodel.transformer.layer.4.ffn.lin2.bias", "bertmodel.transformer.layer.4.output_layer_norm.weight", "bertmodel.transformer.layer.4.output_layer_norm.bias", "bertmodel.transformer.layer.5.attention.q_lin.weight", "bertmodel.transformer.layer.5.attention.q_lin.bias", "bertmodel.transformer.layer.5.attention.k_lin.weight", "bertmodel.transformer.layer.5.attention.k_lin.bias", "bertmodel.transformer.layer.5.attention.v_lin.weight", "bertmodel.transformer.layer.5.attention.v_lin.bias", "bertmodel.transformer.layer.5.attention.out_lin.weight", "bertmodel.transformer.layer.5.attention.out_lin.bias", "bertmodel.transformer.layer.5.sa_layer_norm.weight", "bertmodel.transformer.layer.5.sa_layer_norm.bias", "bertmodel.transformer.layer.5.ffn.lin1.weight", "bertmodel.transformer.layer.5.ffn.lin1.bias", "bertmodel.transformer.layer.5.ffn.lin2.weight", "bertmodel.transformer.layer.5.ffn.lin2.bias", "bertmodel.transformer.layer.5.output_layer_norm.weight", "bertmodel.transformer.layer.5.output_layer_norm.bias", "at_emb_layer.weight", "at_emb_layer.bias", "at2_emb_layer.weight", "at2_emb_layer.bias", "qr_emb_layer.weight", "v_emb_layer.weight", "v_emb_layer.bias", "e_layer.weight", "e_layer.bias", "a_layer.weight", "a_layer.bias", "f_layer.weight", "f_layer.bias", "p_layer.weight", "p_layer.bias". 
	Unexpected key(s) in state_dict: "module.Mk", "module.Mv0", "module.k_emb_layer.weight", "module.d_emb_layer.weight", "module.transformer_encoder.layers.0.self_attn.in_proj_weight", "module.transformer_encoder.layers.0.self_attn.in_proj_bias", "module.transformer_encoder.layers.0.self_attn.out_proj.weight", "module.transformer_encoder.layers.0.self_attn.out_proj.bias", "module.transformer_encoder.layers.0.linear1.weight", "module.transformer_encoder.layers.0.linear1.bias", "module.transformer_encoder.layers.0.linear2.weight", "module.transformer_encoder.layers.0.linear2.bias", "module.transformer_encoder.layers.0.norm1.weight", "module.transformer_encoder.layers.0.norm1.bias", "module.transformer_encoder.layers.0.norm2.weight", "module.transformer_encoder.layers.0.norm2.bias", "module.bertmodel.embeddings.position_ids", "module.bertmodel.embeddings.word_embeddings.weight", "module.bertmodel.embeddings.position_embeddings.weight", "module.bertmodel.embeddings.token_type_embeddings.weight", "module.bertmodel.embeddings.LayerNorm.weight", "module.bertmodel.embeddings.LayerNorm.bias", "module.bertmodel.encoder.layer.0.attention.self.query.weight", "module.bertmodel.encoder.layer.0.attention.self.query.bias", "module.bertmodel.encoder.layer.0.attention.self.key.weight", "module.bertmodel.encoder.layer.0.attention.self.key.bias", "module.bertmodel.encoder.layer.0.attention.self.value.weight", "module.bertmodel.encoder.layer.0.attention.self.value.bias", "module.bertmodel.encoder.layer.0.attention.output.dense.weight", "module.bertmodel.encoder.layer.0.attention.output.dense.bias", "module.bertmodel.encoder.layer.0.attention.output.LayerNorm.weight", "module.bertmodel.encoder.layer.0.attention.output.LayerNorm.bias", "module.bertmodel.encoder.layer.0.intermediate.dense.weight", "module.bertmodel.encoder.layer.0.intermediate.dense.bias", "module.bertmodel.encoder.layer.0.output.dense.weight", "module.bertmodel.encoder.layer.0.output.dense.bias", "module.bertmodel.encoder.layer.0.output.LayerNorm.weight", "module.bertmodel.encoder.layer.0.output.LayerNorm.bias", "module.bertmodel.encoder.layer.1.attention.self.query.weight", "module.bertmodel.encoder.layer.1.attention.self.query.bias", "module.bertmodel.encoder.layer.1.attention.self.key.weight", "module.bertmodel.encoder.layer.1.attention.self.key.bias", "module.bertmodel.encoder.layer.1.attention.self.value.weight", "module.bertmodel.encoder.layer.1.attention.self.value.bias", "module.bertmodel.encoder.layer.1.attention.output.dense.weight", "module.bertmodel.encoder.layer.1.attention.output.dense.bias", "module.bertmodel.encoder.layer.1.attention.output.LayerNorm.weight", "module.bertmodel.encoder.layer.1.attention.output.LayerNorm.bias", "module.bertmodel.encoder.layer.1.intermediate.dense.weight", "module.bertmodel.encoder.layer.1.intermediate.dense.bias", "module.bertmodel.encoder.layer.1.output.dense.weight", "module.bertmodel.encoder.layer.1.output.dense.bias", "module.bertmodel.encoder.layer.1.output.LayerNorm.weight", "module.bertmodel.encoder.layer.1.output.LayerNorm.bias", "module.bertmodel.encoder.layer.2.attention.self.query.weight", "module.bertmodel.encoder.layer.2.attention.self.query.bias", "module.bertmodel.encoder.layer.2.attention.self.key.weight", "module.bertmodel.encoder.layer.2.attention.self.key.bias", "module.bertmodel.encoder.layer.2.attention.self.value.weight", "module.bertmodel.encoder.layer.2.attention.self.value.bias", "module.bertmodel.encoder.layer.2.attention.output.dense.weight", "module.bertmodel.encoder.layer.2.attention.output.dense.bias", "module.bertmodel.encoder.layer.2.attention.output.LayerNorm.weight", "module.bertmodel.encoder.layer.2.attention.output.LayerNorm.bias", "module.bertmodel.encoder.layer.2.intermediate.dense.weight", "module.bertmodel.encoder.layer.2.intermediate.dense.bias", "module.bertmodel.encoder.layer.2.output.dense.weight", "module.bertmodel.encoder.layer.2.output.dense.bias", "module.bertmodel.encoder.layer.2.output.LayerNorm.weight", "module.bertmodel.encoder.layer.2.output.LayerNorm.bias", "module.bertmodel.encoder.layer.3.attention.self.query.weight", "module.bertmodel.encoder.layer.3.attention.self.query.bias", "module.bertmodel.encoder.layer.3.attention.self.key.weight", "module.bertmodel.encoder.layer.3.attention.self.key.bias", "module.bertmodel.encoder.layer.3.attention.self.value.weight", "module.bertmodel.encoder.layer.3.attention.self.value.bias", "module.bertmodel.encoder.layer.3.attention.output.dense.weight", "module.bertmodel.encoder.layer.3.attention.output.dense.bias", "module.bertmodel.encoder.layer.3.attention.output.LayerNorm.weight", "module.bertmodel.encoder.layer.3.attention.output.LayerNorm.bias", "module.bertmodel.encoder.layer.3.intermediate.dense.weight", "module.bertmodel.encoder.layer.3.intermediate.dense.bias", "module.bertmodel.encoder.layer.3.output.dense.weight", "module.bertmodel.encoder.layer.3.output.dense.bias", "module.bertmodel.encoder.layer.3.output.LayerNorm.weight", "module.bertmodel.encoder.layer.3.output.LayerNorm.bias", "module.bertmodel.encoder.layer.4.attention.self.query.weight", "module.bertmodel.encoder.layer.4.attention.self.query.bias", "module.bertmodel.encoder.layer.4.attention.self.key.weight", "module.bertmodel.encoder.layer.4.attention.self.key.bias", "module.bertmodel.encoder.layer.4.attention.self.value.weight", "module.bertmodel.encoder.layer.4.attention.self.value.bias", "module.bertmodel.encoder.layer.4.attention.output.dense.weight", "module.bertmodel.encoder.layer.4.attention.output.dense.bias", "module.bertmodel.encoder.layer.4.attention.output.LayerNorm.weight", "module.bertmodel.encoder.layer.4.attention.output.LayerNorm.bias", "module.bertmodel.encoder.layer.4.intermediate.dense.weight", "module.bertmodel.encoder.layer.4.intermediate.dense.bias", "module.bertmodel.encoder.layer.4.output.dense.weight", "module.bertmodel.encoder.layer.4.output.dense.bias", "module.bertmodel.encoder.layer.4.output.LayerNorm.weight", "module.bertmodel.encoder.layer.4.output.LayerNorm.bias", "module.bertmodel.encoder.layer.5.attention.self.query.weight", "module.bertmodel.encoder.layer.5.attention.self.query.bias", "module.bertmodel.encoder.layer.5.attention.self.key.weight", "module.bertmodel.encoder.layer.5.attention.self.key.bias", "module.bertmodel.encoder.layer.5.attention.self.value.weight", "module.bertmodel.encoder.layer.5.attention.self.value.bias", "module.bertmodel.encoder.layer.5.attention.output.dense.weight", "module.bertmodel.encoder.layer.5.attention.output.dense.bias", "module.bertmodel.encoder.layer.5.attention.output.LayerNorm.weight", "module.bertmodel.encoder.layer.5.attention.output.LayerNorm.bias", "module.bertmodel.encoder.layer.5.intermediate.dense.weight", "module.bertmodel.encoder.layer.5.intermediate.dense.bias", "module.bertmodel.encoder.layer.5.output.dense.weight", "module.bertmodel.encoder.layer.5.output.dense.bias", "module.bertmodel.encoder.layer.5.output.LayerNorm.weight", "module.bertmodel.encoder.layer.5.output.LayerNorm.bias", "module.bertmodel.encoder.layer.6.attention.self.query.weight", "module.bertmodel.encoder.layer.6.attention.self.query.bias", "module.bertmodel.encoder.layer.6.attention.self.key.weight", "module.bertmodel.encoder.layer.6.attention.self.key.bias", "module.bertmodel.encoder.layer.6.attention.self.value.weight", "module.bertmodel.encoder.layer.6.attention.self.value.bias", "module.bertmodel.encoder.layer.6.attention.output.dense.weight", "module.bertmodel.encoder.layer.6.attention.output.dense.bias", "module.bertmodel.encoder.layer.6.attention.output.LayerNorm.weight", "module.bertmodel.encoder.layer.6.attention.output.LayerNorm.bias", "module.bertmodel.encoder.layer.6.intermediate.dense.weight", "module.bertmodel.encoder.layer.6.intermediate.dense.bias", "module.bertmodel.encoder.layer.6.output.dense.weight", "module.bertmodel.encoder.layer.6.output.dense.bias", "module.bertmodel.encoder.layer.6.output.LayerNorm.weight", "module.bertmodel.encoder.layer.6.output.LayerNorm.bias", "module.bertmodel.encoder.layer.7.attention.self.query.weight", "module.bertmodel.encoder.layer.7.attention.self.query.bias", "module.bertmodel.encoder.layer.7.attention.self.key.weight", "module.bertmodel.encoder.layer.7.attention.self.key.bias", "module.bertmodel.encoder.layer.7.attention.self.value.weight", "module.bertmodel.encoder.layer.7.attention.self.value.bias", "module.bertmodel.encoder.layer.7.attention.output.dense.weight", "module.bertmodel.encoder.layer.7.attention.output.dense.bias", "module.bertmodel.encoder.layer.7.attention.output.LayerNorm.weight", "module.bertmodel.encoder.layer.7.attention.output.LayerNorm.bias", "module.bertmodel.encoder.layer.7.intermediate.dense.weight", "module.bertmodel.encoder.layer.7.intermediate.dense.bias", "module.bertmodel.encoder.layer.7.output.dense.weight", "module.bertmodel.encoder.layer.7.output.dense.bias", "module.bertmodel.encoder.layer.7.output.LayerNorm.weight", "module.bertmodel.encoder.layer.7.output.LayerNorm.bias", "module.bertmodel.encoder.layer.8.attention.self.query.weight", "module.bertmodel.encoder.layer.8.attention.self.query.bias", "module.bertmodel.encoder.layer.8.attention.self.key.weight", "module.bertmodel.encoder.layer.8.attention.self.key.bias", "module.bertmodel.encoder.layer.8.attention.self.value.weight", "module.bertmodel.encoder.layer.8.attention.self.value.bias", "module.bertmodel.encoder.layer.8.attention.output.dense.weight", "module.bertmodel.encoder.layer.8.attention.output.dense.bias", "module.bertmodel.encoder.layer.8.attention.output.LayerNorm.weight", "module.bertmodel.encoder.layer.8.attention.output.LayerNorm.bias", "module.bertmodel.encoder.layer.8.intermediate.dense.weight", "module.bertmodel.encoder.layer.8.intermediate.dense.bias", "module.bertmodel.encoder.layer.8.output.dense.weight", "module.bertmodel.encoder.layer.8.output.dense.bias", "module.bertmodel.encoder.layer.8.output.LayerNorm.weight", "module.bertmodel.encoder.layer.8.output.LayerNorm.bias", "module.bertmodel.encoder.layer.9.attention.self.query.weight", "module.bertmodel.encoder.layer.9.attention.self.query.bias", "module.bertmodel.encoder.layer.9.attention.self.key.weight", "module.bertmodel.encoder.layer.9.attention.self.key.bias", "module.bertmodel.encoder.layer.9.attention.self.value.weight", "module.bertmodel.encoder.layer.9.attention.self.value.bias", "module.bertmodel.encoder.layer.9.attention.output.dense.weight", "module.bertmodel.encoder.layer.9.attention.output.dense.bias", "module.bertmodel.encoder.layer.9.attention.output.LayerNorm.weight", "module.bertmodel.encoder.layer.9.attention.output.LayerNorm.bias", "module.bertmodel.encoder.layer.9.intermediate.dense.weight", "module.bertmodel.encoder.layer.9.intermediate.dense.bias", "module.bertmodel.encoder.layer.9.output.dense.weight", "module.bertmodel.encoder.layer.9.output.dense.bias", "module.bertmodel.encoder.layer.9.output.LayerNorm.weight", "module.bertmodel.encoder.layer.9.output.LayerNorm.bias", "module.bertmodel.encoder.layer.10.attention.self.query.weight", "module.bertmodel.encoder.layer.10.attention.self.query.bias", "module.bertmodel.encoder.layer.10.attention.self.key.weight", "module.bertmodel.encoder.layer.10.attention.self.key.bias", "module.bertmodel.encoder.layer.10.attention.self.value.weight", "module.bertmodel.encoder.layer.10.attention.self.value.bias", "module.bertmodel.encoder.layer.10.attention.output.dense.weight", "module.bertmodel.encoder.layer.10.attention.output.dense.bias", "module.bertmodel.encoder.layer.10.attention.output.LayerNorm.weight", "module.bertmodel.encoder.layer.10.attention.output.LayerNorm.bias", "module.bertmodel.encoder.layer.10.intermediate.dense.weight", "module.bertmodel.encoder.layer.10.intermediate.dense.bias", "module.bertmodel.encoder.layer.10.output.dense.weight", "module.bertmodel.encoder.layer.10.output.dense.bias", "module.bertmodel.encoder.layer.10.output.LayerNorm.weight", "module.bertmodel.encoder.layer.10.output.LayerNorm.bias", "module.bertmodel.encoder.layer.11.attention.self.query.weight", "module.bertmodel.encoder.layer.11.attention.self.query.bias", "module.bertmodel.encoder.layer.11.attention.self.key.weight", "module.bertmodel.encoder.layer.11.attention.self.key.bias", "module.bertmodel.encoder.layer.11.attention.self.value.weight", "module.bertmodel.encoder.layer.11.attention.self.value.bias", "module.bertmodel.encoder.layer.11.attention.output.dense.weight", "module.bertmodel.encoder.layer.11.attention.output.dense.bias", "module.bertmodel.encoder.layer.11.attention.output.LayerNorm.weight", "module.bertmodel.encoder.layer.11.attention.output.LayerNorm.bias", "module.bertmodel.encoder.layer.11.intermediate.dense.weight", "module.bertmodel.encoder.layer.11.intermediate.dense.bias", "module.bertmodel.encoder.layer.11.output.dense.weight", "module.bertmodel.encoder.layer.11.output.dense.bias", "module.bertmodel.encoder.layer.11.output.LayerNorm.weight", "module.bertmodel.encoder.layer.11.output.LayerNorm.bias", "module.bertmodel.pooler.dense.weight", "module.bertmodel.pooler.dense.bias", "module.at_emb_layer.weight", "module.at_emb_layer.bias", "module.at2_emb_layer.weight", "module.at2_emb_layer.bias", "module.qr_emb_layer.weight", "module.v_emb_layer.weight", "module.v_emb_layer.bias", "module.e_layer.weight", "module.e_layer.bias", "module.a_layer.weight", "module.a_layer.bias", "module.f_layer.weight", "module.f_layer.bias", "module.p_layer.weight", "module.p_layer.bias". 

In [None]:
# predict answer (testing model)
def predict(q, r, at_s, at_t, at_m):
    output = model(q, r, at_s, at_t, at_m)
    return output.start_logits, output.end_logits

# custom forward function
def custom_forward(inputs, token_type_ids=None, position_ids=None, attention_mask=None, position=0):
    pred = predict(inputs, position_ids, token_type_ids, attention_mask)
    pred = pred[position]
    return pred.max(1).values

In [None]:
ref_token_id = tokenizer.pad_token_id
sep_token_id = tokenizer.sep_token_id
cls_token_id = tokenizer.cls_token_id

NameError: name 'tokenizer' is not defined

In [None]:
# define helper function for constructing references / baseline for word tokens
# 이게 원래 인풋으로 들어가는 Question과 아웃풋의 Answering이 존재했는데 나는 Answering 안써서 Answering 제거함
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):
    text_ids = tokenizer.encode(text, add_special_tokens=False)
    
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    
    # construct reference token ids
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids)
    
    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)


def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device) # * -1
    return token_type_ids, ref_token_type_ids


def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with 'torch.randperm(seq_length, device=device)'
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)
    
    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids


# 어텐션 마스크 구성
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)


# 버트 임베딩 구성
def construct_whole_bert_embeddings(input_ids, ref_input_ids, \
                                    token_type_ids=None, ref_token_type_ids=None, \
                                    position_ids=None, ref_position_ids=None):
    input_embeddings = model.module.bertmodel.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
    ref_input_embeddings = model.module.bertmodel.embeddings(ref_input_ids, token_type_ids=ref_token_type_ids, position_ids=ref_position_ids)
    
    return input_embeddings, ref_input_embeddings


# 시퀀스에서 각 워드 토큰에 대해 속성을 요약해주는 헬퍼 함수
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [None]:
with torch.no_grad():
    for i, data in enumerate(test_loader):
        q, r, qshft_seqs, rshft_seqs, m, at_seqs, q2diff_seqs, pid_seqs, pidshift = data
        
        print(at_seqs)
        input_ids, ref_input_ids, sep_id = construct_input_ref_pair(at_seqs, ref_token_id, sep_token_id, cls_token_id)
        token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
        position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
        attention_mask = construct_attention_mask(input_ids)
        
        indices = input_ids[0].detach().tolist()
        all_tokens = tokenizer.convert_ids_to_tokens(indices)
        
        
        # make prediction
        start_scores, end_scores = predict(q.long(), r.long(), input_ids, token_type_ids, attention_mask)
        print(f"Affected Text: {all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores) + 1]}")
        
        # 워드 임베딩 변화량 볼 수 있음
        lig = LayerIntegratedGradients()
        attribution_start, delta_start = lig
        
            
        break

In [None]:
def 