In [None]:
%load_ext autoreload
%autoreload 2

In [2]:
import pickle
import numpy as np
from import_casa import casa
from casa import caprice
from icecream import ic

In [3]:
with open("../../../data/caprice/caprice_seq_data_20210430.pkl", "rb") as fin:
    data = pickle.load(fin)

In [4]:
import csv
with open("../../../data/caprice/sentiment-constructicon.csv", "r", encoding="UTF-8") as fin:
    fin.readline()
    csvwriter = csv.reader(fin)
    constructions = [(x[0], float(x[1])) for x in csvwriter if x[1]]
len(constructions)

39

In [5]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
model_path = "../../../data/caprice/seq-model-ep9"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)

In [6]:
from matplotlib import font_manager

In [7]:
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "Microsoft JhengHei"
# plt.rcParams["font.family"] = "Heiti TC"

In [8]:
import shap

In [9]:
from transformers import pipeline
caprice_pipeline = pipeline('sentiment-analysis', model=model, tokenizer=tokenizer, return_all_scores=True)

In [10]:
explainer = shap.Explainer(caprice_pipeline, algorithm="partition")

In [11]:
_masker = explainer.masker

In [12]:
import torch
from DistilTag import DistilTag
tagger = DistilTag()

In [13]:
tagger.tag("我才不要去辦吃到飽")

[[('我', 'Nh'),
  ('才', 'Da'),
  ('不要', 'D'),
  ('去', 'D'),
  ('辦', 'VC'),
  ('吃', 'VC'),
  ('到', 'P'),
  ('飽', 'VH')]]

In [14]:
out = tagger.soft_tag("我才不要去辦吃到飽")

In [15]:
len(tagger.pos_list)

80

In [16]:
out[1][0].shape

torch.Size([9, 80])

## Try hierarchical shapley

In [17]:
data[0]

('台星的態度就是在等宿主台哥，逸以待勞，準備寄生。', 1)

In [18]:
shapley_data = []

In [19]:
shap_values = explainer(["台星的態度就是在等宿主台哥，逸以待勞，準備寄生。"], fixed_context=0, max_evals=200)

In [20]:
import re
def find_construction(intxt, constructions):
    for cons, _ in constructions:
        mat = re.search(cons, intxt)
        if mat:
            return cons        
    return ""

In [26]:
re.search(r"\b\w+秀下限\w+\b", "他只是秀下限！")

In [28]:
find_construction("他只是秀下限吧", constructions)

'\\b\\w+秀下限\\w+\\b'

In [29]:
from types import MethodType
from functools import partial
from tqdm.auto import tqdm

for txt, label in tqdm(data):
    if len(txt) > 100: continue
    pat = find_construction(txt, constructions)
    if not pat:
        continue
        
    clustering_wrapper = partial(caprice.custom_clustering, pat=pat, tagger_inst=tagger)
    explainer.masker.clustering = MethodType(clustering_wrapper, explainer.masker)
    try:
        shap_values = explainer([txt], fixed_context=0, max_evals=200)
    except Exception as ex:
        print(ex)
        continue
    pos_probs = explainer.masker.pos_probs
    probs = [x["score"] for x in caprice_pipeline(txt)[0]]
    max_label = np.argmax(probs)
    proc_vals = caprice.process_shap_values(shap_values[0, :, max_label])
    proc_vals["pos_probs"] = pos_probs
    
    shapley_data.append((
        txt, label, proc_vals, max_label
    ))

HBox(children=(FloatProgress(value=0.0, max=2795.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.06s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

index 14 is out of bounds for axis 0 with size 14


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:11,  5.77s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:10,  5.20s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

Partition explainer: 2it [00:11,  5.66s/it]                                                                            


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))




In [30]:
with open("seq_shapley_data_cons.pkl", "wb") as fout:
    pickle.dump(shapley_data, fout)

In [31]:
len(shapley_data)

103