In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pickle
import json
import random
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from icecream import ic
from import_casa import casa
from matplotlib import pyplot as plt
from sklearn.metrics import classification_report, accuracy_score

In [3]:
data_dir = casa.get_data_path() / ""
with open(data_dir / f"threads/cht2021-JanMay-op-every20-attr.pkl", "rb") as fin:
    op_sample = pickle.load(fin)

In [4]:
label_list = ['B-VN', 'B-VP', 'I-VN', 'I-VP', 'O']

## Finding a threshold

In [5]:
norm_vec = []
prob_vec = []
data = []
ratings = []
preds = []
rating_map = {"Y": "Neutral", "G": "Positive", "R": "Negative"}
polarity_map = {0: "Neutral", 1: "Positive", 2: "Negative"}

rng = np.random.RandomState(1234)
for op_x in tqdm(op_sample):
    logits = getattr(op_x, "tok_logits", None)
    if logits is None:
        continue
    
    logits_n = logits[:, [0, 2]].sum(axis=-1)
    logits_p = logits[:, [1, 3]].sum(axis=-1)
    
    T = 4    
    norm = np.log(np.linalg.norm(np.vstack(np.exp([logits_p, logits_n])), axis=0))
    prob_pn = np.exp(logits_p/T) / (np.exp(logits_p/T)+np.exp(logits_n/T))
    norm_vec.extend(norm.tolist())
    prob_vec.extend(prob_pn.tolist())
    
    op_id = op_x.id
    entity = getattr(op_x, "cadet_entity", None)
    attr = getattr(op_x, "cadet_service", None)
    rating = rating_map.get(getattr(op_x, "sentence_sentiment", None))        
    polarity = getattr(op_x, "pred_polarity", None)
           
    if rating is not None and polarity is not None:
        ratings.append(rating)
        preds.append(polarity)
    
    entity = entity or "無特定業者"
    attr = attr or ("無特定類別", "")
    
    seq = []
    for ch_i, ch in enumerate(op_x.text):
        if ch_i >= 500: continue
        seq.append((ch, prob_pn[ch_i], norm[ch_i]))
    data.append((op_id, entity, attr, rating, polarity, seq))

HBox(children=(FloatProgress(value=0.0, max=5405.0), HTML(value='')))




In [6]:
print(classification_report(ratings, preds))

              precision    recall  f1-score   support

    Negative       0.37      0.49      0.42       698
     Neutral       0.86      0.76      0.81      4170
    Positive       0.35      0.54      0.42       463

    accuracy                           0.71      5331
   macro avg       0.53      0.60      0.55      5331
weighted avg       0.75      0.71      0.73      5331



In [7]:
len(op_sample), len(ratings)

(5405, 5331)

## find threshold

In [62]:
from io import StringIO

predictions = []
ratings = []
suffix = "sample-every20"
thres = np.quantile(norm_vec, [.9])[0]
hot_cm = plt.get_cmap("coolwarm")
def rgb2css(rgba, a=None):
    r, g, b, aval = rgba
    if not a:
        a = 0.5
    return f"rgba({int(r*255)}, {int(g*255)}, {int(b*255)}, {a})"

sio = StringIO()
sio.write("<!DOCTYPE HTML>\n")
sio.write("<html>\n")
sio.write("<head>\n")
sio.write("<title>cadence visualization</title>\n")
sio.write(f"""<style>
.seq-wrap {{
    margin: 1% 1%; padding: 1% 1%;
    display:block;
    border-left: 10px solid black;
}}
.text{{max-width: 80%; font-size: 16pt; line-height:125%}}
.supp{{margin-left: 1%; font-size: 12pt; align-self: flex-end}}
.pos {{border-color: {rgb2css(hot_cm(0.8), 1)}}}
.neg {{border-color: {rgb2css(hot_cm(0.2), 1)}}}
.unk {{border-color: {rgb2css(hot_cm(0.5), 1)}}}
.det {{margin: 1% 6pt 1% 0pt; background-color: #EEE;
    color: black; display: inline-block; padding: 2pt;
    font-size: 12pt; border-radius: 5px;}}
.badge {{margin: 1% 6pt 1% 0pt; background-color: #AAA;
    color: white; display: inline-block; padding: 2pt;
    font-size: 12pt;}}
.human-rate {{display: inline-block; background-color: #AAA}}
.model-rate {{display: inline-block; background-color: #AAA}}
.pos-rate {{background-color: {rgb2css(hot_cm(0.8), 1)}}}
.neg-rate {{background-color: {rgb2css(hot_cm(0.2), 1)}}}
</style>\n""")
sio.write("</head>\n")
sio.write("<body>\n")
sio.write("<div style='width: 90%; margin:auto'>")

counter = 0
for op_id, ent, attr, rating, pred, seq in tqdm(data):        
    seq_iter = filter(lambda tok: tok[2] > thres, seq)
    prob_pn = np.array(list(map(lambda tok: tok[1], seq_iter)))
    pred_polarity = "Neutral"
    if len(prob_pn) > 3 and not (ent.startswith("無特定") and \
        attr[0].startswith("無特定")):
        p_score = prob_pn[prob_pn > 0.9].sum()
        n_score = prob_pn[prob_pn < 0.1].sum()
        if p_score > n_score * 1:
            pred_polarity = "Positive"
        elif n_score > p_score * 1:
            pred_polarity = "Negative"
    if rating:
        predictions.append(pred_polarity)
        ratings.append(rating)
    
    if counter > 1000:
        continue
        
    if rating == "Positive":
        sio.write("<div class='seq-wrap pos'>\n")
    elif rating == "Negative":
        sio.write("<div class='seq-wrap neg'>\n")
    else:
        sio.write("<div class='seq-wrap unk'>\n")
        
    mark = "*" if rating != pred_polarity else ""
    sio.write("<div class='header'>")    
    sio.write(f"<div class='title'>{mark}{op_id}</div>")        
    sio.write(f"<div class='det entity'>{ent}</div>")
    sio.write(f"<div class='det attribute'>{attr[0]} &nbsp; {attr[1]}</div>")        
    sio.write(f"<div class='badge human-rate {rating[:3].lower()}-rate'>人工</div>")
    sio.write(f"<div class='badge model-rate {pred_polarity[:3].lower()}-rate'>模型</div>")
    sio.write("</div>")
    sio.write("<div class='text'>")
    scores = [0, 0]
    for tok in seq:
        if tok[2] > thres:            
            sio.write(f"<span style='background-color:")
            sio.write(f"{rgb2css(hot_cm(tok[1]))}'>")
        else:
            hval = 0
            sio.write(f"<span style=''>")
        sio.write(f"{tok[0]}</span>")
    sio.write("</div>")    
    sio.write("</div>\n")
    counter += 1
    
sio.write("<div> <!-- div.wrapper -->")
sio.write("</body>\n</html>\n")
with open(f"h:/cadence_visualization_{suffix}.html", "w", encoding="utf-8") as fout:
    fout.write(sio.getvalue())

HBox(children=(FloatProgress(value=0.0, max=5387.0), HTML(value='')))




In [59]:
print(classification_report(ratings, predictions))

              precision    recall  f1-score   support

    Negative       0.35      0.47      0.40       698
     Neutral       0.85      0.80      0.82      4170
    Positive       0.36      0.36      0.36       463

    accuracy                           0.72      5331
   macro avg       0.52      0.54      0.53      5331
weighted avg       0.74      0.72      0.73      5331

