In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import json
import spacy
from tqdm import tqdm
from common_utils.multiwoz_data import make_tags, flatten_da, MultiWOZData

In [33]:
part = "train"
levels = ["A1", "A2", "B1", "B2", "FULL"]
nlp = spacy.load('en_core_web_sm')
multiwoz_data = MultiWOZData(rm_ws_before_punc=True)
vocab_level = json.load(open(f"./outputs/cefrj/{part}.json"))



In [34]:
level_tolerances = {
    "A1": ["NonAlpha+Stop", "A1"],
    "A2": ["NonAlpha+Stop", "A1", "A2"],
    "B1": ["NonAlpha+Stop", "A1", "A2", "B1"],
    "B2": ["NonAlpha+Stop", "A1", "A2", "B1", "B2"],
    "FULL": ["NonAlpha+Stop", "A1", "A2", "B1", "B2", "OOV"]
}

In [35]:
stats = {}
for dial_name in tqdm(multiwoz_data[part]):
    for i, side, turn in multiwoz_data.iter_dialog_log(part=part, dial_name=dial_name):
        if side != "sys":
            continue
        words = [token.text for token in nlp(turn["text"])]
        das = ["-".join(da) for da in flatten_da(turn["dialog_act"])]
        for level in levels:
            if vocab_level[dial_name][i] not in level_tolerances[level]:
                continue
            if level not in stats:
                stats[level] = {
                    "turns": 0,
                    "words": [],
                    "das": []
                }
            stats[level]["turns"] += 1
            stats[level]["words"] += words
            stats[level]["das"] += das
            

100%|██████████| 8434/8434 [05:45<00:00, 24.42it/s]


In [36]:
def print_stats(level_stats):
    print("turns: ", level_stats["turns"])
    print("vocab: ", len(set(level_stats["words"])))
    print("das: ", len(level_stats["das"]))
    print("uniq das: ", len(set(level_stats["das"])))

In [40]:
print_stats(stats["FULL"])

turns:  56750
vocab:  15914
das:  144634
uniq das:  22253
