In [3]:
def compute_macro_PRF(predicted_idx, gold_idx, i=-1, empty_label=None):
    '''
    This evaluation function follows work from Sorokin and Gurevych(https://www.aclweb.org/anthology/D17-1188.pdf)
    code borrowed from the following link:
    https://github.com/UKPLab/emnlp2017-relation-extraction/blob/master/relation_extraction/evaluation/metrics.py
    '''
    if i == -1:
        i = len(predicted_idx)

    complete_rel_set = set(gold_idx) - {empty_label}
    avg_prec = 0.0
    avg_rec = 0.0

    for r in complete_rel_set:
        r_indices = (predicted_idx[:i] == r)
        tp = len((predicted_idx[:i][r_indices] == gold_idx[:i][r_indices]).nonzero()[0])
        tp_fp = len(r_indices.nonzero()[0])
        tp_fn = len((gold_idx == r).nonzero()[0])
        prec = (tp / tp_fp) if tp_fp > 0 else 0
        rec = tp / tp_fn
        avg_prec += prec
        avg_rec += rec
    f1 = 0
    avg_prec = avg_prec / len(set(predicted_idx[:i]))
    avg_rec = avg_rec / len(complete_rel_set)
    if (avg_rec+avg_prec) > 0:
        f1 = 2.0 * avg_prec * avg_rec / (avg_prec + avg_rec)

    return avg_prec, avg_rec, f1

In [3]:
# Eval of the RE-QA using the Concat Templates on the dev data over all the folds.
import pandas as pd
import numpy as np

mean_f1 = 0.0
for fold_i in range(1, 11, 1):
    gold_file = "./zero-shot-extraction/relation_splits/dev.{}.concat.relation_data.csv".format(str(fold_i-1))
    gold_indices = []
    df = pd.read_csv(gold_file, sep=',')
    correct_indices = df["correct_indices"].tolist()
    for i, index in enumerate(correct_indices):
        if index:
            gold_indices.append(int(i % 12))

    num_examples = len(correct_indices) // 12
    gold_indices = np.array(gold_indices)

    max_file = None
    max_f1 = 0.0
    for checkpoint_i in range(1, 200, 1):
        prediction_file = "~/may-20/fold_{}/concat/relation.concat.dev.predictions.fold.{}.step.{}.csv".format(str(fold_i), str(fold_i), str(100 * checkpoint_i))
        pred_log_ps = pd.read_csv(prediction_file, sep=',')["relation_log_p"].tolist()
        pred_ids = np.argmax(np.reshape(np.array(pred_log_ps), (num_examples, 12)), axis=1)
        avg_prec, avg_rec, f1 = compute_macro_PRF(pred_ids, gold_indices)
        if f1 >= max_f1:
            max_f1 = f1
            max_file = prediction_file

    print(fold_i)
    print(max_f1)
    print(max_file)
    print("\r\n")
    mean_f1 += max_f1

print(mean_f1/10.0)

1
0.6645079529607614
~/may-20/fold_1/concat/relation.concat.dev.predictions.fold.1.step.3600.csv


2
0.7355692801363015
~/may-20/fold_2/concat/relation.concat.dev.predictions.fold.2.step.4300.csv


3
0.816466070295921
~/may-20/fold_3/concat/relation.concat.dev.predictions.fold.3.step.5200.csv


4
0.820538067780762
~/may-20/fold_4/concat/relation.concat.dev.predictions.fold.4.step.1600.csv


5
0.7970665456384882
~/may-20/fold_5/concat/relation.concat.dev.predictions.fold.5.step.2900.csv


6
0.9100498471715361
~/may-20/fold_6/concat/relation.concat.dev.predictions.fold.6.step.1400.csv


7
0.7789862082105365
~/may-20/fold_7/concat/relation.concat.dev.predictions.fold.7.step.2500.csv


8
0.7585498509710629
~/may-20/fold_8/concat/relation.concat.dev.predictions.fold.8.step.400.csv


9
0.7690369496615103
~/may-20/fold_9/concat/relation.concat.dev.predictions.fold.9.step.2600.csv


10
0.7394406760740089
~/may-20/fold_10/concat/relation.concat.dev.predictions.fold.10.step.800.csv


0.779021144

In [3]:
# Eval of the RE-QA using the Gold Templates on the dev data over all the folds.
import pandas as pd
import numpy as np

mean_f1 = 0.0
for fold_i in range(1, 11, 1):
    gold_file = "./zero-shot-extraction/relation_splits/dev.{}.relation_data.csv".format(str(fold_i-1))
    gold_indices = []
    df = pd.read_csv(gold_file, sep=',')
    correct_indices = df["correct_indices"].tolist()
    for i, index in enumerate(correct_indices):
        if index:
            gold_indices.append(int(i % 12))

    num_examples = len(correct_indices) // 12
    gold_indices = np.array(gold_indices)

    max_file = None
    max_f1 = 0.0
    for checkpoint_i in range(1, 200, 1):
        try:
            prediction_file = "~/may-20/fold_{}/gold/relation.gold.dev.predictions.fold.{}.step.{}.csv".format(str(fold_i), str(fold_i), str(100 * checkpoint_i))
            pred_log_ps = pd.read_csv(prediction_file, sep=',')["relation_log_p"].tolist()
            pred_ids = np.argmax(np.reshape(np.array(pred_log_ps), (num_examples, 12)), axis=1)
            avg_prec, avg_rec, f1 = compute_macro_PRF(pred_ids, gold_indices)
            if f1 >= max_f1:
                max_f1 = f1
                max_file = prediction_file
        except:
            print(checkpoint_i, fold_i)

    print(fold_i)
    print(max_f1)
    print(max_file)
    print("\r\n")
    mean_f1 += max_f1

print(mean_f1/10.0)

157 1
159 1
162 1
163 1
164 1
165 1
167 1
168 1
169 1
171 1
1
0.7958029874558785
~/may-20/fold_1/gold/relation.gold.dev.predictions.fold.1.step.600.csv


2
0.8055480874453981
~/may-20/fold_2/gold/relation.gold.dev.predictions.fold.2.step.1900.csv


3
0.8088784230714176
~/may-20/fold_3/gold/relation.gold.dev.predictions.fold.3.step.200.csv


4
0.7950547270773886
~/may-20/fold_4/gold/relation.gold.dev.predictions.fold.4.step.9500.csv


5
0.8218222460881007
~/may-20/fold_5/gold/relation.gold.dev.predictions.fold.5.step.15300.csv


6
0.9343882793208368
~/may-20/fold_6/gold/relation.gold.dev.predictions.fold.6.step.1100.csv


7
0.7977715930389874
~/may-20/fold_7/gold/relation.gold.dev.predictions.fold.7.step.2600.csv


8
0.8869445616734918
~/may-20/fold_8/gold/relation.gold.dev.predictions.fold.8.step.1000.csv


9
0.8353253359450881
~/may-20/fold_9/gold/relation.gold.dev.predictions.fold.9.step.1900.csv


10
0.8742971188094225
~/may-20/fold_10/gold/relation.gold.dev.predictions.fold.10.step

In [4]:
# MML-OFF-PGG performance for Relation Extraction on all the dev folds.
import pandas as pd
import numpy as np

mean_f1 = 0.0
for fold_i in range(1, 11, 1):
    gold_file = "./zero-shot-extraction/relation_splits/dev.{}.qq.relation_data.csv".format(str(fold_i-1))
    gold_indices = []
    df = pd.read_csv(gold_file, sep=',')
    correct_indices = df["correct_indices"].tolist()
    for i, index in enumerate(correct_indices):
        if index:
            gold_indices.append(int(i % 12))

    num_examples = len(correct_indices) // 12
    gold_indices = np.array(gold_indices)

    max_file = None
    max_f1 = 0.0
    for checkpoint_i in range(1, 200, 1):
        try:
            prediction_file = "~/may-20/fold_{}/relation.mml-pgg-off-sim.run.fold_{}.dev.predictions.step.{}.csv".format(str(fold_i), str(fold_i), str(100 * checkpoint_i))
            pred_log_ps = pd.read_csv(prediction_file, sep=',')["relation_log_p"].tolist()
            pred_log_ps = np.log(np.mean(np.reshape(np.exp(np.array(pred_log_ps)), (num_examples, 12, 8)), axis=2))
            pred_ids = np.argmax(pred_log_ps, axis=1)
            avg_prec, avg_rec, f1 = compute_macro_PRF(pred_ids, gold_indices)
            if f1 >= max_f1:
                max_f1 = f1
                max_file = prediction_file
        except:
            print(checkpoint_i)

    print(fold_i)
    print(max_f1)
    print(max_file)
    print("\r\n")
    mean_f1 += max_f1

print(mean_f1/10.0)

1
0.7161432185415761
~/may-20/fold_1/relation.mml-pgg-off-sim.run.fold_1.dev.predictions.step.4700.csv


2
0.7213443950023146
~/may-20/fold_2/relation.mml-pgg-off-sim.run.fold_2.dev.predictions.step.10300.csv


3
0.7769444992682082
~/may-20/fold_3/relation.mml-pgg-off-sim.run.fold_3.dev.predictions.step.1400.csv


4
0.8369994733568273
~/may-20/fold_4/relation.mml-pgg-off-sim.run.fold_4.dev.predictions.step.800.csv


5
0.8133015604114285
~/may-20/fold_5/relation.mml-pgg-off-sim.run.fold_5.dev.predictions.step.14700.csv


6
0.8921353519246782
~/may-20/fold_6/relation.mml-pgg-off-sim.run.fold_6.dev.predictions.step.1100.csv


197
198
199
7
0.782335978986689
~/may-20/fold_7/relation.mml-pgg-off-sim.run.fold_7.dev.predictions.step.2100.csv


198
199
8
0.7395061775208734
~/may-20/fold_8/relation.mml-pgg-off-sim.run.fold_8.dev.predictions.step.800.csv


185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
9
0.776268833348949
~/may-20/fold_9/relation.mml-pgg-off-sim.run.fold_9.dev.predi

In [5]:
# Test set performance over the 10 folds of the RE-QA dataset for the concat and gold models.
import pandas as pd
import numpy as np

gold_files = {
    1: "relation.gold.test.predictions.fold.1.step.600.csv",
    2: "relation.gold.test.predictions.fold.2.step.1900.csv",
    3: "relation.gold.test.predictions.fold.3.step.200.csv",
    4: "relation.gold.test.predictions.fold.4.step.9500.csv",
    5: "relation.gold.test.predictions.fold.5.step.15300.csv",
    6: "relation.gold.test.predictions.fold.6.step.1100.csv",
    7: "relation.gold.test.predictions.fold.7.step.2600.csv",
    8: "relation.gold.test.predictions.fold.8.step.1000.csv",
    9: "relation.gold.test.predictions.fold.9.step.1900.csv",
    10: "relation.gold.test.predictions.fold.10.step.4000.csv"
}

concat_files = {
    1: "relation.concat.test.predictions.fold.1.step.3600.csv",
    2: "relation.concat.test.predictions.fold.2.step.4300.csv",
    3: "relation.concat.test.predictions.fold.3.step.5200.csv",
    4: "relation.concat.test.predictions.fold.4.step.1600.csv",
    5: "relation.concat.test.predictions.fold.5.step.2900.csv",
    6: "relation.concat.test.predictions.fold.6.step.1400.csv",
    7: "relation.concat.test.predictions.fold.7.step.2500.csv",
    8: "relation.concat.test.predictions.fold.8.step.400.csv",
    9: "relation.concat.test.predictions.fold.9.step.2600.csv",
    10: "relation.concat.test.predictions.fold.10.step.800.csv"
}

gold_alone_p_r_f = {'f': [], 'r': [], 'p': []}
concat_alone_p_r_f = {'f': [], 'r': [], 'p': []}

for fold_i in range(1, 11, 1):
    gold_file = "./zero-shot-extraction/relation_splits/test.{}.relation_data.csv".format(str(fold_i-1))
    gold_indices = []
    df = pd.read_csv(gold_file, sep=',')
    correct_indices = df["correct_indices"].tolist()
    for i, index in enumerate(correct_indices):
        if index:
            gold_indices.append(int(i % 24))

    num_examples = len(correct_indices) // 24
    gold_indices = np.array(gold_indices)
    
    concat_prediction_file = "~/may-20/fold_{}/concat/{}".format(fold_i, concat_files[fold_i])
    concat_pred_log_ps = pd.read_csv(concat_prediction_file, sep=',')["relation_log_p"].tolist()
    concat_pred_log_ps = np.reshape(np.array(concat_pred_log_ps), (num_examples, 24))
    concat_pred_ids = np.argmax(concat_pred_log_ps, axis=1)
    avg_prec, avg_rec, f1 = compute_macro_PRF(concat_pred_ids, gold_indices)
    concat_alone_p_r_f["f"].append(f1)
    concat_alone_p_r_f["r"].append(avg_rec)
    concat_alone_p_r_f["p"].append(avg_prec)
    print(fold_i, "concat alone", f1, avg_prec, avg_rec)
    
    gold_prediction_file = "~/may-20/fold_{}/gold/{}".format(fold_i, gold_files[fold_i])
    gold_pred_log_ps = pd.read_csv(gold_prediction_file, sep=',')["relation_log_p"].tolist()
    gold_pred_log_ps = np.reshape(np.array(gold_pred_log_ps), (num_examples, 24))
    gold_pred_ids = np.argmax(gold_pred_log_ps, axis=1)
    avg_prec, avg_rec, f1 = compute_macro_PRF(gold_pred_ids, gold_indices)
    print(fold_i, "gold alone", f1, avg_prec, avg_rec)
    gold_alone_p_r_f["f"].append(f1)
    gold_alone_p_r_f["r"].append(avg_rec)
    gold_alone_p_r_f["p"].append(avg_prec)
    print("\n")

print("gold alone p:", np.mean(np.array(gold_alone_p_r_f["p"])))
print("gold alone r:", np.mean(np.array(gold_alone_p_r_f["r"])))
print("gold alone f:", np.mean(np.array(gold_alone_p_r_f["f"])))

print("concat alone p:", np.mean(np.array(concat_alone_p_r_f["p"])))
print("concat alone r:", np.mean(np.array(concat_alone_p_r_f["r"])))
print("concat alone f:", np.mean(np.array(concat_alone_p_r_f["f"])))

1 concat alone 0.6851779135044868 0.689066127748157 0.6813333333333333
1 gold alone 0.7433432938934399 0.7562889888849941 0.7308333333333333


2 concat alone 0.5901669782971629 0.5989194167218461 0.5816666666666667
2 gold alone 0.6642954078337273 0.6842181964463796 0.6455000000000001


3 concat alone 0.6373039703281808 0.6626408187363189 0.6138333333333333
3 gold alone 0.681873757976071 0.6871616100666138 0.6766666666666667


4 concat alone 0.5858742474657468 0.6150594189021233 0.5593333333333333
4 gold alone 0.6824352310400019 0.7000704484368208 0.6656666666666667


5 concat alone 0.622864337200164 0.6281502269876816 0.6176666666666667
5 gold alone 0.5005601199498438 0.5060723602043087 0.4951666666666666


6 concat alone 0.5752377444829445 0.5954721054962936 0.5563333333333333
6 gold alone 0.6651754396604047 0.6719876078130175 0.6585


7 concat alone 0.6683183924580341 0.6834597978850874 0.6538333333333333
7 gold alone 0.6543720994745615 0.6683140048618551 0.641


8 concat alone 0.597

In [6]:
import pandas as pd
import numpy as np

mml_files = {
    1: "relation.mml-pgg-off-sim.run.fold_1.test.predictions.step.4700.csv",
    2: "relation.mml-pgg-off-sim.run.fold_2.test.predictions.step.10300.csv",
    3: "relation.mml-pgg-off-sim.run.fold_3.test.predictions.step.1400.csv",
    4: "relation.mml-pgg-off-sim.run.fold_4.test.predictions.step.800.csv",
    5: "relation.mml-pgg-off-sim.run.fold_5.test.predictions.step.14700.csv",
    6: "relation.mml-pgg-off-sim.run.fold_6.test.predictions.step.1100.csv",
    7: "relation.mml-pgg-off-sim.run.fold_7.test.predictions.step.2100.csv",
    8: "relation.mml-pgg-off-sim.run.fold_8.test.predictions.step.800.csv",
    9: "relation.mml-pgg-off-sim.run.fold_9.test.predictions.step.1300.csv",
    10: "relation.mml-pgg-off-sim.run.fold_10.test.predictions.step.1600.csv"
}

mml_mml_p_r_f = {'f': [], 'r': [], 'p': []}

for fold_i in range(1, 11, 1):
    gold_file = "./zero-shot-extraction/relation_splits/test.{}.qq.relation_data.csv".format(str(fold_i-1))
    gold_indices = []
    df = pd.read_csv(gold_file, sep=',')
    correct_indices = df["correct_indices"].tolist()
    for i, index in enumerate(correct_indices):
        if index:
            gold_indices.append(int(i % 24))

    num_examples = len(correct_indices) // 24
    gold_indices = np.array(gold_indices)
    
    mml_prediction_file = "~/may-20/fold_{}/{}".format(fold_i, mml_files[fold_i])
    mml_pred_log_ps = pd.read_csv(mml_prediction_file, sep=',')["relation_log_p"].tolist()

    pred_log_ps = np.log(np.mean(np.reshape(np.exp(np.array(mml_pred_log_ps)), (num_examples, 24, 8)), axis=2))
    pred_ids = np.argmax(pred_log_ps, axis=1)
    avg_prec, avg_rec, f1 = compute_macro_PRF(pred_ids, gold_indices)
    mml_mml_p_r_f["f"].append(f1)
    mml_mml_p_r_f["r"].append(avg_rec)
    mml_mml_p_r_f["p"].append(avg_prec)
    print(fold_i, "mml", f1, avg_prec, avg_rec)
    print("\n")

print("mml p:", np.mean(np.array(mml_mml_p_r_f["p"])))
print("mml r:", np.mean(np.array(mml_mml_p_r_f["r"])))
print("mml f:", np.mean(np.array(mml_mml_p_r_f["f"])))

1 mml 0.7025567406794597 0.6983315866121078 0.7068333333333334


2 mml 0.5689271319175269 0.5657240743052863 0.5721666666666667


3 mml 0.6092602516345123 0.617047962660044 0.6016666666666667


4 mml 0.647147583914774 0.6718854375442748 0.6241666666666666


5 mml 0.5420071935904519 0.5580810510027813 0.5268333333333332


6 mml 0.5655752604114211 0.5916919157383963 0.5416666666666666


7 mml 0.716283184224437 0.7214738572850482 0.7111666666666666


8 mml 0.7130307526383761 0.7087794781499889 0.717333333333333


9 mml 0.5758312752481141 0.596937432188357 0.5561666666666666


10 mml 0.633968289007302 0.6515245781177056 0.6173333333333333


mml p: 0.638147737360399
mml r: 0.6175333333333333
mml f: 0.6274587663266374


In [18]:
import pandas as pd
import numpy as np

id_file = "~/codes/QA-ZRE/fewrl_data/val_ids_12321.csv"
gold_file = "~/codes/QA-ZRE/fewrl_data/val_data_12321.csv"

prediction_files = ["~/may-29/fewrl/concat_run_1/relation.concat.run.0.dev.predictions.step.{}.csv".format(200 * i) for i in range(1, 13, 1)]
prediction_files = ["~/may-29/fewrl/concat_run_1/relation.concat.run.1.dev.predictions.step.{}.csv".format(200 * i) for i in range(1, 13, 1)]
prediction_files = ["~/may-29/fewrl/concat_run_1/relation.concat.run.2.dev.predictions.step.{}.csv".format(200 * i) for i in range(1, 13, 1)]
prediction_files = ["~/may-29/fewrl/concat_run_1/relation.concat.run.3.dev.predictions.step.{}.csv".format(200 * i) for i in range(1, 13, 1)]

df = pd.read_csv(gold_file, sep=',')
ids = {val:i for i, val in enumerate(pd.read_csv(id_file, sep=',')["relation_ids"].tolist())}
actual_ids = df["actual_ids"].tolist()
num_examples = len(actual_ids) // 5

gold_indices = []
for each_relation_id in actual_ids:
    gold_indices.append(ids[each_relation_id])

gold_indices = np.max(np.reshape(np.array(gold_indices), (num_examples, 5)), axis=1)

max_f1 = 0.0
max_file = "None"
for prediction_file in prediction_files:
    try:
        pred_log_ps = pd.read_csv(prediction_file, sep=',')["relation_log_p"].tolist()
        pred_ids = np.argmax(np.reshape(np.array(pred_log_ps), (num_examples, 5)), axis=1)
        avg_prec, avg_rec, f1 = compute_macro_PRF(pred_ids, gold_indices)
        if max_f1 <= f1:
            max_f1 = f1
            max_file = prediction_file
    except:
        print(prediction_file)

print(max_file, max_f1)

~/may-29/fewrl/concat_run_1/relation.concat.run.3.dev.predictions.step.400.csv
~/may-29/fewrl/concat_run_1/relation.concat.run.3.dev.predictions.step.1200.csv
~/may-29/fewrl/concat_run_1/relation.concat.run.3.dev.predictions.step.1600.csv
~/may-29/fewrl/concat_run_1/relation.concat.run.3.dev.predictions.step.2000.csv
~/may-29/fewrl/concat_run_1/relation.concat.run.3.dev.predictions.step.2200.csv 0.5150131421572911


In [16]:
import pandas as pd
import numpy as np

id_file = "~/codes/QA-ZRE/fewrl_data/val_ids_12321.csv"
gold_file = "~/codes/QA-ZRE/fewrl_data/val_data_12321.csv"

prediction_files = ["~/may-29/fewrl/run_1/relation.mml-pgg-off-sim.0.dev.predictions.step.{}.csv".format(200 * i) for i in range(1, 13, 1)]

df = pd.read_csv(gold_file, sep=',')
ids = {val:i for i, val in enumerate(pd.read_csv(id_file, sep=',')["relation_ids"].tolist())}
actual_ids = df["actual_ids"].tolist()
num_examples = len(actual_ids) // 5

gold_indices = []
for each_relation_id in actual_ids:
    gold_indices.append(ids[each_relation_id])

gold_indices = np.max(np.reshape(np.array(gold_indices), (num_examples, 5)), axis=1)

max_f1 = 0.0
max_file = "None"
for prediction_file in prediction_files:
    try:
        mml_pred_log_ps = pd.read_csv(prediction_file, sep=',')["relation_log_p"].tolist()
        pred_log_ps = np.log(np.mean(np.reshape(np.exp(np.array(mml_pred_log_ps)), (num_examples, 5, 8)), axis=2))
        pred_ids = np.argmax(pred_log_ps, axis=1)

        avg_prec, avg_rec, f1 = compute_macro_PRF(pred_ids, gold_indices)
        if max_f1 <= f1:
            max_f1 = f1
            max_file = prediction_file
    except:
        print(prediction_file)

print(max_file, max_f1)

  interactivity=interactivity, compiler=compiler, result=result)


~/may-29/fewrl/run_1/relation.mml-pgg-off-sim.0.dev.predictions.step.200.csv
~/may-29/fewrl/run_1/relation.mml-pgg-off-sim.0.dev.predictions.step.1000.csv 0.5675497565736127


In [17]:
import pandas as pd
import numpy as np

id_file = "~/codes/QA-ZRE/fewrl_data/test_ids_12321.csv"
gold_file = "~/codes/QA-ZRE/fewrl_data/test_data_12321.csv"

prediction_files = ["~/may-29/fewrl/concat_run_1/relation.concat.run.0.test.predictions.step.2200.csv"]

df = pd.read_csv(gold_file, sep=',')
ids = {val:i for i, val in enumerate(pd.read_csv(id_file, sep=',')["relation_ids"].tolist())}
actual_ids = df["actual_ids"].tolist()
num_examples = len(actual_ids) // 15

gold_indices = []
for each_relation_id in actual_ids:
    gold_indices.append(ids[each_relation_id])

gold_indices = np.max(np.reshape(np.array(gold_indices), (num_examples, 15)), axis=1)

max_f1 = 0.0
max_file = "None"
for prediction_file in prediction_files:
    try:
        pred_log_ps = pd.read_csv(prediction_file, sep=',')["relation_log_p"].tolist()
        pred_ids = np.argmax(np.reshape(np.array(pred_log_ps), (num_examples, 15)), axis=1)
        avg_prec, avg_rec, f1 = compute_macro_PRF(pred_ids, gold_indices)
        if max_f1 <= f1:
            max_f1 = f1
            max_file = prediction_file
    except:
        print(prediction_file)

print(max_file, max_f1)

~/may-29/fewrl/concat_run_1/relation.concat.run.0.test.predictions.step.2200.csv 0.31786250884957634
