In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pickle
import numpy as np
from import_casa import casa
from casa import caprice
from icecream import ic

In [3]:
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "Microsoft JhengHei"
# plt.rcParams["font.family"] = "Heiti TC"

In [4]:
with open("../../../data/caprice/seq_shapley_data_rev.pkl", "rb") as fin:
    data = pickle.load(fin)

In [5]:
with open("../../../data/caprice/pos_list.txt", "r") as fin:
    pos_list = fin.readlines()
pos_list = [x.strip() for x in pos_list]

In [6]:
n_correct = sum(x[1] == x[3] for x in data)
print("Correctly classified: ", n_correct)
print("All instances: ", len(data))
print("Accuracy: ", n_correct/len(data))

Correctly classified:  2429
All instances:  2518
Accuracy:  0.9646544876886418


In [7]:
data_x = data[0]
assert data_x[1] == data_x[3]

In [8]:
shap_data = data_x[2]

In [9]:
list(shap_data.keys())

['raw_tokens',
 'merged_tokens',
 'values',
 'group_sizes',
 'upper_values',
 'lower_values',
 'group_values',
 'max_values',
 'token_id_to_node_id_mapping',
 'collapsed_node_ids',
 'pos_probs']

In [10]:
type(np.zeros(2))

numpy.ndarray

In [11]:
len(data[10][2]["raw_tokens"]), data[10][2]["pos_probs"].shape

(15, (15, 80))

In [12]:
from collections import defaultdict
from dataclasses import dataclass
@dataclass
class ItemRecord:
    value: float
    freq: int
    pos: np.ndarray
        
merged_values = defaultdict(lambda: ItemRecord(0,0,None))

for data_x in data:
    if data_x[1] != data_x[3]:
        continue
    shap_data = data_x[2]
    tok2nd = shap_data["token_id_to_node_id_mapping"]
    group_values = shap_data["group_values"]
    raw_tokens = shap_data["raw_tokens"]    
    pos_probs = shap_data["pos_probs"]
    buf = ""
    pos_buf = []
    last_id = 0
    
    if data_x[1] == 1:
        polarity = -1
    elif data_x[1] == 2:
        polarity = 2
    else:
        continue
        
    for tok_id, nd_id in enumerate(tok2nd):    
        raw_tok = raw_tokens[tok_id]
        if last_id != nd_id:
            nd_value = polarity*group_values[int(last_id)]
            merged_values[buf].value += nd_value
            merged_values[buf].freq += 1          
            
            if buf:
                pos_vec = np.vstack(pos_buf)
            else:
                pos_vec = np.zeros((1, len(pos_list)), dtype=np.float32)
                
            if merged_values[buf].pos is not None:
                if merged_values[buf].pos.shape[0] != pos_vec.shape[0]:
                    breakpoint()
                merged_values[buf].pos += pos_vec
            else:
                merged_values[buf].pos = pos_vec
            pos_buf = []
            buf = ""        
                
        last_id = nd_id
        buf += raw_tok
        if buf:
            pos_buf.append(pos_probs[tok_id])
        
    if buf:
        merged_values[buf].value += polarity*nd_value
        merged_values[buf].freq += 1
        if merged_values[buf].pos is not None:
            merged_values[buf].pos += np.vstack(pos_buf)
        else:
            merged_values[buf].pos = np.vstack(pos_buf)
    

In [13]:
from collections import Counter
import re
func_mask = [int(re.match(r"N.+|V.+|.*CATEGORY", x) is None) for x in pos_list]
merged_pats = {}
pos_pats = {}
for pat, rec in merged_values.items():
    if not pat: continue
    merged_pats[pat] = rec.value / rec.freq
    pos_pats[pat] = rec.pos


In [14]:
[(x, merged_pats[x]) for x in sorted(merged_pats.keys(), key=merged_pats.get, reverse=True)][:10]

[('Êé®‰∏ÄÂÄã‰∏≠ËèØÈõª‰ø°', 15.741275991218721),
 ('‰∏≠ËèØÈõª‰ø°Â•ΩÊ£í', 15.526976172301513),
 ('‰∫îgÂè™‰ø°‰∫ûÂ§™', 14.780610259067839),
 ('Êé®Ëñ¶‰∏≠ËèØ', 14.548671594675053),
 ('‰∏≠ËèØÂ•Ω', 14.419612054633166),
 ('ÊØîËºÉÂ•Ωüëç', 13.759743712127339),
 ('ÈÇÑÂ•Ω‰∏≠ËèØ', 13.325444134881362),
 ('ÁúãÂ•Ω‰∏≠ËèØ', 12.632075464269988),
 ('Âè™Êé®Âè∞Âì•', 12.459128194538861),
 ('ÊØîËºÉÊúüÂæÖ‰∏≠ËèØ', 12.121670681717855)]

In [15]:
[(x, merged_pats[x]) for x in sorted(merged_pats.keys(), key=merged_pats.get, reverse=False)][:10]

[('Âú®ÂåóÂ§ßÊ≠¶‰∏âËßí', -8.82558056486867),
 ('ÂâõÂâõ‰∏≠ËèØÂá∫ÂïèÈ°å', -8.240402194087654),
 ('Âè∞Âì•Â§ßÊó•Â∏∏Êñ∑Á∑ö', -8.08578173037489),
 ('Â§ßÁúüÁöÑÂ§†Áàõ', -8.019043184761369),
 ('Âè∞Âì•Â§ßÈù†ÈôêÈÄüÂ∞±', -7.922424416351588),
 ('‰∏≠ËèØÈõª‰ø°‰∏çËÄÉÊÖÆ', -7.821020848328299),
 ('‰∏≠ËèØÂ≠∏ÁîüÊñπÊ°à', -7.715885746963345),
 ('‰∏≠ËèØÊòéÈ°ØËÆäÊÖ¢', -7.700923131146217),
 ('Âè∞Âì•-7 Âëµ', -7.670989988222811),
 ('ÂæàÂø´Â∞±ÁàÜ‰∫Ü699Êî∂Ë≤ªÂ§™Ë≤¥...', -7.652803886230502)]

In [16]:
pos_weights = {}
for pat, pos_probs in pos_pats.items():
    pos_weights[pat] = (pos_probs * np.array(func_mask)).mean(axis=0).sum()
    

In [23]:
from DistilTag import DistilTag
tagger = DistilTag()

In [27]:
tagger.print_soft_tag(*tagger.soft_tag("ÊØîËºÉÂ•Ω"))

ÊØî_0.59 Dfa_0.58/ VC_0.03/ Na_0.02/ VH_0.02/ Nv_0.02
ËºÉ_0.37 Dfa_0.65/ VH_0.03/ VC_0.02/ Nf_0.01/ Nv_0.01
Â•Ω_0.49  VH_0.56/ VL_0.05/ VC_0.02/Dfa_0.02/ VA_0.02



In [29]:
V_mask = [int(re.match(r"V.+", x) is not None) for x in pos_list]
D_mask = [int(re.match(r"D.+", x) is not None) for x in pos_list]
cons_mask_1 = np.vstack([D_mask, D_mask, V_mask])
cons_mask_1.shape

(3, 80)

In [31]:
def apply_cross_product(pos_mat, cons_mask):
    res_M = pos_mat.shape[0]-cons_mask.shape[0]+1
    res_mat = np.zeros(res_M, 1)
    for i in range(res_M):
        res_mat[i] = res_M
        

AttributeError: module 'numpy' has no attribute 'conv2'

In [30]:
len(merged_pats)

3297

In [18]:
pat_weights_1 = {}
for pat in merged_pats:
    pol_value = merged_pats[pat]
    pos_score = pos_weights[pat].mean()    
    # pat_weights_1[pat] = pol_value/abs(pol_value) * (len(pat)*20 + pos_score*.5 + abs(pol_value)*.5)
    pat_weights_1[pat] = pol_value/abs(pol_value) * (len(pat)*20 + abs(pol_value))

In [19]:
def sort_dict(dict_x, key_func, reverse=True, topn=10):
    return [(x, dict_x[x]) for x in sorted(dict_x.keys(), key=key_func, reverse=reverse)][:topn]

In [22]:
with open("h:/pat.txt", "w", encoding="UTF-8") as fout:
    for x in sort_dict(merged_pats, 
                       key_func=lambda x: (pat_weights_1.get(x)), 
                       reverse=True, topn=-1):
        if not (3 <= len(x[0]) <= 10): continue
        if abs(x[1]) <= 2: continue
        fout.write(f"{x[0]}, {pos_weights[x[0]].sum():.4f}, {pat_weights_1[x[0]]:.4f}")
        fout.write("\n")