In [9]:
# interpreting BERT Models with Captum library
import os
import json
import numpy as np
import pandas as pd
import seaborn as sns 
import matplotlib.pyplot as plt 

import torch 
import torch.nn as nn
from torch.utils.data import DataLoader, random_split

from transformers import DistilBertTokenizer

from captum.attr import visualization as viz
from captum.attr import LayerIntegratedGradients

In [10]:
# module
from models.dkvmn_text import SUBJ_DKVMN

from data_loaders.assist2009 import ASSIST2009
from data_loaders.assist2012 import ASSIST2012
from data_loaders.csedm import CSEDM

In [11]:
# load config
model_name = "dkvmn+"
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
dataset_name = "ASSIST2009"
ckpts = f"ckpts/{model_name}/{dataset_name}/"

with open("config.json") as f:
    config = json.load(f)
    model_config = config[model_name]
    train_config = config["train_config"]
    
batch_size = train_config["batch_size"]
num_epochs = train_config["num_epochs"]
train_ratio = train_config["train_ratio"]
learning_rate = train_config["learning_rate"]
optimizer = train_config["optimizer"] # can be sgd, adam
seq_len = train_config["seq_len"] # 샘플링 할 갯수

In [12]:
from torch.nn.utils.rnn import pad_sequence

if torch.cuda.is_available():
    from torch.cuda import FloatTensor, CharTensor, LongTensor
    torch.set_default_tensor_type(torch.cuda.FloatTensor)
else:
    from torch import FloatTensor, CharTensor, LongTensor
# tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')

tokenizer = DistilBertTokenizer.from_pretrained('./ckpts/tokenizer/ASSIST2009/DISTIL_A2009')

def custom_collate(batch, pad_val=-1):
    '''
    This function for torch.utils.data.DataLoader

    Returns:
        q_seqs: the question(KC) sequences with the size of \
            [batch_size, maximum_sequence_length_in_the_batch]
        r_seqs: the response sequences with the size of \
            [batch_size, maximum_sequence_length_in_the_batch]
        qshft_seqs: the question(KC) sequences which were shifted \
            one step to the right with the size of \
            [batch_size, maximum_sequence_length_in_the_batch]
        rshft_seqs: the response sequences which were shifted \
            one step to the right with the size of \
            [batch_size, maximum_sequence_length_in_the_batch]
        mask_seqs: the mask sequences indicating where \
            the padded entry is with the size of \
            [batch_size, maximum_sequence_length_in_the_batch]
    '''

    q_seqs = []
    r_seqs = []
    qshft_seqs = []
    rshft_seqs = []
    at_seqs = []
    atshft_seqs = []
    q2diff_seqs = []
    pid_seqs = []
    pidshft_seqs = []
    hint_seqs = []

    # q_seq와 r_seq는 마지막 전까지만 가져옴 (마지막은 padding value)
    # q_shft와 rshft는 처음 값 이후 가져옴 (우측 시프트 값이므로..)
    for q_seq, r_seq, at_seq, q2diff, pid_seq, hint_seq in batch:
        q_seqs.append(FloatTensor(q_seq[:-1])) 
        r_seqs.append(FloatTensor(r_seq[:-1]))
        at_seqs.append(at_seq[:-1])
        atshft_seqs.append(at_seq[1:])
        qshft_seqs.append(FloatTensor(q_seq[1:]))
        rshft_seqs.append(FloatTensor(r_seq[1:]))
        q2diff_seqs.append(FloatTensor(q2diff[:-1]))
        pid_seqs.append(FloatTensor(pid_seq[:-1]))
        pidshft_seqs.append(FloatTensor(pid_seq[1:]))
        hint_seqs.append(FloatTensor(hint_seq[:-1]))

    # pad_sequence, 첫번째 인자는 sequence, 두번째는 batch_size가 첫 번째로 인자로 오게 하는 것이고, 3번째 인자의 경우 padding된 요소의 값
    # 시퀀스 내 가장 길이가 긴 시퀀스를 기준으로 padding이 됨, 길이가 안맞는 부분은 늘려서 padding_value 값으로 채워줌
    q_seqs = pad_sequence(
        q_seqs, batch_first=True, padding_value=pad_val
    )
    r_seqs = pad_sequence(
        r_seqs, batch_first=True, padding_value=pad_val
    )
    q2diff_seqs = pad_sequence(
        q2diff_seqs, batch_first=True, padding_value=pad_val
    )
    qshft_seqs = pad_sequence(
        qshft_seqs, batch_first=True, padding_value=pad_val
    )
    rshft_seqs = pad_sequence(
        rshft_seqs, batch_first=True, padding_value=pad_val
    )
    pid_seqs = pad_sequence(
        pid_seqs, batch_first=True, padding_value=pad_val
    )
    pidshft_seqs = pad_sequence(
        pidshft_seqs, batch_first=True, padding_value=pad_val
    )
    hint_seqs = pad_sequence(
        hint_seqs, batch_first=True, padding_value=pad_val
    )

    # 마스킹 시퀀스 생성 
    # 일반 question 시퀀스: 패딩 밸류와 다른 값들은 모두 1로 처리, 패딩 처리된 값들은 0으로 처리.
    # 일반 question padding 시퀀스: 한 칸 옆으로 시프팅 된 시퀀스 값들이 패딩 값과 다를 경우 1로 처리, 패딩 처리 된 값들은 0으로 처리.
    # 마스킹 시퀀스: 패딩 처리 된 시퀀스 밸류들은 모두 0, 두 값 모두 패딩처리 되지 않았을 경우 1로 처리. (원본 시퀀스와 shift 시퀀스 모두의 값)
    # 예를 들어, 현재 값과 다음 값이 패딩 값이 아닐 경우 1, 현재 값과 다음 값 둘 중 하나라도 패딩일 경우 0으로 처리함.
    mask_seqs = (q_seqs != pad_val) * (qshft_seqs != pad_val)

    # 원본 값의 다음 값이(shift value) 패딩이기만 해도 마스킹 시퀀스에 의해 값이 0로 변함. 아닐경우 원본 시퀀스 데이터를 가짐.
    q_seqs, r_seqs, qshft_seqs, rshft_seqs, q2diff_seqs, pid_seqs, pidshft_seqs, hint_seqs = \
        q_seqs * mask_seqs, r_seqs * mask_seqs, qshft_seqs * mask_seqs, \
        rshft_seqs * mask_seqs, q2diff_seqs * mask_seqs, pid_seqs * mask_seqs, \
        pidshft_seqs * mask_seqs, hint_seqs * mask_seqs
    

    # Word2vec

    # BERT preprocessing
    bert_details = []

    # def mapmax(data):
    #     return max(data, key=len)

    # 2차원에서 가장 긴 문장 추출
    # SENT_LEN = len(max(map(mapmax, at_seqs), key=len))

    for answer_text in at_seqs:
        text = list(map(str, answer_text))
        # print(f"============= text: {text} ================")
        
        encoded_bert_sent = tokenizer.encode_plus(
            text, add_special_tokens=False, padding="max_length", max_len=200, truncation=True
        )
        bert_details.append(encoded_bert_sent)
    
    # 정답지 추가
    # proc_atshft_seqs = []
    # # SENT_LEN = q_seqs.size(0)
    # for answer_text in atshft_seqs:
    #     text = " ".join(map(str, answer_text))
    #     encoded_bert_sent = bert_tokenizer.encode_plus(
    #         text, add_special_tokens=True, padding='max_length', truncation=True
    #     )
    #     proc_atshft_seqs.append(encoded_bert_sent)

    bert_sentences = LongTensor([text["input_ids"] for text in bert_details])
    # bert_sentence_types = LongTensor([text["token_type_ids"] for text in bert_details])
    bert_sentence_att_mask = LongTensor([text["attention_mask"] for text in bert_details])
    # proc_atshft_sentences = LongTensor([text["input_ids"] for text in proc_atshft_seqs])

    return q_seqs, r_seqs, qshft_seqs, rshft_seqs, mask_seqs, bert_sentences, [], bert_sentence_att_mask, q2diff_seqs, pid_seqs, pidshft_seqs, at_seqs


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [13]:
# load dataset
# 데이터셋 추가 가능
collate_pt = custom_collate
dataset = ASSIST2009(seq_len, 'datasets/ASSIST2009/')
if dataset_name == "ASSIST2012":
    dataset = ASSIST2012(seq_len, 'datasets/ASSIST2012/')
elif dataset_name == "CSEDM":
    dataset = CSEDM(seq_len, 'datasets/CSEDM/')
    
# 데이터셋 분할
data_size = len(dataset)
train_size = int(data_size * train_ratio) 
valid_size = int(data_size * ((1.0 - train_ratio) / 2.0))
test_size = data_size - train_size - valid_size

train_dataset, valid_dataset, test_dataset = random_split(
    dataset, [train_size, valid_size, test_size], generator=torch.Generator(device=device)
)

train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True,
    collate_fn=collate_pt, generator=torch.Generator(device=device)
)
valid_loader = DataLoader(
    valid_dataset, batch_size=batch_size, shuffle=True,
    collate_fn=collate_pt, generator=torch.Generator(device=device)
)
test_loader = DataLoader(
    test_dataset, batch_size=batch_size, shuffle=True,
    collate_fn=collate_pt, generator=torch.Generator(device=device)
)

In [14]:
# 

In [15]:
# define help function
model_config

{'dim_s': 200, 'size_m': 20}

In [16]:
# load model / tokenizer in collate function
print(ckpts)
model = torch.nn.DataParallel(SUBJ_DKVMN(dataset.num_q, num_qid=dataset.num_pid, **model_config)).to(device)
model.load_state_dict(torch.load(os.path.join(ckpts, "model.ckpt"), map_location=device))
model.eval()
model.zero_grad()


ckpts/dkvmn+/ASSIST2009/


In [17]:
# predict answer (testing model)
def predict(q, r, at_s, at_t, at_m):
    output, Mv = model(q, r, at_s, at_t, at_m)
    return output

# custom forward function
def custom_forward(inputs, token_type_ids=None, position_ids=None, attention_mask=None, position=0):
    pred = predict(inputs, position_ids, token_type_ids, attention_mask)
    pred = pred[position]
    return pred.max(1).values

In [18]:
ref_token_id = tokenizer.pad_token_id
sep_token_id = tokenizer.sep_token_id
cls_token_id = tokenizer.cls_token_id

In [19]:
# define helper function for constructing references / baseline for word tokens
# 이게 원래 인풋으로 들어가는 Question과 아웃풋의 Answering이 존재했는데 나는 Answering 안써서 Answering 제거함
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):
    text_ids = tokenizer.encode(text, add_special_tokens=False, max_length=200, truncation=True, padding="max_length")
    
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    
    # construct reference token ids
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]
    
    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)


def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device) # * -1
    return token_type_ids, ref_token_type_ids


def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with 'torch.randperm(seq_length, device=device)'
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)
    
    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids


# 어텐션 마스크 구성
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)


# 버트 임베딩 구성
def construct_whole_bert_embeddings(input_ids, ref_input_ids, \
                                    token_type_ids=None, ref_token_type_ids=None, \
                                    position_ids=None, ref_position_ids=None):
    input_embeddings = model.module.bertmodel.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
    ref_input_embeddings = model.module.bertmodel.embeddings(ref_input_ids, token_type_ids=ref_token_type_ids, position_ids=ref_position_ids)
    
    return input_embeddings, ref_input_embeddings


# 시퀀스에서 각 워드 토큰에 대해 속성을 요약해주는 헬퍼 함수
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [20]:
# 일단 Interpreting text models tutorial 보고 진행
def add_attributions_to_visualizer(attributions, text, pred, pred_ind, label, delta):
    attributions = attributions.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.cpu().detach().numpy()

    print(attributions)

    end_position_vis = viz.VisualizationDataRecord(
        attributions, # word_attributions
        pred, # pred_prob
        pred_ind, # pred_class
        label, # true_class
        str(pred_ind),  # attr_class
        attributions.sum(), # attr_score
        text, # row_input_ids
        delta # convergence_score
    )
    viz.visualize_text([end_position_vis])

def interpret_sentence(model, q, r, input_ids, token_type_ids, attention_mask):
    model.zero_grad()

    input_indices = input_ids
    # input_indices = input_indices.unsqueeze(0)
    print(input_indices.shape)
    all_tokens = tokenizer.convert_ids_to_tokens(input_indices[1])

    # predict
    pred = predict(q.long(), r.long(), input_ids, token_type_ids, attention_mask)
    pred_ind = torch.round(pred)
    
    # compute attributions and approximation delta using layer integrated gradients
    lig = LayerIntegratedGradients(predict, model.module.bertmodel.embeddings)
    ig, delta = lig.attribute((q.long()), additional_forward_args=(r.long(), input_ids, token_type_ids, attention_mask), \
                                    return_convergence_delta=True, n_steps=10, internal_batch_size=16, target=1)
    print(pred_ind.shape, pred.shape, delta.shape)
    print(f"pred: {pred_ind[0]} pred %: {pred[0]} delta: {abs(delta[0])} ")

    add_attributions_to_visualizer(ig, all_tokens, pred[0], pred_ind[0], r, delta[0])

In [21]:
with torch.no_grad():
    for i, data in enumerate(test_loader):
        q, r, qshft_seqs, rshft_seqs, mask_seqs, input_ids, token_type_ids, attention_mask, q2diff_seqs, pid_seqs, pidshft_seqs, at_seqs \
            = data
        
        interpret_sentence(model, q, r, input_ids, token_type_ids, attention_mask)

        break

torch.Size([16, 512])
torch.Size([16, 200]) torch.Size([16, 200]) torch.Size([16])
pred: tensor([1., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.]) pred %: tensor

TypeError: unsupported format string passed to Tensor.__format__

In [29]:
with torch.no_grad():
    for i, data in enumerate(test_loader):
        q, r, qshft_seqs, rshft_seqs, mask_seqs, input_ids, token_type_ids, attention_mask, q2diff_seqs, pid_seqs, pidshft_seqs, at_seqs \
            = data
        print(input_ids.shape)
        indices = input_ids[0]
        all_tokens = tokenizer.convert_ids_to_tokens(indices)
        score = predict(q.long(), r.long(), input_ids, token_type_ids, attention_mask)[0] # sequence 하나만 해보자
        print(score.shape, all_tokens[0], )
        # acc = [1 if p >= 0.5 else 0 for p in score.detach().cpu().numpy()]
        # print(acc, score, r[0][0], torch.argmax(score[0]), input_ids.shape)
        attr = LayerIntegratedGradients(predict, model.module.qr_emb_layer)
        attributions, delta = attr.attribute((q.long()), additional_forward_args=(r.long(), input_ids, token_type_ids, attention_mask), \
                                    return_convergence_delta=True, n_steps=50, internal_batch_size=16, target=1)
        # Assuming attributions is a tensor
        print(len(attributions))
        attributions_sum = summarize_attributions(attributions)

        # attributions_end_sum = summarize_attributions(attributions_end)

        end_position_vis = viz.VisualizationDataRecord(
                        attributions_sum, # word_attributions
                        torch.max(torch.softmax(score[0], dim=0)), # pred_prob
                        torch.argmax(score[0]), # pred_class
                        torch.argmax(r[0]), # true_class
                        str(torch.argmax(score[0])), # attr_class
                        attributions_sum.sum(), # attr_score
                        all_tokens[0], # row_input_ids
                        delta) # convergence_score
        print(f"end_position_vis: {indices}")
        viz.visualize_text([end_position_vis])
        break

torch.Size([16, 512])
torch.Size([200]) 96
16
end_position_vis: tensor([ 5306, 39955, 35496, 39962, 29332, 31982, 30962, 32001,  2539,  1367,
          129,  1492, 30014, 35219, 29040, 30398,   125, 34207, 32099,   130,
        32185,  1429,  1367,   129, 29149, 36278, 17683,  3731, 29355, 29173,
        29023, 29163, 29257, 40305, 29161,   130,   130,   129, 40308,   129,
          100,   100,   100,   100,   100,   100,   100,   100,   100,   100,
          100,   100,   100,   100,   100,   100,   100,   100,   100,   100,
          100,   100,   100,   100,   100,   100,   100,   100,   100,   100,
          100,   100,   100,   100,   100,   100,   100,   100,   100,   100,
          100,   100,   100,   100,   100,   100,   100,   100,   100,   100,
          100,   100,   100,   100,   100,   100,   100,   100,   100,   100,
          100,   100,   100,   100,   100,   100,   100,   100,   100,   100,
          100,   100,   100,   100,   100,   100,   100,   100,   100,   100,


RuntimeError: Boolean value of Tensor with more than one value is ambiguous

In [33]:
index = 12
end_position_vis = viz.VisualizationDataRecord(
                attributions_sum[index], # word_attributions
                score[index], # pred_prob
                torch.round(score[index]), # pred_class
                r[0][index], # true_class
                str(r[0][index]), # attr_class
                attributions_sum[index].sum(), # attr_score
                all_tokens[index], # row_input_ids
                delta) # convergence_score
print(f"{len(attributions_sum)} {r.shape}")
print(f"{len(all_tokens)}")
viz.visualize_text([end_position_vis])

16 torch.Size([16, 200])
512


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1.0 (1.00),tensor(1.),-0.01,9 . 5
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1.0 (1.00),tensor(1.),-0.01,9 . 5
,,,,


In [None]:
tokenz = ' '.join(all_tokens)
tok_data = tokenizer.tokenize(tokenz)
print(tokenizer.decode(tokenizer.convert_tokens_to_ids(all_tokens[5])))
if '[UNK]' in tok_data:
    new_data = tokenizer.convert_tokens_to_ids(tok_data)
    print(tokenizer.decode(new_data))


[ U N K ]
$1.19  60 99 98 [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [U

In [None]:
with torch.no_grad():
    for i, data in enumerate(test_loader):
        q, r, qshft_seqs, rshft_seqs, mask_seqs, input_ids, token_type_ids, attention_mask, q2diff_seqs, pid_seqs, pidshft_seqs, at_seqs \
            = data
        
        input_ids = []
        ref_input_ids = []
        token_type_ids = []
        ref_token_type_ids = []
        position_ids = []
        ref_position_ids = []
        attention_masks = []
        
                
        for answer_text in at_seqs:
            text = ' '.join(map(str, answer_text))
            # print(f"============= text: {text} ================")
            
            input_id, ref_input_id, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
            token_type_id, ref_token_type_id = construct_input_ref_token_type_pair(input_id, sep_id)
            position_id, ref_position_id = construct_input_ref_pos_id_pair(input_id)
            attention_mask = construct_attention_mask(input_id)
            
            input_ids.append(input_id)
            ref_input_ids.append(ref_input_id)
            token_type_ids.append(token_type_id)
            ref_token_type_ids.append(ref_token_type_id)
            attention_masks.append(attention_mask)
            
        # print(input_ids)
        input_ids = torch.tensor([t.numpy() for t in input_ids])
        print(input_ids)
        ref_input_ids = [t.numpy() for t in ref_input_ids]
        token_type_ids = [t.numpy() for t in token_type_ids]
        ref_token_type_ids = [t.numpy() for t in ref_token_type_ids]
        position_ids = [t.numpy() for t in position_ids]
        ref_position_ids = [t.numpy() for t in ref_position_ids]
        attention_masks = [t.numpy() for t in attention_masks]
            
        score = predict(q.long(), r.long(), input_ids, token_type_ids, attention_mask)
            
        indices = input_ids[0].detach().tolist()
        all_tokens = tokenizer.convert_ids_to_tokens(indices)
        # 워드 임베딩 변화량 볼 수 있음
        lig = LayerIntegratedGradients(model, model.module.bertmodel.embeddings)
        # attributions_start, delta_start = lig.attribute(inputs=input_ids,
        #                         baselines=ref_input_ids,
        #                         additional_forward_args=(token_type_ids, position_ids, attention_mask, 0),
        #                         return_convergence_delta=True)
        # attributions_end, delta_end = lig.attribute(inputs=input_ids, baselines=ref_input_ids,
        #                         additional_forward_args=(token_type_ids, position_ids, attention_mask, 1),
        #                         return_convergence_delta=True)
        
        attributions, delta = lig.attribute(inputs=input_ids,
                        baselines=ref_input_ids,
                        n_steps=700,
                        internal_batch_size=3,
                        return_convergence_delta=True)
        attributions_sum = summarize_attributions(attributions)
        # attributions_end_sum = summarize_attributions(attributions_end)
        
        end_position_vis = viz.VisualizationDataRecord(
                        attributions_sum,
                        torch.max(torch.softmax(score, dim=0)),
                        torch.argmax(score),
                        torch.argmax(score),
                        str(indices.index(input_ids)),
                        attributions_sum.sum(),       
                        all_tokens,
                        delta)
        viz.visualize_text([end_position_vis])
            

        
        

        
        # print(f"Affected Text: {all_tokens[torch.argmax(torch.tensor(score[0])) : torch.argmax(torch.tensor(score[-1])) + 1]}")
        


        break

NameError: name 'test_loader' is not defined

In [None]:
attributions[0, attributions[0] > 0]

tensor([], dtype=torch.float64)