In [1]:
from torch import nn
from train.model.model import SensitiveClassifier
from ultils import *
import pickle

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
#Load tokenizer
with open('D:/AI Project/sensitive_filter/tokenizer.pickle', 'rb') as handle:
    tokenizer = pickle.load(handle)
    
model = SensitiveClassifier(n_classes=2)
model.load_state_dict(torch.load('./checkpoints/phobert_fold3.pth', map_location=torch.device(device)))

cuda:0


Some weights of the model checkpoint at vinai/phobert-base were not used when initializing RobertaModel: ['lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.decoder.bias', 'lm_head.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<All keys matched successfully>

In [5]:
def construct_input_and_baseline(text, tokenizer):
    max_length = 510
    baseline_token_id = tokenizer.pad_token_id 
    sep_token_id = tokenizer.sep_token_id 
    cls_token_id = tokenizer.cls_token_id 

    text_ids = tokenizer.encode(text, max_length=max_length, truncation=True, add_special_tokens=False)
   
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    token_list = tokenizer.convert_ids_to_tokens(input_ids)
  

    baseline_input_ids = [cls_token_id] + [baseline_token_id] * len(text_ids) + [sep_token_id]
    return torch.tensor([input_ids], device='cpu'), torch.tensor([baseline_input_ids], device='cpu'), token_list

def summarize_attributions(attributions):

    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    
    return attributions

def interpret_text(text, model, tokenizer, true_class = 'unknow'):

    input_ids, baseline_input_ids, all_tokens = construct_input_and_baseline(text, tokenizer)
# Define model output
    def model_output(inputs):
        return model(inputs)[0]

    # Define model input
    model_input = model.bert.embeddings
    lig = LayerIntegratedGradients(model_output, model_input)
    attributions, delta = lig.attribute(inputs= input_ids,
                                    baselines= baseline_input_ids,
                                    return_convergence_delta=True,
                                    internal_batch_size=1
                                    )
    attributions_sum = summarize_attributions(attributions)
    score_vis = viz.VisualizationDataRecord(
                        word_attributions = attributions_sum,
                        pred_prob = torch.max(model(input_ids)[0]),
                        pred_class = torch.argmax(model(input_ids)[0]).numpy(),
                        true_class = true_class,
                        attr_class = text,
                        attr_score = attributions_sum.sum(),       
                        raw_input_ids = all_tokens,
                        convergence_score = delta)
    viz.visualize_text([score_vis])


In [11]:
text = "ngu thì mất tiền thôi than lên đây làm gì"
true_class = 1
interpret_text(text, model, tokenizer)
text = "chứng khoán dạo này chán lắm, tất tay giờ là toang"
true_class = 0
interpret_text(text, model, tokenizer)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
unknow,1 (2.46),ngu thì mất tiền thôi than lên đây làm gì,-0.38,#s ngu thì mất tiền thôi than lên đây làm gì #/s
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
unknow,0 (3.59),chứng khoán dạo này chán lắm,1.38,#s chứng khoán dạo này chán lắm #/s
,,,,
