In [2]:
import os
import json
import pandas as pd
import spacy
from tqdm import tqdm
from collections import defaultdict, Counter

In [3]:
nlp = spacy.load('en_core_web_sm')

In [4]:
def ppo_vocab_cnt(eval_type, evaluate_id):
    evaluated_dpath = "/data/group1/z44383r/dev/rl-nlg/experiments/evaluate_model/outputs"
    evaluated_result_fpath = os.path.join(evaluated_dpath, eval_type, evaluate_id, "result.json")
    result = json.load(open(evaluated_result_fpath))
    cefrj_wl_fpath = "/data/group1/z44383r/dev/rl-nlg/experiments/vocabulary_level/olp-en-cefrj/cefrj-vocabulary-profile-1.5.csv"
    df = pd.read_csv(cefrj_wl_fpath)

    levels = {
        'A1': df[df['CEFR']=='A1'],
        'A2': df[(df['CEFR']=='A2')],
        'B1': df[(df['CEFR']=='B1')],
        'B2': df[(df['CEFR']=='B2')],
    }

    cefrj_wl = defaultdict(list)
    for level in levels:
        vocabulary = []
        for word in levels[level]['headword']:
            word = word.lower()
            vocabulary += word.split("/")
        cefrj_wl[level] = set(vocabulary)

    vocab_level_dist = []
    for r in tqdm(result):
        doc = nlp(r["gen_text"])
        for token in doc:
            if not token.is_alpha: # アルファベットではない場合は除外
                vocab_level_dist.append('NonAlpha+Stop')
            elif token.is_stop: # stopwordは除外
                vocab_level_dist.append('NonAlpha+Stop')
            elif token.tag_ in ['NNP', 'NNPS']: # 固有名詞は除外
                vocab_level_dist.append('NNP(S)')
            elif token.lemma_.lower() in cefrj_wl['A1']:
                vocab_level_dist.append('A1')
            elif token.lemma_.lower() in cefrj_wl['A2']:
                vocab_level_dist.append('A2')
            elif token.lemma_.lower() in cefrj_wl['B1']:
                vocab_level_dist.append('B1')
            elif token.lemma_.lower() in cefrj_wl['B2']:
                vocab_level_dist.append('B2')
            else:
                vocab_level_dist.append('OOV')
    return Counter(vocab_level_dist)

In [8]:
def sum_cnt(cnt, levels):
    return sum(cnt[l] for l in levels)

In [55]:
cnt_A1_it100_milu = ppo_vocab_cnt("ppo", "it100-milu-cefrjA1-sys--seed12") \
                    + ppo_vocab_cnt("ppo", "it100-milu-cefrjA1-sys--seed34") \
                    + ppo_vocab_cnt("ppo", "it100-milu-cefrjA1-sys--seed56") \
                    + ppo_vocab_cnt("ppo", "it100-milu-cefrjA1-sys--seed78") \
                    + ppo_vocab_cnt("ppo", "it100-milu-cefrjA1-sys--seed90")

cnt_A2_it100_milu = ppo_vocab_cnt("ppo", "it100-milu-cefrjA2-sys--seed12") \
                    + ppo_vocab_cnt("ppo", "it100-milu-cefrjA2-sys--seed34") \
                    + ppo_vocab_cnt("ppo", "it100-milu-cefrjA2-sys--seed56") \
                    + ppo_vocab_cnt("ppo", "it100-milu-cefrjA2-sys--seed78") \
                    + ppo_vocab_cnt("ppo", "it100-milu-cefrjA2-sys--seed90")

cnt_B1_it100_milu = ppo_vocab_cnt("ppo", "it100-milu-cefrjB1-sys--seed12") \
                    + ppo_vocab_cnt("ppo", "it100-milu-cefrjB1-sys--seed34") \
                    + ppo_vocab_cnt("ppo", "it100-milu-cefrjB1-sys--seed56") \
                    + ppo_vocab_cnt("ppo", "it100-milu-cefrjB1-sys--seed78") \
                    + ppo_vocab_cnt("ppo", "it100-milu-cefrjB1-sys--seed90")

cnt_B2_it100_milu = ppo_vocab_cnt("ppo", "it100-milu-cefrjB2-sys--seed12") \
                    + ppo_vocab_cnt("ppo", "it100-milu-cefrjB2-sys--seed34") \
                    + ppo_vocab_cnt("ppo", "it100-milu-cefrjB2-sys--seed56") \
                    + ppo_vocab_cnt("ppo", "it100-milu-cefrjB2-sys--seed78") \
                    + ppo_vocab_cnt("ppo", "it100-milu-cefrjB2-sys--seed90")

100%|██████████| 7372/7372 [00:41<00:00, 177.51it/s]
100%|██████████| 7372/7372 [00:42<00:00, 174.91it/s]
100%|██████████| 7372/7372 [00:41<00:00, 176.90it/s]
100%|██████████| 7372/7372 [00:41<00:00, 176.16it/s]
100%|██████████| 7372/7372 [00:41<00:00, 176.54it/s]
100%|██████████| 7372/7372 [00:41<00:00, 177.41it/s]
100%|██████████| 7372/7372 [00:41<00:00, 177.74it/s]
100%|██████████| 7372/7372 [00:41<00:00, 178.00it/s]
100%|██████████| 7372/7372 [00:41<00:00, 179.06it/s]
100%|██████████| 7372/7372 [00:41<00:00, 179.35it/s]
100%|██████████| 7372/7372 [00:41<00:00, 178.91it/s]
100%|██████████| 7372/7372 [00:41<00:00, 178.75it/s]
100%|██████████| 7372/7372 [00:40<00:00, 180.48it/s]
100%|██████████| 7372/7372 [00:41<00:00, 179.18it/s]
100%|██████████| 7372/7372 [00:40<00:00, 179.83it/s]
100%|██████████| 7372/7372 [00:40<00:00, 179.85it/s]
100%|██████████| 7372/7372 [00:40<00:00, 180.65it/s]
100%|██████████| 7372/7372 [00:41<00:00, 179.37it/s]
100%|██████████| 7372/7372 [00:41<00:00, 178.5

In [22]:
cnt_A1_milu = ppo_vocab_cnt("ppo", "milu-cefrjA1-sys--seed12") \
            + ppo_vocab_cnt("ppo", "milu-cefrjA1-sys--seed34") \
            + ppo_vocab_cnt("ppo", "milu-cefrjA1-sys--seed56") \
            + ppo_vocab_cnt("ppo", "milu-cefrjA1-sys--seed78") \
            + ppo_vocab_cnt("ppo", "milu-cefrjA1-sys--seed90")

cnt_A2_milu = ppo_vocab_cnt("ppo", "milu-cefrjA2-sys--seed12") \
            + ppo_vocab_cnt("ppo", "milu-cefrjA2-sys--seed34") \
            + ppo_vocab_cnt("ppo", "milu-cefrjA2-sys--seed56") \
            + ppo_vocab_cnt("ppo", "milu-cefrjA2-sys--seed78") \
            + ppo_vocab_cnt("ppo", "milu-cefrjA2-sys--seed90")

cnt_B1_milu = ppo_vocab_cnt("ppo", "milu-cefrjB1-sys--seed12") \
            + ppo_vocab_cnt("ppo", "milu-cefrjB1-sys--seed34") \
            + ppo_vocab_cnt("ppo", "milu-cefrjB1-sys--seed56") \
            + ppo_vocab_cnt("ppo", "milu-cefrjB1-sys--seed78") \
            + ppo_vocab_cnt("ppo", "milu-cefrjB1-sys--seed90")

cnt_B2_milu = ppo_vocab_cnt("ppo", "milu-cefrjB2-sys--seed12") \
            + ppo_vocab_cnt("ppo", "milu-cefrjB2-sys--seed34") \
            + ppo_vocab_cnt("ppo", "milu-cefrjB2-sys--seed56") \
            + ppo_vocab_cnt("ppo", "milu-cefrjB2-sys--seed78") \
            + ppo_vocab_cnt("ppo", "milu-cefrjB2-sys--seed90")

100%|██████████| 7372/7372 [00:41<00:00, 177.23it/s]
100%|██████████| 7372/7372 [00:42<00:00, 175.14it/s]
100%|██████████| 7372/7372 [00:42<00:00, 173.00it/s]
100%|██████████| 7372/7372 [00:43<00:00, 171.41it/s]
100%|██████████| 7372/7372 [00:42<00:00, 175.31it/s]
100%|██████████| 7372/7372 [00:41<00:00, 176.34it/s]
100%|██████████| 7372/7372 [00:41<00:00, 177.83it/s]
100%|██████████| 7372/7372 [00:41<00:00, 177.14it/s]
100%|██████████| 7372/7372 [00:41<00:00, 176.69it/s]
100%|██████████| 7372/7372 [00:41<00:00, 176.71it/s]
100%|██████████| 7372/7372 [00:41<00:00, 177.23it/s]
100%|██████████| 7372/7372 [00:41<00:00, 177.18it/s]
100%|██████████| 7372/7372 [00:41<00:00, 177.33it/s]
100%|██████████| 7372/7372 [00:41<00:00, 176.82it/s]
100%|██████████| 7372/7372 [00:41<00:00, 178.52it/s]
100%|██████████| 7372/7372 [00:41<00:00, 178.90it/s]
100%|██████████| 7372/7372 [00:41<00:00, 178.95it/s]
100%|██████████| 7372/7372 [00:41<00:00, 178.34it/s]
100%|██████████| 7372/7372 [00:41<00:00, 178.4

In [40]:
cnt_A1_bert = ppo_vocab_cnt("ppo", "bert-cefrjA1-sys--seed12") \
            + ppo_vocab_cnt("ppo", "bert-cefrjA1-sys--seed34") \
            + ppo_vocab_cnt("ppo", "bert-cefrjA1-sys--seed56") \
            + ppo_vocab_cnt("ppo", "bert-cefrjA1-sys--seed78") \
            + ppo_vocab_cnt("ppo", "bert-cefrjA1-sys--seed90")

cnt_A2_bert = ppo_vocab_cnt("ppo", "bert-cefrjA2-sys--seed12") \
            + ppo_vocab_cnt("ppo", "bert-cefrjA2-sys--seed34") \
            + ppo_vocab_cnt("ppo", "bert-cefrjA2-sys--seed56") \
            + ppo_vocab_cnt("ppo", "bert-cefrjA2-sys--seed78") \
            + ppo_vocab_cnt("ppo", "bert-cefrjA2-sys--seed90")

cnt_B1_bert = ppo_vocab_cnt("ppo", "bert-cefrjB1-sys--seed12") \
            + ppo_vocab_cnt("ppo", "bert-cefrjB1-sys--seed34") \
            + ppo_vocab_cnt("ppo", "bert-cefrjB1-sys--seed56") \
            + ppo_vocab_cnt("ppo", "bert-cefrjB1-sys--seed78") \
            + ppo_vocab_cnt("ppo", "bert-cefrjB1-sys--seed90")

cnt_B2_bert = ppo_vocab_cnt("ppo", "bert-cefrjB2-sys--seed12") \
            + ppo_vocab_cnt("ppo", "bert-cefrjB2-sys--seed34") \
            + ppo_vocab_cnt("ppo", "bert-cefrjB2-sys--seed56") \
            + ppo_vocab_cnt("ppo", "bert-cefrjB2-sys--seed78") \
            + ppo_vocab_cnt("ppo", "bert-cefrjB2-sys--seed90")

100%|██████████| 7372/7372 [00:41<00:00, 176.36it/s]
100%|██████████| 7372/7372 [00:41<00:00, 177.72it/s]
100%|██████████| 7372/7372 [00:41<00:00, 177.49it/s]
100%|██████████| 7372/7372 [00:41<00:00, 177.92it/s]
100%|██████████| 7372/7372 [00:40<00:00, 181.05it/s]
100%|██████████| 7372/7372 [00:40<00:00, 180.08it/s]
100%|██████████| 7372/7372 [00:41<00:00, 178.69it/s]
100%|██████████| 7372/7372 [00:41<00:00, 178.97it/s]
100%|██████████| 7372/7372 [00:41<00:00, 179.78it/s]
100%|██████████| 7372/7372 [00:40<00:00, 180.85it/s]
100%|██████████| 7372/7372 [00:41<00:00, 179.45it/s]
100%|██████████| 7372/7372 [00:40<00:00, 181.92it/s]
100%|██████████| 7372/7372 [00:40<00:00, 183.09it/s]
100%|██████████| 7372/7372 [00:40<00:00, 180.67it/s]
100%|██████████| 7372/7372 [00:40<00:00, 180.15it/s]
100%|██████████| 7372/7372 [00:40<00:00, 182.56it/s]
100%|██████████| 7372/7372 [00:40<00:00, 179.98it/s]
100%|██████████| 7372/7372 [00:40<00:00, 183.94it/s]
100%|██████████| 7372/7372 [00:41<00:00, 179.0

In [25]:
cnt_full_gpt2 = ppo_vocab_cnt("baselines", "gpt2-bert-cefrjA1-sys")
# cnt_full_gpt2 = ppo_vocab_cnt("baselines", "gpt2-bert-cefrjA2-sys")
# cnt_full_gpt2 = ppo_vocab_cnt("baselines", "gpt2-bert-cefrjB1-sys")
# cnt_full_gpt2 = ppo_vocab_cnt("baselines", "gpt2-bert-cefrjB2-sys")

100%|██████████| 7372/7372 [00:42<00:00, 175.10it/s]


In [5]:
cnt_full_scgpt = ppo_vocab_cnt("baselines", "scgpt-bert-cefrjA1-sys")
# cnt_full_scgpt += ppo_vocab_cnt("baselines", "scgpt-bert-cefrjA2-sys")
# cnt_full_scgpt += ppo_vocab_cnt("baselines", "scgpt-bert-cefrjB1-sys")
# cnt_full_scgpt += ppo_vocab_cnt("baselines", "scgpt-bert-cefrjB2-sys")

100%|██████████| 7372/7372 [00:55<00:00, 133.68it/s]


In [17]:
cnt_full_sclstm = ppo_vocab_cnt("baselines", "sclstm-bert-cefrjA1-sys")
# cnt_full_sclstm += ppo_vocab_cnt("baselines", "sclstm-bert-cefrjA2-sys")
# cnt_full_sclstm += ppo_vocab_cnt("baselines", "sclstm-bert-cefrjB1-sys")
# cnt_full_sclstm += ppo_vocab_cnt("baselines", "sclstm-bert-cefrjB2-sys")

100%|██████████| 7372/7372 [00:42<00:00, 174.01it/s]


In [72]:
total = ["A1", "A2", "B1", "B2", "OOV"]

In [75]:
levels = ["A1"]
print( "sclstm: ", sum_cnt(cnt_full_sclstm, levels) / sum_cnt(cnt_full_sclstm, total) )
print( "scgpt: ", sum_cnt(cnt_full_scgpt, levels) / sum_cnt(cnt_full_scgpt, total) )
print( "gpt2: ", sum_cnt(cnt_full_gpt2, levels) / sum_cnt(cnt_full_gpt2, total) )
print( "gpt2ppo-milu: ", sum_cnt(cnt_A1_milu, levels) / sum_cnt(cnt_A1_milu, total) )
print( "gpt2ppo-it100-milu: ", sum_cnt(cnt_A1_it100_milu, levels) / sum_cnt(cnt_A1_it100_milu, total) )
print( "gpt2ppo-bert: ", sum_cnt(cnt_A1_bert, levels) / sum_cnt(cnt_A1_bert, total) )

sclstm:  0.672660413225716
scgpt:  0.6270900058003177
gpt2:  0.647630762021906
gpt2ppo-milu:  0.6809635648089374
gpt2ppo-it100-milu:  0.68213765511802
gpt2ppo-bert:  0.6695149503338425


In [76]:
levels = ["A1", "A2"]
print( "sclstm: ", sum_cnt(cnt_full_sclstm, levels) / sum_cnt(cnt_full_sclstm, total) )
print( "scgpt: ", sum_cnt(cnt_full_scgpt, levels) / sum_cnt(cnt_full_scgpt, total) )
print( "gpt2: ", sum_cnt(cnt_full_gpt2, levels) / sum_cnt(cnt_full_gpt2, total) )
print( "gpt2ppo-milu: ", sum_cnt(cnt_A2_milu, levels) / sum_cnt(cnt_A2_milu, total) )
print( "gpt2ppo-it100-milu: ", sum_cnt(cnt_A2_it100_milu, levels) / sum_cnt(cnt_A2_it100_milu, total) )
print( "gpt2ppo-bert: ", sum_cnt(cnt_A2_bert, levels) / sum_cnt(cnt_A2_bert, total) )

sclstm:  0.7664308641590576
scgpt:  0.7361864171689405
gpt2:  0.7723305804486617
gpt2ppo-milu:  0.7982278622888348
gpt2ppo-it100-milu:  0.794923707471236
gpt2ppo-bert:  0.7956764950920514


In [77]:
levels = ["A1", "A2", "B1"]
print( "sclstm: ", sum_cnt(cnt_full_sclstm, levels) / sum_cnt(cnt_full_sclstm, total) )
print( "scgpt: ", sum_cnt(cnt_full_scgpt, levels) / sum_cnt(cnt_full_scgpt, total) )
print( "gpt2: ", sum_cnt(cnt_full_gpt2, levels) / sum_cnt(cnt_full_gpt2, total) )
print( "gpt2ppo-milu: ", sum_cnt(cnt_B1_milu, levels) / sum_cnt(cnt_B1_milu, total) )
print( "gpt2ppo-it100-milu: ", sum_cnt(cnt_B1_it100_milu, levels) / sum_cnt(cnt_B1_it100_milu, total) )
print( "gpt2ppo-bert: ", sum_cnt(cnt_B1_bert, levels) / sum_cnt(cnt_B1_bert, total) )

sclstm:  0.8836049736669886
scgpt:  0.8442488588505284
gpt2:  0.8721372928014994
gpt2ppo-milu:  0.8800404311056352
gpt2ppo-it100-milu:  0.8823642283637846
gpt2ppo-bert:  0.8804230755267266


In [78]:
levels = ["A1", "A2", "B1", "B2"]
print( "sclstm: ", sum_cnt(cnt_full_sclstm, levels) / sum_cnt(cnt_full_sclstm, total) )
print( "scgpt: ", sum_cnt(cnt_full_scgpt, levels) / sum_cnt(cnt_full_scgpt, total) )
print( "gpt2: ", sum_cnt(cnt_full_gpt2, levels) / sum_cnt(cnt_full_gpt2, total) )
print( "gpt2ppo-milu: ", sum_cnt(cnt_B2_milu, levels) / sum_cnt(cnt_B2_milu, total) )
print( "gpt2ppo-it100-milu: ", sum_cnt(cnt_B2_it100_milu, levels) / sum_cnt(cnt_B2_it100_milu, total) )
print( "gpt2ppo-bert: ", sum_cnt(cnt_B2_bert, levels) / sum_cnt(cnt_B2_bert, total) )

sclstm:  0.9165757736295927
scgpt:  0.8674501298766802
gpt2:  0.9025361682188251
gpt2ppo-milu:  0.9100049757885411
gpt2ppo-it100-milu:  0.9113969557746444
gpt2ppo-bert:  0.9104262401966009
