In [8]:
# interpreting BERT Models with Captum library
import json
import numpy as np
import pandas as pd
import seaborn as sns 
import matplotlib.pyplot as plt 

import torch 
import torch.nn as nn
from torch.utils.data import DataLoader, random_split

from transformers import DistilBertTokenizer

from captum.attr import visualization as viz
from captum.attr import LayerConductance, LayerIntegratedGradients

In [9]:
# module
from models.dkvmn_text import SUBJ_DKVMN

from data_loaders.assist2009 import ASSIST2009
from data_loaders.assist2012 import ASSIST2012
from models.utils import collate_fn

In [10]:
# load config
model_name = "dkvmn+"
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
dataset_name = "ASSIST2009"
seq_len = 200

ckpts = f"ckpts/{model_name}2009-performance/{dataset_name}/"

with open("config.json") as f:
    config = json.load(f)
    model_config = config[model_name]
    train_config = config["train_config"]
    
batch_size = train_config["batch_size"]
num_epochs = train_config["num_epochs"]
train_ratio = train_config["train_ratio"]
learning_rate = train_config["learning_rate"]
optimizer = train_config["optimizer"] # can be sgd, adam
seq_len = train_config["seq_len"] # 샘플링 할 갯수

In [5]:
# load dataset
# 데이터셋 추가 가능
collate_pt = collate_fn
if dataset_name == "ASSIST2009":
    dataset = ASSIST2009(seq_len, 'datasets/ASSIST2009/')
elif dataset_name == "ASSIST2012":
    dataset = ASSIST2012(seq_len, 'datasets/ASSIST2012/')
    
# 데이터셋 분할
data_size = len(dataset)
train_size = int(data_size * train_ratio) 
valid_size = int(data_size * ((1.0 - train_ratio) / 2.0))
test_size = data_size - train_size - valid_size

train_dataset, valid_dataset, test_dataset = random_split(
    dataset, [train_size, valid_size, test_size], generator=torch.Generator(device=device)
)

train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True,
    collate_fn=collate_pt, generator=torch.Generator(device=device)
)
valid_loader = DataLoader(
    valid_dataset, batch_size=batch_size, shuffle=True,
    collate_fn=collate_pt, generator=torch.Generator(device=device)
)
test_loader = DataLoader(
    test_dataset, batch_size=batch_size, shuffle=True,
    collate_fn=collate_pt, generator=torch.Generator(device=device)
)

In [None]:
# 

In [None]:
# define help function

In [2]:
# load model / tokenizer
model = torch.nn.DataParallel(SUBJ_DKVMN(dataset.num_q, **model_config)).to(device)
model.eval()
model.zero_grad()

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

In [12]:
# predict answer (testing model)
def predict(q, r, at_s, at_t, at_m, q2diff):
    output = model(q, r, at_s, at_t, at_m, q2diff, None)
    return output.start_logits, output.end_logits

In [13]:
ref_token_id = tokenizer.pad_token_id
sep_token_id = tokenizer.sep_token_id
cls_token_id = tokenizer.cls_token_id

NameError: name 'tokenizer' is not defined

In [None]:
# define helper function for constructing references / baseline for word tokens
# 이게 원래 인풋으로 들어가는 Question과 아웃풋의 Answering이 존재했는데 나는 Answering 안써서 Answering 제거함
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):
    text_ids = tokenizer.encode(text, add_special_tokens=False)
    
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    
    # construct reference token ids
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids)
    
    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)


def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device) # * -1
    return token_type_ids, ref_token_type_ids


def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with 'torch.randperm(seq_length, device=device)'
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)
    
    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids


# 어텐션 마스크 구성
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)


# 버트 임베딩 구성
def construct_whole_bert_embeddings(input_ids, ref_input_ids, \
                                    token_type_ids=None, ref_token_type_ids=None, \
                                    position_ids=None, ref_position_ids=None):
    input_embeddings = model.module.bertmodel.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
    ref_input_embeddings = model.module.bertmodel.embeddings(ref_input_ids, token_type_ids=ref_token_type_ids, position_ids=ref_position_ids)
    
    return input_embeddings, ref_input_embeddings

In [14]:
# visualization