In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pickle
import numpy as np
from import_casa import casa
from casa import caprice
from icecream import ic

In [3]:
with open("../../../data/caprice/caprice_seq_data_20210430.pkl", "rb") as fin:
    data = pickle.load(fin)

In [4]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
model_path = "../../../data/caprice/seq-model-ep9"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)

In [5]:
from matplotlib import font_manager

In [6]:
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "Microsoft JhengHei"
# plt.rcParams["font.family"] = "Heiti TC"

In [7]:
import shap

In [8]:
from transformers import pipeline
caprice_pipeline = pipeline('sentiment-analysis', model=model, tokenizer=tokenizer, return_all_scores=True)

In [9]:
explainer = shap.Explainer(caprice_pipeline, algorithm="partition")

In [10]:
_masker = explainer.masker

In [11]:
import torch
from DistilTag import DistilTag
tagger = DistilTag()

In [12]:
tagger.tag("我才不要去辦吃到飽")

[[('我', 'Nh'),
  ('才', 'Da'),
  ('不要', 'D'),
  ('去', 'D'),
  ('辦', 'VC'),
  ('吃', 'VC'),
  ('到', 'P'),
  ('飽', 'VH')]]

In [13]:
out = tagger.soft_tag("我才不要去辦吃到飽")

In [14]:
len(tagger.pos_list)

80

In [15]:
out[1][0].shape

torch.Size([9, 80])

## Try hierarchical shapley

In [16]:
data[0]

('台星的態度就是在等宿主台哥，逸以待勞，準備寄生。', 1)

In [17]:
shapley_data = []

In [22]:
shap_values = explainer(["台星的態度就是在等宿主台哥，逸以待勞，準備寄生。"], fixed_context=0, max_evals=1000)

In [25]:
from types import MethodType
from functools import partial
from tqdm.auto import tqdm

clustering_wrapper = partial(caprice.custom_clustering, pat="", tagger_inst=tagger)
explainer.masker.clustering = MethodType(clustering_wrapper, explainer.masker)
for txt, label in tqdm(data):
    if len(txt) > 100: continue
    try:
        shap_values = explainer([txt], fixed_context=0, max_evals=200)
    except Exception as ex:
        print(ex)
        continue
    pos_probs = explainer.masker.pos_probs
    probs = [x["score"] for x in caprice_pipeline(txt)[0]]
    max_label = np.argmax(probs)
    proc_vals = caprice.process_shap_values(shap_values[0, :, max_label])
    proc_vals["pos_probs"] = pos_probs
    
    shapley_data.append((
        txt, label, proc_vals, max_label
    ))

HBox(children=(FloatProgress(value=0.0, max=2795.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:11,  5.63s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.20s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.48s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.43s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:11,  5.59s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.08s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.35s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:11,  5.58s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:11,  5.53s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.06s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:12,  6.15s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.46s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.20s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:12,  6.00s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.17s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.00s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:11,  5.90s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:11,  5.73s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:11,  5.91s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.20s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:12,  6.02s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:11,  5.89s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:12,  6.35s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.09s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.15s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.28s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:11,  5.93s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.01s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.31s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

list index out of range


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.46s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:11,  5.54s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:11,  5.63s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:11,  5.57s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.19s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.29s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.39s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.05s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

list index out of range


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:11,  5.81s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.22s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.28s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.33s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.32s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:11,  5.65s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.22s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.21s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:11,  5.68s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.33s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

list index out of range


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.38s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:11,  5.91s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:11,  5.72s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:12,  6.17s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.46s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.30s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:12,  6.09s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.33s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.15s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.24s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.20s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:11,  5.97s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.36s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.15s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:11,  5.72s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.33s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.17s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))




In [26]:
with open("seq_shapley_data.pkl", "wb") as fout:
    pickle.dump(shapley_data, fout)

In [29]:
[i for i, x in enumerate(shapley_data) if x[0] == shapley_data[0][0]]

[0, 22, 104]

In [31]:
print(len(shapley_data))
with open("seq_shapley_data_rev.pkl", "wb") as fout:
    pickle.dump(shapley_data[104:], fout)

2622


In [29]:
shap_values.data

array([['', '他', '貴', '到', '一', '個', '爆', '炸', '']], dtype='<U1')

In [25]:
proc_vals["pos_probs"] = pos_probs

In [28]:
with open("pos_list.txt", "w") as fout:
    fout.write("\n".join(tagger.pos_list))

In [23]:
group_values = proc_vals["group_values"]
for node_i, node_x in enumerate(clust_nodes):
    leaf_ids = node_x.pre_order(lambda x: x.id)
    print("% .2e" % group_values[node_i], "".join(ex_tokens[x] for x in sorted(leaf_ids)))

-4.75e-02 
-4.75e-02 他
 5.15e-01 貴
 2.57e-01 到
 1.29e-01 一
 1.29e-01 個
 3.42e-02 爆
 3.42e-02 炸
 1.52e-07 
 2.57e-01 一個
 5.15e-01 到一個
 6.85e-02 爆炸
 1.03e+00 貴到一個
 1.10e+00 貴到一個爆炸
-9.50e-02 他
 1.00e+00 他貴到一個爆炸
 1.00e+00 他貴到一個爆炸


In [24]:
shap.plots.text(shap_values[0, :, max_label])