In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pickle
import json
import random
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from icecream import ic
from import_casa import casa
from matplotlib import pyplot as plt
from casa import Cadence, Cadet, Crystal, MTBert
from sklearn.metrics import classification_report, accuracy_score

In [3]:
cadence = Cadence.load("../../data/cadence/config.json")

[INFO] 2021-09-01 12:29:57,434 casa.Cadence: Loading Cadet
[INFO] 2021-09-01 12:29:57,447 gensim.utils: loading KeyedVectors object from ..\..\data\cadence\..\cadet\op20.3\ft-2020.kv
[INFO] 2021-09-01 12:29:57,729 gensim.utils: setting ignored attribute vectors_norm to None
[INFO] 2021-09-01 12:29:57,730 gensim.utils: setting ignored attribute vectors_vocab_norm to None
[INFO] 2021-09-01 12:29:57,731 gensim.utils: setting ignored attribute vectors_ngrams_norm to None
[INFO] 2021-09-01 12:29:57,731 gensim.utils: setting ignored attribute buckets_word to None
[INFO] 2021-09-01 12:29:57,846 gensim.utils: FastTextKeyedVectors lifecycle event {'fname': '..\\..\\data\\cadence\\..\\cadet\\op20.3\\ft-2020.kv', 'datetime': '2021-09-01T12:29:57.846572', 'gensim': '4.0.0', 'python': '3.8.2 (tags/v3.8.2:7b3ab59, Feb 25 2020, 23:03:10) [MSC v.1916 64 bit (AMD64)]', 'platform': 'Windows-10-10.0.18362-SP0', 'event': 'loaded'}
[INFO] 2021-09-01 12:29:57,872 casa.Cadence: Loading Crystal
[INFO] 2021-09

In [4]:
from casa.cadence.resolvers import CadenceBertOnlyResolver, CadenceSimpleResolver, CadenceMultiResolver
from casa.cadence import visualize_tokens

In [94]:
out = cadence.analyze("亞太網路超差，中華收訊就很好", strategy="simple")

In [95]:
visualize_tokens(out, 0.005)

亞太網路[31m超[0m[31m差[0m，中華[31m收[0m訊[31m就[0m[31m很[0m[31m好[0m


{'pn_prob': array([3.1391025e-04, 4.5604211e-05, 3.2047320e-03, 3.9974255e-03,
        9.9427634e-01, 9.7850168e-01, 3.3351569e-04, 2.8090037e-05,
        2.4490140e-05, 5.6001269e-03, 6.7131169e-04, 2.3341034e-01,
        2.5550574e-01, 1.6362540e-01], dtype=float32),
 'pn_idx': array([-1, -1, -1, -1,  1,  1, -1, -1, -1,  1, -1,  1,  1,  1],
       dtype=int64)}

In [96]:
out.crystal

{'result': ('[通訊]網速', 1.0),
 'word_attr_map': {'超差': ('[通訊]網速', 1, 0.5), '很好': ('[通訊]涵蓋', 5, 0.3)},
 'CxG': [],
 'onto': [('超差', [('[通訊]網速', 1, 0.5), ('[通訊]涵蓋', 1, 0.5)]),
  ('很好',
   [('[其他]手機', 4, 0.3),
    ('[通訊]網速', 4, 0.3),
    ('[通訊]涵蓋', 5, 0.3),
    ('[通訊]涵蓋', 4, 0.1)])]}

In [97]:
out.cadet

{'entity': ['中華電信', '亞太電信', '台灣大哥大', '遠傳電信', '台灣之星', '無框行動'],
 'entity_probs': array([0.48782193, 0.48782193, 0.00635621, 0.00620852, 0.00619568,
        0.00559573]),
 'service': [('通訊品質', '網速'),
  ('通訊品質', '涵蓋'),
  ('資費方案', '低資費方案'),
  ('加值服務', 'vowifi'),
  ('加值服務', '電信APP')],
 'service_probs': array([0.44906384, 0.44906384, 0.00546045, 0.00518785, 0.0051812 ]),
 'seeds': ['網速', '覆蓋率', '訊號', '0月租', '免月租'],
 'seed_probs': array([0.3714933 , 0.3714933 , 0.00598352, 0.00451722, 0.00451722]),
 'tokens': ['亞太', '網路', '超', '差,', '中華', '收訊', '就', '很好'],
 'tokens_attrib': {'亞太電信': [0], '網速': [1], '中華電信': [4], '覆蓋率': [5]}}

In [98]:
def find_all_pos(text, target, start=0):
    try:
        pos = text.index(target, start)
        return [pos] + find_all_pos(text, target, start=pos+1)
    except ValueError:
        return []

In [101]:
from itertools import groupby
ent_tokens = {}
srv_tokens = {}
pol_tokens = {}
cadet_res = out.cadet
crystal_res = out.crystal
mtbert_res = out.mt_bert

raw_text = mtbert_res.get("text", "")
cadet_tokens = cadet_res.get("tokens", [])

# get entity tokens
for attrib, tok_idxs in cadet_res.get("tokens_attrib", {}).items():
    if attrib not in cadet_res.get("entity", []):
        # not an entity attribute
        continue
    for tok_idx in tok_idxs:
        tok = cadet_tokens[tok_idx]
        pos_list = find_all_pos(raw_text, tok)
        ent_tokens.setdefault(attrib, []).extend(pos_list)

# get service tokens
# use crystal if available
word_attr_map = crystal_res["word_attr_map"]
for word, attr in word_attr_map.items():
    indices = find_all_pos(raw_text, word)
    srv_tokens.setdefault(attr[0], []).extend(indices)
    pol_score = attr[1]
    if pol_score > 3:
        pol_tokens.setdefault("Positive", []).extend(indices)
    elif pol_score < 3:
        pol_tokens.setdefault("Negative", []).extend(indices)
        
# if crystal is abstained, use cadet service tokens
if not srv_tokens:
    for attrib, tok_idxs in cadet_res.get("tokens_attrib", {}).items():
        if attrib in cadet_res.get("entity", []):
            # skip entity attribute
            continue
        for tok_idx in tok_idxs:
            tok = cadet_tokens[tok_idx]
            pos_list = find_all_pos(raw_text, tok)
            srv_tokens.setdefault(attrib, []).extend(pos_list)

if not pol_tokens:
    pn = visualize_tokens(out, pn_thres=0.2, quiet=True)
    pn_idx = pn["pn_idx"]    
    grp_iter = groupby(enumerate(pn_idx), key=lambda x: x[1])
    groups = [(gk, [idx for idx, _ in gv]) for gk, gv in grp_iter]    
    for pn_code, idx_list in groups:
        if pn_code < 0: continue
        pn = "Positive" if pn_code == 0 else "Negative"
        first_idx = idx_list[0]
        if pn not in pol_tokens or pol_tokens[pn][-1] < idx-1:
            pol_tokens.setdefault(pn, []).append(first_idx)

In [102]:
print("raw_text: ", raw_text)
print("entity: ", ent_tokens)
print("services: ", srv_tokens)
print("polarities: ", pol_tokens)

raw_text:  亞太網路超差，中華收訊就很好
entity:  {'亞太電信': [0], '中華電信': [7]}
services:  {'[通訊]網速': [4], '[通訊]涵蓋': [12]}
polarities:  {'Negative': [4], 'Positive': [12]}


In [103]:
ch_labels = {}
def update_ch_labels(new_dict):
    for k, v in new_dict.items():
        ch_labels.setdefault(k, []).append(v)

def index_positions(labtype, label_map):
    return {v:(labtype, k) 
            for k, vs in label_map.items() 
            for v in vs}

update_ch_labels(index_positions("ent", ent_tokens))
update_ch_labels(index_positions("srv", srv_tokens))
update_ch_labels(index_positions("pol", pol_tokens))

In [104]:
ch_labels

{0: [('ent', '亞太電信')],
 7: [('ent', '中華電信')],
 4: [('srv', '[通訊]網速'), ('pol', 'Negative')],
 12: [('srv', '[通訊]涵蓋'), ('pol', 'Positive')]}

In [107]:
buf = {}
aspects = []
for ch_i, ch_x in enumerate(raw_text):
    if ch_i not in ch_labels:
        continue
    for label, data in ch_labels[ch_i]:
        buf[label] = data
    if "ent" in buf and "srv" in buf and "pol" in buf:
        aspects.append((buf["ent"], buf["srv"], buf["pol"]))
        del buf["pol"]


In [108]:
aspects

[('亞太電信', '[通訊]網速', 'Negative'), ('中華電信', '[通訊]涵蓋', 'Positive')]