In [13]:
def compute_macro_PRF(predicted_idx, gold_idx, i=-1, empty_label=None):
    '''
    This evaluation function follows work from Sorokin and Gurevych(https://www.aclweb.org/anthology/D17-1188.pdf)
    code borrowed from the following link:
    https://github.com/UKPLab/emnlp2017-relation-extraction/blob/master/relation_extraction/evaluation/metrics.py
    '''
    if i == -1:
        i = len(predicted_idx)

    complete_rel_set = set(gold_idx) - {empty_label}
    avg_prec = 0.0
    avg_rec = 0.0

    for r in complete_rel_set:
        r_indices = (predicted_idx[:i] == r)
        tp = len((predicted_idx[:i][r_indices] == gold_idx[:i][r_indices]).nonzero()[0])
        tp_fp = len(r_indices.nonzero()[0])
        tp_fn = len((gold_idx == r).nonzero()[0])
        prec = (tp / tp_fp) if tp_fp > 0 else 0
        rec = tp / tp_fn
        avg_prec += prec
        avg_rec += rec
    f1 = 0
    avg_prec = avg_prec / len(set(predicted_idx[:i]))
    avg_rec = avg_rec / len(complete_rel_set)
    if (avg_rec+avg_prec) > 0:
        f1 = 2.0 * avg_prec * avg_rec / (avg_prec + avg_rec)

    return avg_prec, avg_rec, f1

In [5]:
# Eval of the RE-QA using the Concat Templates on the dev data over all the folds.
import pandas as pd
import numpy as np

mean_f1 = 0.0
for fold_i in range(1, 11, 1):
    gold_file = "./zero-shot-extraction/relation_splits/dev.{}.concat.relation_data.csv".format(str(fold_i-1))
    gold_indices = []
    df = pd.read_csv(gold_file, sep=',')
    correct_indices = df["correct_indices"].tolist()
    for i, index in enumerate(correct_indices):
        if index:
            gold_indices.append(int(i % 12))

    num_examples = len(correct_indices) // 12
    gold_indices = np.array(gold_indices)

    max_file = None
    max_f1 = 0.0
    for checkpoint_i in range(1, 200, 1):
        prediction_file = "~/may-20/fold_{}/concat/relation.concat.dev.predictions.fold.{}.step.{}.csv".format(str(fold_i), str(fold_i), str(100 * checkpoint_i))
        pred_log_ps = pd.read_csv(prediction_file, sep=',')["relation_log_p"].tolist()
        pred_ids = np.argmax(np.reshape(np.array(pred_log_ps), (num_examples, 12)), axis=1)
        avg_prec, avg_rec, f1 = compute_macro_PRF(pred_ids, gold_indices)
        if f1 >= max_f1:
            max_f1 = f1
            max_file = prediction_file

    print(fold_i)
    print(max_f1)
    print(max_file)
    print("\r\n")
    mean_f1 += max_f1

print(mean_f1/10.0)

1
0.6645079529607614
~/may-20/fold_1/concat/relation.concat.dev.predictions.fold.1.step.3600.csv


2
0.7355692801363015
~/may-20/fold_2/concat/relation.concat.dev.predictions.fold.2.step.4300.csv


3
0.816466070295921
~/may-20/fold_3/concat/relation.concat.dev.predictions.fold.3.step.5200.csv


4
0.820538067780762
~/may-20/fold_4/concat/relation.concat.dev.predictions.fold.4.step.1600.csv


5
0.7970665456384882
~/may-20/fold_5/concat/relation.concat.dev.predictions.fold.5.step.2900.csv


6
0.9100498471715361
~/may-20/fold_6/concat/relation.concat.dev.predictions.fold.6.step.1400.csv


7
0.7789862082105365
~/may-20/fold_7/concat/relation.concat.dev.predictions.fold.7.step.2500.csv


8
0.7585498509710629
~/may-20/fold_8/concat/relation.concat.dev.predictions.fold.8.step.400.csv


9
0.7690369496615103
~/may-20/fold_9/concat/relation.concat.dev.predictions.fold.9.step.2600.csv


10
0.7394406760740089
~/may-20/fold_10/concat/relation.concat.dev.predictions.fold.10.step.800.csv


0.779021144

In [12]:
# Eval of the RE-QA using the Gold Templates on the dev data over all the folds.
import pandas as pd
import numpy as np

mean_f1 = 0.0
for fold_i in range(1, 11, 1):
    gold_file = "./zero-shot-extraction/relation_splits/dev.{}.relation_data.csv".format(str(fold_i-1))
    gold_indices = []
    df = pd.read_csv(gold_file, sep=',')
    correct_indices = df["correct_indices"].tolist()
    for i, index in enumerate(correct_indices):
        if index:
            gold_indices.append(int(i % 12))

    num_examples = len(correct_indices) // 12
    gold_indices = np.array(gold_indices)

    max_file = None
    max_f1 = 0.0
    for checkpoint_i in range(1, 200, 1):
        try:
            prediction_file = "~/may-20/fold_{}/gold/relation.gold.dev.predictions.fold.{}.step.{}.csv".format(str(fold_i), str(fold_i), str(100 * checkpoint_i))
            pred_log_ps = pd.read_csv(prediction_file, sep=',')["relation_log_p"].tolist()
            pred_ids = np.argmax(np.reshape(np.array(pred_log_ps), (num_examples, 12)), axis=1)
            avg_prec, avg_rec, f1 = compute_macro_PRF(pred_ids, gold_indices)
            if f1 >= max_f1:
                max_f1 = f1
                max_file = prediction_file
        except:
            print(checkpoint_i, fold_i)

    print(fold_i)
    print(max_f1)
    print(max_file)
    print("\r\n")
    mean_f1 += max_f1

print(mean_f1/10.0)

157 1
159 1
162 1
163 1
164 1
165 1
167 1
168 1
169 1
171 1
1
0.7958029874558785
~/may-20/fold_1/gold/relation.gold.dev.predictions.fold.1.step.600.csv


2
0.8055480874453981
~/may-20/fold_2/gold/relation.gold.dev.predictions.fold.2.step.1900.csv


3
0.8088784230714176
~/may-20/fold_3/gold/relation.gold.dev.predictions.fold.3.step.200.csv


4
0.7950547270773886
~/may-20/fold_4/gold/relation.gold.dev.predictions.fold.4.step.9500.csv


5
0.8218222460881007
~/may-20/fold_5/gold/relation.gold.dev.predictions.fold.5.step.15300.csv


6
0.9343882793208368
~/may-20/fold_6/gold/relation.gold.dev.predictions.fold.6.step.1100.csv


7
0.7977715930389874
~/may-20/fold_7/gold/relation.gold.dev.predictions.fold.7.step.2600.csv


8
0.8869445616734918
~/may-20/fold_8/gold/relation.gold.dev.predictions.fold.8.step.1000.csv


9
0.8353253359450881
~/may-20/fold_9/gold/relation.gold.dev.predictions.fold.9.step.1900.csv


10
0.8742971188094225
~/may-20/fold_10/gold/relation.gold.dev.predictions.fold.10.step

In [9]:
# MML-OFF-PGG performance for Relation Extraction on all the dev folds.
import pandas as pd
import numpy as np

mean_f1 = 0.0
for fold_i in range(1, 11, 1):
    gold_file = "./zero-shot-extraction/relation_splits/dev.{}.qq.relation_data.csv".format(str(fold_i-1))
    gold_indices = []
    df = pd.read_csv(gold_file, sep=',')
    correct_indices = df["correct_indices"].tolist()
    for i, index in enumerate(correct_indices):
        if index:
            gold_indices.append(int(i % 12))

    num_examples = len(correct_indices) // 12
    gold_indices = np.array(gold_indices)

    max_file = None
    max_f1 = 0.0
    for checkpoint_i in range(1, 200, 1):
        try:
            prediction_file = "~/may-20/fold_{}/relation.mml-pgg-off-sim.run.fold_{}.dev.predictions.step.{}.csv".format(str(fold_i), str(fold_i), str(100 * checkpoint_i))
            pred_log_ps = pd.read_csv(prediction_file, sep=',')["relation_log_p"].tolist()
            pred_log_ps = np.log(np.mean(np.reshape(np.exp(np.array(pred_log_ps)), (num_examples, 12, 8)), axis=2))
            pred_ids = np.argmax(pred_log_ps, axis=1)
            avg_prec, avg_rec, f1 = compute_macro_PRF(pred_ids, gold_indices)
            if f1 >= max_f1:
                max_f1 = f1
                max_file = prediction_file
        except:
            print(checkpoint_i)

    print(fold_i)
    print(max_f1)
    print(max_file)
    print("\r\n")
    mean_f1 += max_f1

print(mean_f1/10.0)

1
0.7161432185415761
~/may-20/fold_1/relation.mml-pgg-off-sim.run.fold_1.dev.predictions.step.4700.csv


2
0.7213443950023146
~/may-20/fold_2/relation.mml-pgg-off-sim.run.fold_2.dev.predictions.step.10300.csv


3
0.7769444992682082
~/may-20/fold_3/relation.mml-pgg-off-sim.run.fold_3.dev.predictions.step.1400.csv


4
0.8369994733568273
~/may-20/fold_4/relation.mml-pgg-off-sim.run.fold_4.dev.predictions.step.800.csv


5
0.8133015604114285
~/may-20/fold_5/relation.mml-pgg-off-sim.run.fold_5.dev.predictions.step.14700.csv


6
0.8921353519246782
~/may-20/fold_6/relation.mml-pgg-off-sim.run.fold_6.dev.predictions.step.1100.csv


197
198
199
7
0.782335978986689
~/may-20/fold_7/relation.mml-pgg-off-sim.run.fold_7.dev.predictions.step.2100.csv


198
199
8
0.7395061775208734
~/may-20/fold_8/relation.mml-pgg-off-sim.run.fold_8.dev.predictions.step.800.csv


185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
9
0.776268833348949
~/may-20/fold_9/relation.mml-pgg-off-sim.run.fold_9.dev.predi

In [44]:
import pandas as pd
import numpy as np

gold_files = {
    1: "relation.gold.test.predictions.fold.1.step.600.csv",
    2: "relation.gold.test.predictions.fold.2.step.1900.csv",
    3: "relation.gold.test.predictions.fold.3.step.200.csv",
    4: "relation.gold.test.predictions.fold.4.step.9500.csv",
    5: "relation.gold.test.predictions.fold.5.step.15300.csv",
    6: "relation.gold.test.predictions.fold.6.step.1100.csv",
    7: "relation.gold.test.predictions.fold.7.step.2600.csv",
    8: "relation.gold.test.predictions.fold.8.step.1000.csv",
    9: "relation.gold.test.predictions.fold.9.step.1900.csv",
    10: "relation.gold.test.predictions.fold.10.step.4000.csv"
}

concat_files = {
    1: "relation.concat.test.predictions.fold.1.step.3600.csv",
    2: "relation.concat.test.predictions.fold.2.step.4300.csv",
    3: "relation.concat.test.predictions.fold.3.step.5200.csv",
    4: "relation.concat.test.predictions.fold.4.step.1600.csv",
    5: "relation.concat.test.predictions.fold.5.step.2900.csv",
    6: "relation.concat.test.predictions.fold.6.step.1400.csv",
    7: "relation.concat.test.predictions.fold.7.step.2500.csv",
    8: "relation.concat.test.predictions.fold.8.step.400.csv",
    9: "relation.concat.test.predictions.fold.9.step.2600.csv",
    10: "relation.concat.test.predictions.fold.10.step.800.csv"
}

gold_alone_p_r_f = {'f': [], 'r': [], 'p': []}
concat_alone_p_r_f = {'f': [], 'r': [], 'p': []}

for fold_i in range(1, 11, 1):
    gold_file = "./zero-shot-extraction/relation_splits/test.{}.relation_data.csv".format(str(fold_i-1))
    gold_indices = []
    df = pd.read_csv(gold_file, sep=',')
    correct_indices = df["correct_indices"].tolist()
    for i, index in enumerate(correct_indices):
        if index:
            gold_indices.append(int(i % 24))

    num_examples = len(correct_indices) // 24
    gold_indices = np.array(gold_indices)
    
    concat_prediction_file = "/Users/saeed/Desktop/codes/repos/QA-ZRE/relation_predictions/{}".format(concat_files[fold_i])
    concat_pred_log_ps = pd.read_csv(concat_prediction_file, sep=',')["relation_log_p"].tolist()
    concat_pred_log_ps = np.reshape(np.array(concat_pred_log_ps), (num_examples, 24))
    concat_pred_ids = np.argmax(concat_pred_log_ps, axis=1)
    avg_prec, avg_rec, f1 = compute_macro_PRF(concat_pred_ids, gold_indices)
    concat_alone_p_r_f["f"].append(f1)
    concat_alone_p_r_f["r"].append(avg_rec)
    concat_alone_p_r_f["p"].append(avg_prec)
    print(fold_i, "concat alone", f1, avg_prec, avg_rec)
    
    gold_prediction_file = "/Users/saeed/Desktop/codes/repos/QA-ZRE/relation_predictions/{}".format(gold_files[fold_i])
    gold_pred_log_ps = pd.read_csv(gold_prediction_file, sep=',')["relation_log_p"].tolist()
    gold_pred_log_ps = np.reshape(np.array(gold_pred_log_ps), (num_examples, 24))
    gold_pred_ids = np.argmax(gold_pred_log_ps, axis=1)
    avg_prec, avg_rec, f1 = compute_macro_PRF(gold_pred_ids, gold_indices)
    print(fold_i, "gold alone", f1, avg_prec, avg_rec)
    gold_alone_p_r_f["f"].append(f1)
    gold_alone_p_r_f["r"].append(avg_rec)
    gold_alone_p_r_f["p"].append(avg_prec)
    print("\n")

    prior_prediction_file = "/Users/saeed/Desktop/codes/repos/QA-ZRE/relation_predictions/{}".format(lm_prior_files[fold_i])
    prior_log_ps = pd.read_csv(prior_prediction_file, sep=',')["relation_log_p"].tolist()
    prior_log_ps = np.reshape(np.array(prior_log_ps), (num_examples, 24))

    prior_pred_ids = np.argmax(prior_log_ps, axis=1)
    avg_prec, avg_rec, f1 = compute_macro_PRF(prior_pred_ids, gold_indices)
    print(fold_i, "prior alone", f1, avg_prec, avg_rec)
    print("\n")
    
    n_prior_log_ps = (prior_log_ps - np.min(prior_log_ps, axis=1, keepdims=True)) / (np.max(prior_log_ps, axis=1, keepdims=True) - np.min(prior_log_ps, axis=1, keepdims=True))
    n_gold_pred_log_ps = (gold_pred_log_ps - np.min(gold_pred_log_ps, axis=1, keepdims=True)) / (np.max(gold_pred_log_ps, axis=1, keepdims=True) - np.min(gold_pred_log_ps, axis=1, keepdims=True))
    sum_prior_gold_log_ps = 0.03 * n_prior_log_ps + n_gold_pred_log_ps
    sum_pred_ids = np.argmax(sum_prior_gold_log_ps, axis=1)
    avg_prec, avg_rec, f1 = compute_macro_PRF(sum_pred_ids, gold_indices)
    print(fold_i, "gold + prior", f1, avg_prec, avg_rec)
    gold_p_r_f["f"].append(f1)
    gold_p_r_f["r"].append(avg_rec)
    gold_p_r_f["p"].append(avg_prec)
    print("\n")
    
    
    n_prior_log_ps = (prior_log_ps - np.min(prior_log_ps, axis=1, keepdims=True)) / (np.max(prior_log_ps, axis=1, keepdims=True) - np.min(prior_log_ps, axis=1, keepdims=True))
    n_concat_pred_log_ps = (concat_pred_log_ps - np.min(concat_pred_log_ps, axis=1, keepdims=True)) / (np.max(concat_pred_log_ps, axis=1, keepdims=True) - np.min(concat_pred_log_ps, axis=1, keepdims=True))
    sum_prior_concat_log_ps = 0.03 * n_prior_log_ps + n_concat_pred_log_ps
    sum_pred_ids = np.argmax(sum_prior_concat_log_ps, axis=1)
    avg_prec, avg_rec, f1 = compute_macro_PRF(sum_pred_ids, gold_indices)
    print(fold_i, "concat + prior", f1, avg_prec, avg_rec)
    concat_p_r_f["f"].append(f1)
    concat_p_r_f["r"].append(avg_rec)
    concat_p_r_f["p"].append(avg_prec)
    print("\n")

print("gold alone p:", np.mean(np.array(gold_alone_p_r_f["p"])))
print("gold alone r:", np.mean(np.array(gold_alone_p_r_f["r"])))
print("gold alone f:", np.mean(np.array(gold_alone_p_r_f["f"])))

print("concat alone p:", np.mean(np.array(concat_alone_p_r_f["p"])))
print("concat alone r:", np.mean(np.array(concat_alone_p_r_f["r"])))
print("concat alone f:", np.mean(np.array(concat_alone_p_r_f["f"])))

print("gold p:", np.mean(np.array(gold_p_r_f["p"])))
print("gold r:", np.mean(np.array(gold_p_r_f["r"])))
print("gold f:", np.mean(np.array(gold_p_r_f["f"])))

print("concat p:", np.mean(np.array(concat_p_r_f["p"])))
print("concat r:", np.mean(np.array(concat_p_r_f["r"])))
print("concat f:", np.mean(np.array(concat_p_r_f["f"])))

1 concat alone 0.6851779135044868 0.689066127748157 0.6813333333333333
1 gold alone 0.7480997866613807 0.7437507343003492 0.7525


1 prior alone 0.3374536793487677 0.5286055943723491 0.24783333333333335


1 gold + prior 0.7464523973862477 0.7535367611573248 0.7395


1 concat + prior 0.6709794511481647 0.680029978673471 0.6621666666666667


2 concat alone 0.5901669782971629 0.5989194167218461 0.5816666666666667
2 gold alone 0.5682775771606646 0.5928433586883303 0.5456666666666667


2 prior alone 0.30836258096198693 0.40983299948646995 0.24716666666666662


2 gold + prior 0.5823453696013038 0.6044119119791929 0.5618333333333333


2 concat + prior 0.5943873047882167 0.6010890693513659 0.5878333333333333


3 concat alone 0.6373039703281808 0.6626408187363189 0.6138333333333333
3 gold alone 0.681873757976071 0.6871616100666138 0.6766666666666667


3 prior alone 0.27247005550823505 0.4351156914702811 0.19833333333333333


3 gold + prior 0.6323186088139521 0.6433419458457007 0.621666666666666

In [None]:
import pandas as pd
import numpy as np

gold_file = "/Users/saeed/Desktop/codes/repos/QA-ZRE/zero-shot-extraction/relation_splits/test.0.relation_data.csv"
gold_indices = []
df = pd.read_csv(gold_file, sep=',')
correct_indices = df["correct_indices"].tolist()
for i, index in enumerate(correct_indices):
    if index:
        gold_indices.append(int(i % 24))

num_examples = len(correct_indices) // 24
gold_indices = np.array(gold_indices)

prior_prediction_file = "/Users/saeed/Desktop/codes/repos/QA-ZRE/relation_predictions/{}".format(lm_prior_files[fold_i])
prior_log_ps = pd.read_csv(prior_prediction_file, sep=',')["relation_log_p"].tolist()
prior_log_ps = np.reshape(np.array(prior_log_ps), (num_examples, 24))

prior_pred_ids = np.argmax(prior_log_ps, axis=1)
avg_prec, avg_rec, f1 = compute_macro_PRF(prior_pred_ids, gold_indices)
print(fold_i, "prior alone", f1, avg_prec, avg_rec)
print("\n")

In [87]:
# Check MML-MML-OFF-SIM on the fold 1 dev file.
import pandas as pd
import numpy as np

gold_file = "/Users/saeed/Desktop/codes/repos/QA-ZRE/zero-shot-extraction/relation_splits/dev.0.qq.relation_data.csv"
gold_indices = []
df = pd.read_csv(gold_file, sep=',')
correct_indices = df["correct_indices"].tolist()
for i, index in enumerate(correct_indices):
    if index:
        gold_indices.append(int(i % 12))

num_examples = len(correct_indices) // 12
gold_indices = np.array(gold_indices)

'''
max_f1 = 0.0
max_file = None
for step_i in range(1, 93, 1):
    prior_prediction_file = "/Users/saeed/Desktop/codes/repos/QA-ZRE/march-29-re-predictions/relation_dev_predictions/relation.mml-mml-off-sim.run.0.dev.predictions.step.{}.csv".format(str(100 * step_i))
    prior_log_ps = pd.read_csv(prior_prediction_file, sep=',')["relation_log_p"].tolist()
    prior_log_ps = np.reshape(np.array(prior_log_ps), (num_examples, 12, 8))

    prior_pred_ids = np.argmax(np.mean(prior_log_ps, axis=2), axis=1)
    avg_prec, avg_rec, f1 = compute_macro_PRF(prior_pred_ids, gold_indices)
    if f1 >= max_f1:
        max_f1 = f1
        max_file = prior_prediction_file

print(max_file, max_f1)
print("\n")
'''

prior_prediction_file = "/Users/saeed/Desktop/codes/repos/QA-ZRE/march-29-re-predictions/relation.mml-pgg-off-sim.run.0.dev.predictions.step.9300.csv"
prior_answer_log_ps = pd.read_csv(prior_prediction_file, sep=',')["answer_log_p"].tolist()
prior_question_log_ps = pd.read_csv(prior_prediction_file, sep=',')["question_log_p"].tolist()
prior_gen_questions = pd.read_csv(prior_prediction_file, sep=',')["generated_question"].tolist()
prior_gen_questions_len = [len(each.split()) for each in prior_gen_questions]


prior_answer_log_ps = np.reshape(np.array(prior_answer_log_ps), (num_examples, 12, 8))
prior_question_log_ps = np.reshape(np.array(prior_question_log_ps), (num_examples, 12, 8))
prior_gen_questions_len = np.reshape(np.array(prior_gen_questions_len), (num_examples, 12, 8))

prior_question_log_ps = np.mean(prior_question_log_ps / prior_gen_questions_len, axis=2)
prior_answer_log_ps = np.mean(prior_answer_log_ps, axis=2)

n_prior_answer_log_ps = (prior_answer_log_ps - np.min(prior_answer_log_ps, axis=1, keepdims=True)) / (np.max(prior_answer_log_ps, axis=1, keepdims=True) - np.min(prior_answer_log_ps, axis=1, keepdims=True))
n_prior_question_log_ps = (prior_question_log_ps - np.min(prior_question_log_ps, axis=1, keepdims=True)) / (np.max(prior_question_log_ps, axis=1, keepdims=True) - np.min(prior_question_log_ps, axis=1, keepdims=True))

prior_pred_ids = np.argmax(n_prior_answer_log_ps, axis=1)


mml_wrongs = []
for index in range(num_examples):
    if prior_pred_ids[index] != gold_indices[index]:
        mml_wrongs.append(index)

avg_prec, avg_rec, f1 = compute_macro_PRF(prior_pred_ids, gold_indices)
print(avg_prec, avg_rec, f1)


'''
concat_prediction_file = "/Users/saeed/Desktop/codes/repos/QA-ZRE/relation_predictions/relation.concat.dev.predictions.fold.1.step.3600.csv"
concat_log_ps = pd.read_csv(concat_prediction_file, sep=',')["relation_log_p"].tolist()
concat_log_ps = np.reshape(np.array(concat_log_ps), (num_examples, 12))

concat_pred_ids = np.argmax(concat_log_ps, axis=1)


concat_wrongs = []
for index in range(num_examples):
    if concat_pred_ids[index] != gold_indices[index]:
        concat_wrongs.append(index)

avg_prec, avg_rec, f1 = compute_macro_PRF(concat_pred_ids, gold_indices)
print(avg_prec, avg_rec, f1)

print(mml_wrongs)

print("\n")

print(concat_wrongs)

print("\n")

print("Shared:", set(mml_wrongs) & set(concat_wrongs))

print("\n")

print("mml - concat", set(mml_wrongs) - set(concat_wrongs))

print("\n")

print("concat - mml", set(concat_wrongs) - set(mml_wrongs))
'''

0.6535216783123196 0.65 0.6517560819594281


'\nconcat_prediction_file = "/Users/saeed/Desktop/codes/repos/QA-ZRE/relation_predictions/relation.concat.dev.predictions.fold.1.step.3600.csv"\nconcat_log_ps = pd.read_csv(concat_prediction_file, sep=\',\')["relation_log_p"].tolist()\nconcat_log_ps = np.reshape(np.array(concat_log_ps), (num_examples, 12))\n\nconcat_pred_ids = np.argmax(concat_log_ps, axis=1)\n\n\nconcat_wrongs = []\nfor index in range(num_examples):\n    if concat_pred_ids[index] != gold_indices[index]:\n        concat_wrongs.append(index)\n\navg_prec, avg_rec, f1 = compute_macro_PRF(concat_pred_ids, gold_indices)\nprint(avg_prec, avg_rec, f1)\n\nprint(mml_wrongs)\n\nprint("\n")\n\nprint(concat_wrongs)\n\nprint("\n")\n\nprint("Shared:", set(mml_wrongs) & set(concat_wrongs))\n\nprint("\n")\n\nprint("mml - concat", set(mml_wrongs) - set(concat_wrongs))\n\nprint("\n")\n\nprint("concat - mml", set(concat_wrongs) - set(mml_wrongs))\n'