In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pickle
import random
import numpy as np
from import_casa import casa
from casa.cadet import get_thread_aspect


In [3]:
with (casa.get_data_path()/"threads/cht-2020-merged-cadet-op20.pkl").open("rb") as fin:
    threads = pickle.load(fin)

In [4]:
len(threads)

56064

In [5]:
threads_aspects = []
for idx, thread_x in enumerate(threads):
    aspect_x = (idx,
        *get_thread_aspect(thread_x))
    threads_aspects.append(aspect_x)

In [6]:
len(threads_aspects)

56064

In [7]:
from collections import Counter
ent_probs = Counter(x[1] for x in threads_aspects)
ent_probs = {k: round(v/sum(ent_probs.values()), 2) for k, v in ent_probs.items()}
aspect_probs = Counter(x[2] for x in threads_aspects)
aspect_probs = {k: round(v/sum(aspect_probs.values()), 2) for k, v in aspect_probs.items()}
print(ent_probs)
print(aspect_probs)

{'中華電信': 0.29, '台灣大哥大': 0.13, None: 0.19, '遠傳電信': 0.13, '其他電信': 0.01, '台灣之星': 0.16, '亞太電信': 0.08}
{'其他': 0.05, '資費方案': 0.49, '通訊品質': 0.26, None: 0.07, '加值服務': 0.13}


電信：中華電信(80%)、其他電信等比例 (20%)
屬性：等比例抽取 (資費)

In [8]:
ent_sample_probs = ent_probs
ent_sample_probs["中華電信"] = 0.8
Z_ent_chtx = sum(ent_probs.values()) - ent_probs["中華電信"]
for ent_x in ent_probs.keys():
    if ent_x == "中華電信": continue
    ent_sample_probs[ent_x] = (ent_probs[ent_x]/Z_ent_chtx) * .2
ent_sample_probs

{'中華電信': 0.8,
 '台灣大哥大': 0.037142857142857144,
 None: 0.0542857142857143,
 '遠傳電信': 0.037142857142857144,
 '其他電信': 0.0028571428571428576,
 '台灣之星': 0.04571428571428572,
 '亞太電信': 0.02285714285714286}

## Build Samples

In [9]:
import random
rng = random.Random(12345)

In [10]:
n_batch, batch_size = 20, 1000
N = batch_size * n_batch
thread_idxs = list(range(len(threads)))
rng.shuffle(thread_idxs)

In [11]:
sample_ents = np.repeat(list(ent_probs.keys()), [int(x*N+1) for x in ent_probs.values()])
sample_ents = sample_ents[:N]
sample_ents_freq = Counter(sample_ents)
sample_attrs = []
for ent_x, ent_freq in sample_ents_freq.items():
    attrs = np.repeat(list(aspect_probs.keys()), [int(x*ent_freq+1) for x in aspect_probs.values()])
    sample_attrs.extend(attrs[:ent_freq])

thread_sample = []
sample_scheme = Counter(zip(sample_ents, sample_attrs))
sample_scheme_ori = sample_scheme.copy()
for tidx in thread_idxs:
    _, tent, tattr, _, _ = threads_aspects[tidx]
    if sample_scheme[(tent,tattr)] > 0:
        thread_sample.append((tent, tattr, tidx))
        sample_scheme[(tent,tattr)] -= 1

if sum(sample_scheme.values()) > 0:
    for (ent, att), freq in sample_scheme.items():
        if freq == 0: continue
        ori_freq = sample_scheme_ori[(ent, att)]
        print((ent, att), "cannot be fully sampled. Expected/Actual: ", 
              ori_freq, ori_freq-freq)


    

('中華電信', '資費方案') cannot be fully sampled. Expected/Actual:  7841 7090
('中華電信', None) cannot be fully sampled. Expected/Actual:  1121 855


In [12]:
import pandas as pd

In [13]:
sample_table = pd.DataFrame(thread_sample, columns=["entity", "attribute", "thread_idx"])\
                .sort_values(["entity", "attribute"])
sample_table["batch"] = np.tile(np.arange(n_batch), batch_size)[:sample_table.shape[0]]
sample_table.fillna("[NA]", inplace=True)
sample_table.shape[0]

18983

In [14]:
sample_table

Unnamed: 0,entity,attribute,thread_idx,batch
23,中華電信,其他,10317,0
43,中華電信,其他,45195,1
55,中華電信,其他,38134,2
119,中華電信,其他,22896,3
298,中華電信,其他,44469,4
...,...,...,...,...
2629,[NA],[NA],54274,18
2631,[NA],[NA],42001,19
2667,[NA],[NA],5505,0
2680,[NA],[NA],7652,1


In [15]:
sample_table\
    .pivot_table(values="thread_idx", index=["entity"], columns=["attribute"], 
                 aggfunc="count", margins=True)

attribute,[NA],其他,加值服務,資費方案,通訊品質,All
entity,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
[NA],77,55,138,533,283,1086
中華電信,855,801,2077,7090,4161,14984
亞太電信,32,23,57,223,119,454
其他電信,5,3,5,29,16,58
台灣之星,65,46,117,449,238,915
台灣大哥大,53,38,93,365,194,743
遠傳電信,53,38,93,365,194,743
All,1140,1004,2580,9054,5205,18983


In [16]:
sample_table.loc[sample_table.batch==0, :]\
    .pivot_table(values="thread_idx", index=["entity"], columns=["attribute"], aggfunc="count", margins=True)

attribute,[NA],其他,加值服務,資費方案,通訊品質,All
entity,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
[NA],4.0,3.0,7.0,27.0,14.0,55
中華電信,43.0,41.0,103.0,355.0,208.0,750
亞太電信,1.0,1.0,3.0,11.0,6.0,22
其他電信,,1.0,,1.0,1.0,3
台灣之星,3.0,3.0,5.0,23.0,12.0,46
台灣大哥大,2.0,2.0,5.0,18.0,10.0,37
遠傳電信,2.0,2.0,5.0,18.0,10.0,37
All,55.0,53.0,128.0,453.0,261.0,950


In [17]:
sample_table.loc[sample_table.batch==5, :]\
    .pivot_table(values="thread_idx", index=["entity"], columns=["attribute"], aggfunc="count", margins=True)

attribute,[NA],其他,加值服務,資費方案,通訊品質,All
entity,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
[NA],3.0,3.0,7.0,26.0,15.0,54
中華電信,42.0,40.0,104.0,355.0,208.0,749
亞太電信,1.0,2.0,2.0,12.0,6.0,23
其他電信,,,1.0,1.0,1.0,3
台灣之星,3.0,2.0,6.0,23.0,12.0,46
台灣大哥大,3.0,2.0,4.0,19.0,9.0,37
遠傳電信,3.0,2.0,4.0,19.0,9.0,37
All,55.0,51.0,128.0,455.0,260.0,949


## Output JSON

In [18]:
import json
def write_to_json(batch_idx):
    sub_df = sample_table.loc[sample_table.batch==batch_idx, :].copy() 
    sub_df.sort_values("thread_idx", inplace=True)
    json_data = []
    for serial, row in enumerate(sub_df.itertuples()):
        thread_x = threads[row.thread_idx]
        html = casa.format_thread_html(thread_x, serial)
        json_data.append({
            "batch_idx": batch_idx, 
            "serial": serial,
            "thread_idx": row.thread_idx,
            "cadet_ent": threads_aspects[row.thread_idx][1],
            "cadet_attr": threads_aspects[row.thread_idx][2],
            "html": html
        })
    json_path = casa.get_data_path()/f"annot_data/thread_batch_{batch_idx}.json"
    with open(json_path, "w", encoding="UTF-8") as fout:
        json.dump(json_data, fout, ensure_ascii=False, indent=2)

In [19]:
for batch_idx in range(n_batch):
    write_to_json(batch_idx)