In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pickle
import numpy as np
from import_casa import casa
from casa import caprice
from icecream import ic



In [3]:
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "Microsoft JhengHei"
# plt.rcParams["font.family"] = "Heiti TC"

In [4]:
with open("../../../data/caprice/seq_shapley_data_rev.pkl", "rb") as fin:
    data = pickle.load(fin)

In [5]:
with open("../../../data/caprice/pos_list.txt", "r") as fin:
    pos_list = fin.readlines()
pos_list = [x.strip() for x in pos_list]

In [6]:
n_correct = sum(x[1] == x[3] for x in data)
print("Correctly classified: ", n_correct)
print("All instances: ", len(data))
print("Accuracy: ", n_correct/len(data))

Correctly classified:  2429
All instances:  2518
Accuracy:  0.9646544876886418


In [7]:
data_x = data[0]
assert data_x[1] == data_x[3]

In [8]:
shap_data = data_x[2]

In [9]:
list(shap_data.keys())

['raw_tokens',
 'merged_tokens',
 'values',
 'group_sizes',
 'upper_values',
 'lower_values',
 'group_values',
 'max_values',
 'token_id_to_node_id_mapping',
 'collapsed_node_ids',
 'pos_probs']

In [10]:
type(np.zeros(2))

numpy.ndarray

In [11]:
len(data[10][2]["raw_tokens"]), data[10][2]["pos_probs"].shape

(15, (15, 80))

In [12]:
from collections import defaultdict
from dataclasses import dataclass
@dataclass
class ItemRecord:
    value: float
    freq: int
    pos: np.ndarray
        
merged_values = defaultdict(lambda: ItemRecord(0,0,None))

for data_x in data:
    if data_x[1] != data_x[3]:
        continue
    shap_data = data_x[2]
    tok2nd = shap_data["token_id_to_node_id_mapping"]
    group_values = shap_data["group_values"]
    raw_tokens = shap_data["raw_tokens"]    
    pos_probs = shap_data["pos_probs"]
    buf = ""
    pos_buf = []
    last_id = 0
    
    if data_x[1] == 1:
        polarity = -1
    elif data_x[1] == 2:
        polarity = 2
    else:
        continue
        
    for tok_id, nd_id in enumerate(tok2nd):    
        raw_tok = raw_tokens[tok_id]
        if last_id != nd_id:
            nd_value = polarity*group_values[int(last_id)]
            merged_values[buf].value += nd_value
            merged_values[buf].freq += 1          
            
            if buf:
                pos_vec = np.vstack(pos_buf)
            else:
                pos_vec = np.zeros((1, len(pos_list)), dtype=np.float32)
                
            if merged_values[buf].pos is not None:
                if merged_values[buf].pos.shape[0] != pos_vec.shape[0]:
                    breakpoint()
                merged_values[buf].pos += pos_vec
            else:
                merged_values[buf].pos = pos_vec
            pos_buf = []
            buf = ""        
                
        last_id = nd_id
        buf += raw_tok
        if buf:
            pos_buf.append(pos_probs[tok_id])
        
    if buf:
        merged_values[buf].value += polarity*nd_value
        merged_values[buf].freq += 1
        if merged_values[buf].pos is not None:
            merged_values[buf].pos += np.vstack(pos_buf)
        else:
            merged_values[buf].pos = np.vstack(pos_buf)
    

In [13]:
from collections import Counter
import re
func_mask = [int(re.match(r"N.+|V.+|.*CATEGORY", x) is None) for x in pos_list]
merged_pats = {}
pos_pats = {}
for pat, rec in merged_values.items():
    if not pat: continue
    merged_pats[pat] = rec.value / rec.freq
    pos_pats[pat] = rec.pos


In [14]:
[(x, merged_pats[x]) for x in sorted(merged_pats.keys(), key=merged_pats.get, reverse=True)][:10]

[('推一個中華電信', 15.741275991218721),
 ('中華電信好棒', 15.526976172301513),
 ('五g只信亞太', 14.780610259067839),
 ('推薦中華', 14.548671594675053),
 ('中華好', 14.419612054633166),
 ('比較好👍', 13.759743712127339),
 ('還好中華', 13.325444134881362),
 ('看好中華', 12.632075464269988),
 ('只推台哥', 12.459128194538861),
 ('比較期待中華', 12.121670681717855)]

In [15]:
[(x, merged_pats[x]) for x in sorted(merged_pats.keys(), key=merged_pats.get, reverse=False)][:10]

[('在北大武三角', -8.82558056486867),
 ('剛剛中華出問題', -8.240402194087654),
 ('台哥大日常斷線', -8.08578173037489),
 ('大真的夠爛', -8.019043184761369),
 ('台哥大靠限速就', -7.922424416351588),
 ('中華電信不考慮', -7.821020848328299),
 ('中華學生方案', -7.715885746963345),
 ('中華明顯變慢', -7.700923131146217),
 ('台哥-7 呵', -7.670989988222811),
 ('很快就爆了699收費太貴...', -7.652803886230502)]

In [16]:
pos_weights = {}
for pat, pos_probs in pos_pats.items():
    pos_weights[pat] = (pos_probs * np.array(func_mask)).mean(axis=0).sum()
    

In [17]:
from DistilTag import DistilTag
tagger = DistilTag()

In [45]:
tagger.print_soft_tag(*tagger.soft_tag("這一個比較好"))

這_0.50 Nep_0.42/ Dk_0.04/Nes_0.03/Cbb_0.02/SHI_0.02
一_0.51 Neu_0.63/Cbb_0.02/ Nb_0.02/ FW_0.02/ Nd_0.01
個_0.48  Nf_0.41/ Nh_0.04/ Na_0.03/ Nc_0.03/ Nd_0.02
比_0.59 Dfa_0.57/ VC_0.03/ Na_0.02/V_2_0.02/ Nv_0.02
較_0.39 Dfa_0.64/ VH_0.02/ VC_0.02/ Nf_0.02/ Nv_0.01
好_0.49  VH_0.59/ VL_0.03/ VC_0.03/ VA_0.02/ Nv_0.02



In [47]:
re.match("N[^ef]+", "Neu")

In [62]:
V = [int(re.match(r"V.+", x) is not None) for x in pos_list]
D = [int(re.match(r"D.+", x) is not None) for x in pos_list]
NDet = [int(re.match(r"Ne.|Nf", x) is not None) for x in pos_list]
NN = [int(re.match(r"N[^ef]+", x) is not None) for x in pos_list]
def make_mask(*masks):
    cons_mask = np.vstack([*masks]).astype(np.float32)
    cons_mask /= cons_mask.sum()
    return cons_mask
cons_masks = {
    "DDVV": make_mask([D, D, V, V]),
    "DDV": make_mask([D, D, V]),
    "DVV": make_mask([D, V, V]),
    "Nd2N": make_mask([NDet, NDet, NN]),
    "NdN2": make_mask([NDet, NN, NN])}

In [77]:
def apply_cross_product(pos_mat, cons_mask):
    mask_M = cons_mask.shape[0]
    res_M = pos_mat.shape[0]-mask_M+1
    if res_M <= 0:
        return np.zeros((1, 1), dtype=np.float32)
    res_mat = np.zeros(shape=(res_M, 1))
    for i in range(res_M):
        res_mat[i] = (pos_mat[i:i+mask_M] * cons_mask).sum()
    return res_mat

In [78]:
def apply_masks(pos_mat, masks):
    scores = np.zeros(len(masks), dtype=np.float32)
    for mask_i, mask_x in enumerate(masks):
        scores[mask_i] = apply_cross_product(pos_mat, mask_x).max()
    return scores

In [79]:
apply_masks(pos_pats["比較期待中華"], cons_masks.values())

array([0.06581572, 0.07206085, 0.06067116, 0.03702857, 0.06047242],
      dtype=float32)

In [80]:
len(merged_pats)

3297

## Shapley values quantiles

In [114]:
np.quantile(np.abs([x for pat, x in merged_pats.items() if 3<=len(pat)<=10]), np.arange(0,1,0.1))

array([1.85621582e-03, 6.86807174e-02, 1.50942372e-01, 2.40439319e-01,
       3.49729472e-01, 4.96462432e-01, 7.58105585e-01, 1.17797188e+00,
       2.38808450e+00, 5.77778638e+00])

## Pattern weights distribution

In [93]:
pat_weights_1 = {}
pat_weights_2 = {}
for pat in merged_pats:
    pol_value = merged_pats[pat]
    pos_score = pos_weights[pat].mean()    
    # pat_weights_1[pat] = pol_value/abs(pol_value) * (len(pat)*20 + pos_score*.5 + abs(pol_value)*.5)
    pat_weights_1[pat] = pol_value/abs(pol_value) * (len(pat)*20 + abs(pol_value))
    pat_weights_2[pat] = (pol_value/abs(pol_value) * 
                          (apply_masks(pos_pats[pat], cons_masks.values()).max()))

In [103]:
np.abs(list(pat_weights_2.values()))

array([0.        , 0.        , 0.        , ..., 0.04639961, 0.05843211,
       0.        ])

In [107]:
np.quantile(np.abs([x for pat, x in pat_weights_2.items() if 3<=len(pat)<=10]), np.arange(0,1,0.1))

array([0.        , 0.        , 0.        , 0.        , 0.        ,
       0.02019217, 0.03824183, 0.05375126, 0.06033709, 0.0606324 ])

In [109]:
def sort_dict(dict_x, key_func, reverse=True, topn=10):
    return [(x, dict_x[x]) for x in sorted(dict_x.keys(), key=key_func, reverse=reverse)][:topn]

In [110]:
with open("/tmp/pat.txt", "w", encoding="UTF-8") as fout:
# with open("h:/pat.txt", "w", encoding="UTF-8") as fout:
    for x in sort_dict(merged_pats, 
                       key_func=lambda x: (pat_weights_1.get(x)), 
                       reverse=True, topn=-1):
        if not (3 <= len(x[0]) <= 10): continue
        if abs(x[1]) <= 2: continue
        fout.write(f"{x[0]}, {pos_weights[x[0]].sum():.4f}, {pat_weights_1[x[0]]:.4f}")
        fout.write("\n")

In [117]:
with open("/tmp/pat2.txt", "w", encoding="UTF-8") as fout:
# with open("h:/pat.txt", "w", encoding="UTF-8") as fout:
    for x in sort_dict(merged_pats, 
                       key_func=lambda x: (pat_weights_2.get(x)), 
                       reverse=True, topn=-1):
        if not (3 <= len(x[0]) <= 10): continue
        weight = pat_weights_2[x[0]]
        if abs(x[1]) <= 0.496: continue
        if abs(weight) <= .0202: continue
        fout.write(f"{x[0]};{x[1]:.4f};{weight:.4f}")
        fout.write("\n")