In [1]:
import json
import pickle
import tiktoken
import numpy as np
from xopen import xopen
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import defaultdict
import string
from typing import List
import regex
chatgpt_tok = tiktoken.encoding_for_model("gpt-3.5-turbo")

In [2]:
def normalize_answer(s: str) -> str:
    """Normalization from the SQuAD evaluation script.

    See https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/
    """

    def remove_articles(text):
        return regex.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def best_subspan_em(prediction: str, ground_truths: List[str]) -> float:
    normalized_prediction = normalize_answer(prediction)

    for ground_truth in ground_truths:
        normalized_ground_truth = normalize_answer(ground_truth)
        if normalized_ground_truth.lower() in normalized_prediction.lower():
            return 1.0
    return 0.0

In [3]:
for ctx_score_cumsum in [0.3, 0.4, 0.5, 0.6]:
    for sent_low in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]:
        with xopen(f"compressed_qa_predictions/nq_20/dpr_20_fid_doc{ctx_score_cumsum}_sl{sent_low}_sh1.0_tl1.0-Llama-2-13b-chat-hf-predictions.jsonl.gz") as f:
            preds = [json.loads(l) for l in f]
            len_list = []
            for pred in preds:
                len_list.append(len(chatgpt_tok.encode(pred['compressed_prompt'])))
            print(ctx_score_cumsum, sent_low, np.mean(len_list))


FileNotFoundError: [Errno 2] No such file or directory: 'compressed_qa_predictions/nq_20/dpr_20_fid_doc0.3_sl0.1_sh1.0_tl1.0-Llama-2-13b-chat-hf-predictions.jsonl.gz'

In [None]:
with open('token_scores/token_scores_list_20_documents_gold_at_0_oneContextFalse.pkl', 'rb') as f:
    token_scores_list = pickle.load(f)

In [None]:
np.array(token_scores_list_9[0][0])

In [None]:
def gini(x):
    mean_absolute_diff = np.abs(np.subtract.outer(x, x)).mean()
    relative_mean_absolute_diff = mean_absolute_diff/np.mean(x)
    g = 0.5 * relative_mean_absolute_diff
    return g


In [None]:
# with xopen('compressed_qa_predictions/dpr_nq_20_dev_20/fid_500_ctxTrue_sentTrue0.15_tokFalse1.0_orgidxTrue_Llama-2-13b-chat-hf.jsonl.gz') as f:
#     preds = [json.loads(l) for l in f]
# FiDcomp/
with xopen('compressed_qa_predictions/dpr_nq_20/dpr_20_fid_doc0.4_sl0.3_sh1.0_tl1.0-Llama-2-13b-chat-hf-predictions.jsonl.gz') as f:
    preds_old = [json.loads(l) for l in f]
with xopen(f'compressed_qa_predictions/base/0527_adaptive_ctx_score_v3_power4/dpr_nq_test_20/fid_500_ctxTrue_sentTrue0.15_4_orgidxTrue_Llama-2-13b-chat-hf.jsonl.gz') as f:
    preds_new = [json.loads(l) for l in f]

In [None]:
qas_i = 5

gold_answers = preds_old[qas_i]["answers"]
print(preds_old[qas_i]['compressed_prompt'], preds_old[qas_i]['model_answer'], best_subspan_em(prediction=preds_old[qas_i]['model_answer'], ground_truths=preds_old[qas_i]["answers"]))
print()
print(preds_new[qas_i]['compressed_prompt'], preds_new[qas_i]['model_answer'], best_subspan_em(prediction=preds_new[qas_i]['model_answer'], ground_truths=preds_old[qas_i]["answers"]))

In [None]:
print(preds_new[0]['compressed_prompt'])

In [None]:
diff_preds = []
span_em_list_old, span_em_list_new = [], []
for qas_i, pred_old, pred_new in enumerate(zip(preds_old, preds_new)):
    gold_answers = pred_old["answers"]
    model_answer = pred_old["model_answer"]
    span_em_old = best_subspan_em(prediction=model_answer, ground_truths=gold_answers)
    span_em_list_old.append(span_em_old)

    gold_answers = pred_new["answers"]
    model_answer = pred_new["model_answer"]
    span_em_new = best_subspan_em(prediction=model_answer, ground_truths=gold_answers)
    span_em_list_new.append(span_em_new)
    if span_em_old and not span_em_new:
        diff_preds.append(qas_i)
    # span_em_list.append(span_em)

print("OLD:", f"{np.mean(span_em_list_old)*100:.2f}", "NEW:", f"{np.mean(span_em_list_new)*100:.2f}")

In [None]:
diff_preds[]

In [None]:
diff_preds[0][0]['ctxs']

In [None]:
print(diff_preds[0][0]['compressed_prompt']), print(diff_preds[0][1]['compressed_prompt'])

In [99]:
low_indices = {}
for gini in [0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 'mean']:
    with open(f'gini_low_{gini}_indices.pkl', 'rb') as f:
        gini_low_indices = pickle.load(f)
        low_indices[gini] = gini_low_indices

low_indices_entropy = {}
for entropy in ['mean']:
    with open(f'entropy_low_{gini}_indices.pkl', 'rb') as f:
        entropy_low_indices = pickle.load(f)
        low_indices_entropy[entropy] = set(entropy_low_indices)

In [None]:
q_indices = set()
with open('qa_data/dpr_nq_test.json') as f:
    test2655 = json.load(f)
with open('../eun_FiD/open_domain_data/nq/test.json') as f:
    test3610 = json.load(f)
for qas in test2655:
    for qas_i in range(len(test3610)):
        if qas['question'] == test3610[qas_i]['question']:
            q_indices.add(qas_i)
            break


In [None]:
len(q_indices)

## NQ test 3610 평가

In [None]:
################# Test 평가
version='0601_reader_base'
for power in [2, 3, 4]:
    sent_ratio_list = [0.1, 0.15, 0.2]
    for sent_ratio in sent_ratio_list:
        with xopen(f'compressed_qa_predictions/base/{version}/dpr_nq_test3610_20/fid_500_ctxTrue_sentTrue{sent_ratio}_{power}_sent1False_orgidxTrue_Llama-2-13b-chat-hf.jsonl.gz') as f:
            preds = [json.loads(l) for l in f]
            span_em_list = []
            span_em_list_2655 = []
            for pred_i, pred in enumerate(preds):
                gold_answers = pred["answers"]
                model_answer = pred["model_answer"]
                span_em = best_subspan_em(prediction=model_answer, ground_truths=gold_answers)
                span_em_list.append(span_em)
                if pred_i in q_indices:
                    span_em_list_2655.append(span_em)
            
            print(f"3610 sent_ratio{sent_ratio} {power} {np.mean(span_em_list)*100:.2f} {np.mean(span_em_list_2655)*100:.2f}")
    ctx_cumsum_list= [0.35, 0.4, 0.45, 0.5, 0.55, 0.6]
    for cumsum in ctx_cumsum_list:
        with xopen(f'compressed_qa_predictions/base/{version}/dpr_nq_test3610_20/fid_500_ctxTrue{cumsum}_sentTrue_{power}_sent1False_orgidxTrue_Llama-2-13b-chat-hf.jsonl.gz') as f:
            preds = [json.loads(l) for l in f]
            span_em_list = []
            span_em_list_2655 = []
            for pred_i, pred in enumerate(preds):
                gold_answers = pred["answers"]
                model_answer = pred["model_answer"]
                span_em = best_subspan_em(prediction=model_answer, ground_truths=gold_answers)
                span_em_list.append(span_em)
                if pred_i in q_indices:
                    span_em_list_2655.append(span_em)
            
            print(f"3610 cumsum{cumsum} {power} {np.mean(span_em_list)*100:.2f} {np.mean(span_em_list_2655)*100:.2f}")

In [102]:
################# Test2655 평가
gini2span_em_list_low_mean = defaultdict(list)
gini2span_em_list_high_mean = defaultdict(list)
version='0601_reader_base_solvedtitleinsentsissue'
dataset = 'dpr_nq_test_20'
for power in [2]:
    sent_ratio_list = [0.1, 0.15, 0.2, 0.25]
    for sent_ratio in sent_ratio_list:
        with xopen(f'compressed_qa_predictions/base/{version}/{dataset}/fid_500_ctxTrue_sentTrue{sent_ratio}_{power}_sent1False_orgidxTrue_Llama-2-13b-chat-hf.jsonl.gz') as f:
            preds = [json.loads(l) for l in f]
            span_em_list = []
            gini2span_em_list_low = defaultdict(list)
            gini2span_em_list_high = defaultdict(list)

            for pred_i, pred in enumerate(preds):
                gold_answers = pred["answers"]
                model_answer = pred["model_answer"]
                span_em = best_subspan_em(prediction=model_answer, ground_truths=gold_answers)
                span_em_list.append(span_em)
            
                for gini, gini_indices in low_indices.items():
                    if pred_i in gini_indices:
                        gini2span_em_list_low[gini].append(span_em)
                    else:
                        gini2span_em_list_high[gini].append(span_em)
                        
            for gini, span_em_list_low in gini2span_em_list_low.items():
                gini2span_em_list_low_mean[gini].append(np.mean(span_em_list_low))
            for gini, span_em_list_high in gini2span_em_list_high.items():
                gini2span_em_list_high_mean[gini].append(np.mean(span_em_list_high))

            print(f"sent_ratio{sent_ratio} {power} {np.mean(span_em_list)*100:.6f}")
            
# for power in [2, 3, 4]:
#     ctx_cumsum_list= [0.35, 0.4, 0.45, 0.5, 0.55, 0.6]
#     for cumsum in ctx_cumsum_list:
#         with xopen(f'compressed_qa_predictions/base/{version}/{dataset}/fid_500_ctxTrue{cumsum}_sentTrue_{power}_sent1False_orgidxTrue_Llama-2-13b-chat-hf.jsonl.gz') as f:
#             preds = [json.loads(l) for l in f]
#             span_em_list = []
#             for pred_i, pred in enumerate(preds):
#                 gold_answers = pred["answers"]
#                 model_answer = pred["model_answer"].split('\n')[0]
#                 span_em = best_subspan_em(prediction=model_answer, ground_truths=gold_answers)
#                 span_em_list.append(span_em)
            
#             print(f"cumsum{cumsum} {power} {np.mean(span_em_list)*100:.6f}")

sent_ratio0.1 2 67.721281
sent_ratio0.15 2 67.495292
sent_ratio0.2 2 67.645951
sent_ratio0.25 2 67.645951


In [103]:
gini2span_em_list_low_mean

defaultdict(list,
            {0.3: [0.6846573681018799,
              0.6858702243784112,
              0.6858702243784112,
              0.6852637962401456],
             0.35: [0.6809787626962143,
              0.6800554016620498,
              0.6805170821791321,
              0.6791320406278855],
             0.4: [0.6778846153846154,
              0.6754807692307693,
              0.6778846153846154,
              0.6770833333333334],
             0.25: [0.688780487804878,
              0.6878048780487804,
              0.6897560975609756,
              0.686829268292683],
             'mean': [0.6881331403762663,
              0.6895803183791607,
              0.6895803183791607,
              0.6888567293777135],
             0.15: [0.7476635514018691,
              0.7476635514018691,
              0.7757009345794392,
              0.7757009345794392],
             0.2: [0.6997690531177829,
              0.6951501154734411,
              0.7043879907621247,
              0.706

In [None]:
q_indices_recall20 = set()
for qas_i, qas in enumerate(dev):
    has_answers = [ctx['has_answer'] for ctx in qas['ctxs']]
    if sum(has_answers) > 0:
        q_indices_recall20.add(qas_i)

In [4]:
with open("qa_data/dpr_nq_dev.json") as f:
    dev = json.load(f)
with open('../eun_FiD/open_domain_data/nq/biencoder-nq-dev.json') as f:
    dev_dpr = json.load(f)

In [5]:
for qas in dev_dpr:
    qas['question_'] = ''.join(qas['question'].replace('--', '_').split())

In [6]:
q_indices_dev_dpr = set()
for qas_i, qas in enumerate(tqdm(dev)):
    cur_q = ''.join(qas['question'].split())
    for i in range(len(dev_dpr)):
        if cur_q == dev_dpr[i]['question_']:
            if qas_i not in q_indices_dev_dpr:
                q_indices_dev_dpr.add(qas_i)
                break
            else:
                continue

100%|██████████| 8757/8757 [00:06<00:00, 1432.46it/s]


In [7]:
questions = [''.join(dev[qas_i]['question'].split()) for qas_i in q_indices_dev_dpr]

In [8]:
for qas in dev_dpr:
    if qas['question_'] not in questions:
        print(qas['question'])

the meridian opposite of earth 's prime meridian ( 0 ° longitude ) is called
how often does spermatogeneis -- the production of sperm -- occur
st. peter 's basilica the head of the catholic religion is located in
-- composer arnold schoenberg was highly influential in the movement called --
who won the sprint 15km men 's cross country skiing event in sochi in 2014
the popular sculpture shiva nataraja depicts shiva -- or siva -- as
what was national louis university 's initial area of specialization
if two organism 's have the same name in their binomial name then they are in the same
which movie has won the best animation award at the 89th academy awards -- 2017
name 2 art movements that occurred during the 1920 's
what are hare 's two levels of moral thinking


In [19]:
for qas_i, qas in enumerate(dev):
    if 'two levels of moral thinking' in qas['question']:
        print(qas['question'])
        q_indices_dev_dpr.add(qas_i)

what are hare’s two levels of moral thinking


In [20]:
dev_q_indices_list = [q_indices_dev_dpr, q_indices_recall20]

NameError: name 'q_indices_recall20' is not defined

In [107]:
with xopen('compressed_qa_data/base/0601_reader_base_solvedtitleinsentsissue/dpr_nq_test_20/fid_500_ctxTrue0.35_sentTrue_2_sent1False_orgidxTrue.jsonl.gz_v1') as f:
    data_v1 = [json.loads(l) for l in f]

In [108]:
with xopen('compressed_qa_data/base/0601_reader_base_solvedtitleinsentsissue/dpr_nq_test_20/fid_500_ctxTrue0.35_sentTrue_2_sent1False_orgidxTrue.jsonl.gz') as f:
    data_v2 = [json.loads(l) for l in f]

In [112]:
for qas_i, (qas_1, qas_2) in enumerate(zip(data_v1, data_v2)):
    if qas_1['compressed_prompt'] != qas_2['compressed_prompt']:
        print(qas_i)

## DEV 평가

In [25]:
version = '0601_reader_base_solvedtitleinsentsissue'
dataset = 'dpr_nq_dev_20'
powers = [2, 3, 4]
for power in powers:
    sent_ratio_list = [0.1, 0.15, 0.2, 0.25]
    for sent_ratio in sent_ratio_list:
        with xopen(f'compressed_qa_predictions/base/{version}/{dataset}/fid_500_ctxTrue_sentTrue{sent_ratio}_{power}_sent1False_orgidxTrue_Llama-2-13b-chat-hf.jsonl.gz') as f:
            preds = [json.loads(l) for l in f]
            span_em_list = []
            recall_list = []
            span_em_list2 = []
            for pred_i, pred in enumerate(preds):
                gold_answers = pred["answers"]
                model_answer = pred["model_answer"]
                span_em = best_subspan_em(prediction=model_answer, ground_truths=gold_answers)
                recall = best_subspan_em(prediction=pred['model_prompt'], ground_truths=gold_answers)
                span_em_list.append(span_em)
                recall_list.append(recall)
                if pred_i in q_indices_dev_dpr:
                    span_em_list2.append(span_em)
            
            print(f"sent_ratio{sent_ratio} {power} {np.mean(span_em_list)*100:.6f} {np.mean(span_em_list2)*100:.6f}, {np.mean(recall_list)*100:.6f}")
            
for power in powers:
    ctx_cumsum_list= [0.35, 0.4, 0.45, 0.5, 0.55, 0.6]
    for cumsum in ctx_cumsum_list:
        with xopen(f'compressed_qa_predictions/base/{version}/{dataset}/fid_500_ctxTrue{cumsum}_sentTrue_{power}_sent1False_orgidxTrue_Llama-2-13b-chat-hf.jsonl.gz') as f:
            preds = [json.loads(l) for l in f]
            span_em_list = []
            recall_list = []
            span_em_list2 = []
            for pred_i, pred in enumerate(preds):
                gold_answers = pred["answers"]
                model_answer = pred["model_answer"]
                span_em = best_subspan_em(prediction=model_answer, ground_truths=gold_answers)
                recall = best_subspan_em(prediction=pred['model_prompt'], ground_truths=gold_answers)
                span_em_list.append(span_em)
                recall_list.append(recall)
                if pred_i in q_indices_dev_dpr:
                    span_em_list2.append(span_em)
            
            print(f"cumsum {cumsum} {power} {np.mean(span_em_list)*100:.6f} {np.mean(span_em_list2)*100:.6f}, {np.mean(recall_list)*100:.6f}")

sent_ratio0.1 2 47.790339 53.998465, 77.001256
sent_ratio0.15 2 47.824597 54.044513, 76.955578
sent_ratio0.2 2 47.801759 54.013814, 76.864223
sent_ratio0.25 2 47.756081 54.013814, 76.887062
sent_ratio0.1 3 47.698984 53.844973, 76.704351
sent_ratio0.15 3 47.721823 53.875672, 76.624415
sent_ratio0.2 3 47.824597 54.105909, 76.521640
sent_ratio0.25 3 47.858856 54.244052, 76.510220
sent_ratio0.1 4 47.995889 54.351497, 76.350348
sent_ratio0.15 4 47.836017 54.198005, 76.304671
sent_ratio0.2 4 47.607628 53.860322, 76.213315
sent_ratio0.25 4 47.539112 53.752878, 76.201896
cumsum 0.35 2 59.655133 67.413661, 72.547676
cumsum 0.4 2 60.134749 68.089025, 75.151308
cumsum 0.45 2 60.351719 68.104375, 76.350348
cumsum 0.5 2 60.705721 68.442057, 76.795706
cumsum 0.55 2 60.682882 68.273216, 76.875642
cumsum 0.6 2 60.214685 67.828089, 76.818545
cumsum 0.35 3 59.735069 67.475058, 72.547676
cumsum 0.4 3 60.260363 68.288565, 75.117049
cumsum 0.45 3 60.214685 68.042978, 76.201896
cumsum 0.5 3 60.443074 68.119

: 

In [None]:
for pow in [4]:
    gini2span_em_list_low_mean = defaultdict(list)
    gini2span_em_list_high_mean = defaultdict(list)
            gini2span_em_list_low = defaultdict(list)
            gini2span_em_list_high = defaultdict(list)
    entropy2span_em_list_low_mean = defaultdict(list)
    entropy2span_em_list_high_mean = defaultdict(list)
    rate2num_tokens = defaultdict(list)
    # for rate in [0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]:
    for rate in [0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]:
        if rate == 0.0:
            sent_comp = False
        else:
            sent_comp = True
        with xopen(f'compressed_qa_predictions/base/0530_ctx_cumsum_pow{pow}/dpr_nq_test_20/fid_500_ctxTrue{rate}_sentTrue_{pow}_sent1False_orgidxTrue_Llama-2-13b-chat-hf.jsonl.gz') as f:
        # with xopen(f'compressed_qa_predictions/base/0527_adaptive_ctx_score_v3_power{pow}_q_include/dpr_nq_test_20/fid_500_ctxTrue_sent{sent_comp}{rate}_{pow}_orgidxTrue_Llama-2-13b-chat-hf.jsonl.gz') as f:
            preds = [json.loads(l) for l in f]
            doc_num_list = []
            span_em_list = []
            gini2span_em_list_low = defaultdict(list)
            gini2span_em_list_high = defaultdict(list)

            entropy2span_em_list_low = defaultdict(list)
            entropy2span_em_list_high = defaultdict(list)

            for p_i, pred in enumerate(preds):
                rate2num_tokens[str(rate)].append(len(chatgpt_tok.encode(pred['compressed_prompt'])))
                cnt = pred['model_prompt'].count('Document [')
                doc_num_list.append(cnt)
                gold_answers = pred["answers"]
                model_answer = pred["model_answer"]
                span_em = best_subspan_em(prediction=model_answer, ground_truths=gold_answers)
                span_em_list.append(span_em)

                for gini, gini_indices in low_indices.items():
                    if p_i in gini_indices:
                        gini2span_em_list_low[gini].append(span_em)
                    else:
                        gini2span_em_list_high[gini].append(span_em)

                for entropy, entropy_indices in low_indices_entropy.items():
                    if p_i in entropy_indices:
                        entropy2span_em_list_low[entropy].append(span_em)
                    else:
                        entropy2span_em_list_high[entropy].append(span_em)
                        
            for gini, span_em_list_low in gini2span_em_list_low.items():
                gini2span_em_list_low_mean[gini].append(np.mean(span_em_list_low))
            for gini, span_em_list_high in gini2span_em_list_high.items():
                gini2span_em_list_high_mean[gini].append(np.mean(span_em_list_high))

            for entropy, span_em_list_low in entropy2span_em_list_low.items():
                entropy2span_em_list_low_mean[entropy].append(np.mean(span_em_list_low))
            for entropy, span_em_list_high in entropy2span_em_list_high.items():
                entropy2span_em_list_high_mean[entropy].append(np.mean(span_em_list_high))

            print(f'rate: {rate} num_docs: {np.mean(doc_num_list):.2f} num_tokens: {np.round(np.mean(rate2num_tokens[str(rate)]), 1)} span_em: {100 * np.mean(span_em_list):.4f}')
            
            # print(f'low ({len(span_em_list_low)}): {100 * np.mean(span_em_list_low):.2f}', end=' ')
            # print(f'high ({len(span_em_list_high)}): {100 * np.mean(span_em_list_high):.2f}')

In [None]:
for pow in [4]:
    gini2span_em_list_low_mean = defaultdict(list)
    gini2span_em_list_high_mean = defaultdict(list)
    entropy2span_em_list_low_mean = defaultdict(list)
    entropy2span_em_list_high_mean = defaultdict(list)
    rate2num_tokens = defaultdict(list)
    # for rate in [0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]:
    for rate in [0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]:
        if rate == 0.0:
            sent_comp = False
        else:
            sent_comp = True
        with xopen(f'compressed_qa_predictions/base/0530_ctx_cumsum_pow{pow}/dpr_nq_test_20/fid_500_ctxTrue{rate}_sentTrue_{pow}_sent1False_orgidxTrue_Llama-2-13b-chat-hf.jsonl.gz') as f:
        # with xopen(f'compressed_qa_predictions/base/0527_adaptive_ctx_score_v3_power{pow}_q_include/dpr_nq_test_20/fid_500_ctxTrue_sent{sent_comp}{rate}_{pow}_orgidxTrue_Llama-2-13b-chat-hf.jsonl.gz') as f:
            preds = [json.loads(l) for l in f]
            doc_num_list = []
            span_em_list = []
            gini2span_em_list_low = defaultdict(list)
            gini2span_em_list_high = defaultdict(list)

            entropy2span_em_list_low = defaultdict(list)
            entropy2span_em_list_high = defaultdict(list)

            for p_i, pred in enumerate(preds):
                rate2num_tokens[str(rate)].append(len(chatgpt_tok.encode(pred['compressed_prompt'])))
                cnt = pred['model_prompt'].count('Document [')
                doc_num_list.append(cnt)
                gold_answers = pred["answers"]
                model_answer = pred["model_answer"]
                span_em = best_subspan_em(prediction=model_answer, ground_truths=gold_answers)
                span_em_list.append(span_em)

                for gini, gini_indices in low_indices.items():
                    if p_i in gini_indices:
                        gini2span_em_list_low[gini].append(span_em)
                    else:
                        gini2span_em_list_high[gini].append(span_em)

                for entropy, entropy_indices in low_indices_entropy.items():
                    if p_i in entropy_indices:
                        entropy2span_em_list_low[entropy].append(span_em)
                    else:
                        entropy2span_em_list_high[entropy].append(span_em)
                        
            for gini, span_em_list_low in gini2span_em_list_low.items():
                gini2span_em_list_low_mean[gini].append(np.mean(span_em_list_low))
            for gini, span_em_list_high in gini2span_em_list_high.items():
                gini2span_em_list_high_mean[gini].append(np.mean(span_em_list_high))

            for entropy, span_em_list_low in entropy2span_em_list_low.items():
                entropy2span_em_list_low_mean[entropy].append(np.mean(span_em_list_low))
            for entropy, span_em_list_high in entropy2span_em_list_high.items():
                entropy2span_em_list_high_mean[entropy].append(np.mean(span_em_list_high))

            print(f'rate: {rate} num_docs: {np.mean(doc_num_list):.2f} num_tokens: {np.round(np.mean(rate2num_tokens[str(rate)]), 1)} span_em: {100 * np.mean(span_em_list):.4f}')
            
            # print(f'low ({len(span_em_list_low)}): {100 * np.mean(span_em_list_low):.2f}', end=' ')
            # print(f'high ({len(span_em_list_high)}): {100 * np.mean(span_em_list_high):.2f}')

In [None]:
67.5 --> 67.8은 갈 수 있는데 이게 맞나?

In [None]:
rate: 0.05 num_docs: 3.84 num_tokens: 484.0 span_em: 66.52
rate: 0.1 num_docs: 4.21 num_tokens: 483.0 span_em: 67.38
rate: 0.15 num_docs: 4.34 num_tokens: 483.0 span_em: 67.53
rate: 0.2 num_docs: 4.34 num_tokens: 482.0 span_em: 67.19
rate: 0.25 num_docs: 4.31 num_tokens: 483.0 span_em: 66.97
rate: 0.3 num_docs: 4.26 num_tokens: 482.0 span_em: 67.04
rate: 0.35 num_docs: 4.23 num_tokens: 482.0 span_em: 67.08
rate: 0.4 num_docs: 4.18 num_tokens: 482.0 span_em: 66.97
rate: 0.45 num_docs: 4.17 num_tokens: 482.0 span_em: 67.01
rate: 0.5 num_docs: 4.13 num_tokens: 482.0 span_em: 66.82

In [None]:
entropy2span_em_list_high_mean['mean']

In [None]:
entropy2span_em_list_low_mean['mean']

In [None]:
len(entropy2span_em_list_high['mean']), len(entropy2span_em_list_low['mean'])

In [None]:
len(gini2span_em_list_high['mean'])

In [104]:
rate = [0.1, 0.15, 0.2, 0.25]
for gini in [0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 'mean']:
    score_for_low = max(gini2span_em_list_low_mean[gini]) * len(gini2span_em_list_low[gini])
    score_for_high = max(gini2span_em_list_high_mean[gini]) * len(gini2span_em_list_high[gini])
    rate_for_low = rate[np.argmax(gini2span_em_list_low_mean[gini])]
    rate_for_high = rate[np.argmax(gini2span_em_list_high_mean[gini])]
    print(f"gini: {gini} ({rate_for_low}, {rate_for_high}), {100 * (score_for_low + score_for_high) / 2655:.2f}")

gini: 0.1 (0.1, 0.1), 67.72
gini: 0.15 (0.2, 0.1), 67.83
gini: 0.2 (0.25, 0.1), 67.83
gini: 0.25 (0.2, 0.1), 67.76
gini: 0.3 (0.15, 0.1), 67.80
gini: 0.35 (0.1, 0.25), 67.80
gini: 0.4 (0.1, 0.1), 67.72
gini: mean (0.15, 0.1), 67.80


In [None]:
gini2span_em_list_high_mean[0.3], gini2span_em_list_low_mean[0.3]

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Assuming you have already defined 'gini2span_em_list_low_mean' and 'gini2span_em_list_high_mean'
fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(20, 20))
axes = axes.flatten()

ginis = [0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 'mean']
x = [0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]
fontsize = 20
for i, gini in enumerate(ginis):
    low_scores = gini2span_em_list_low_mean[gini]
    high_scores = gini2span_em_list_high_mean[gini]
    ax = axes[i]

    # Plotting the data
    ax.plot(x, high_scores, label=f'High coefficient ({len(gini2span_em_list_high[gini])})')
    ax.plot(x, low_scores, label=f'Low coefficient ({len(gini2span_em_list_low[gini])})')

    # Customizing the y-axis limits to close the gap
    high_min = min(high_scores)
    high_max = max(high_scores)
    low_min = min(low_scores)
    low_max = max(low_scores)
    
    # Find the common range, if overlapping, or close the gap if not
    combined_min = min(low_min, high_min)
    combined_max = max(low_max, high_max)

    # Check if there is overlap
    if high_min < low_max and high_max > low_min:
        # There is overlap
        plot_min = combined_min
        plot_max = combined_max
    else:
        # No overlap, minimize the gap
        plot_min = min(high_min, low_min)
        plot_max = max(high_max, low_max)

    # Reducing vertical space between the two sets of data
    buffer = (plot_max - plot_min) * 0.03  # 10% buffer to avoid cutting off peaks
    ax.set_ylim(plot_min - buffer, plot_max + buffer)
    ## y_tick, x_tick font size
    ax.tick_params(axis='both', which='major', labelsize=fontsize*.6)

    # Setting labels and title
    ax.set_xlabel(r'$\mathcal{R}_H$', fontsize=fontsize)
    # ax.set_ylabel('Span EM', fontsize=14)
    if gini == 'mean':
        ax.set_title(f'Gini: {gini} (NQ dev)', fontsize=fontsize)
    else:
        ax.set_title(f'Gini: {gini}', fontsize=fontsize)
    ax.legend(loc='upper left', bbox_to_anchor=(0.32, 0.84), fontsize=fontsize*.6)
    # ax.legend()

plt.tight_layout()
plt.show()