In [1]:
def compute_macro_PRF(predicted_idx, gold_idx, i=-1, empty_label=None):
    '''
    This evaluation function follows work from Sorokin and Gurevych(https://www.aclweb.org/anthology/D17-1188.pdf)
    code borrowed from the following link:
    https://github.com/UKPLab/emnlp2017-relation-extraction/blob/master/relation_extraction/evaluation/metrics.py
    '''
    if i == -1:
        i = len(predicted_idx)

    complete_rel_set = set(gold_idx) - {empty_label}
    avg_prec = 0.0
    avg_rec = 0.0

    for r in complete_rel_set:
        r_indices = (predicted_idx[:i] == r)
        tp = len((predicted_idx[:i][r_indices] == gold_idx[:i][r_indices]).nonzero()[0])
        tp_fp = len(r_indices.nonzero()[0])
        tp_fn = len((gold_idx == r).nonzero()[0])
        prec = (tp / tp_fp) if tp_fp > 0 else 0
        rec = tp / tp_fn
        #print(id_to_labels[r], prec, rec, 2.0 * prec * rec / (prec + rec))
        avg_prec += prec
        avg_rec += rec
    f1 = 0
    avg_prec = avg_prec / len(set(predicted_idx[:i]))
    avg_rec = avg_rec / len(complete_rel_set)
    if (avg_rec+avg_prec) > 0:
        f1 = 2.0 * avg_prec * avg_rec / (avg_prec + avg_rec)

    return avg_prec, avg_rec, f1

In [3]:
# Dev prediction for the model on the fewrel dataset using the concat model without the negative examples.
from re import I
import pandas as pd
import numpy as np

gold_files = {
    12321: "~/codes/QA-ZRE/fewrl_data/val_data_12321.csv",
    943: "~/codes/QA-ZRE/fewrl_data/val_data_943.csv",
    111: "~/codes/QA-ZRE/fewrl_data/val_data_111.csv",
    300: "~/codes/QA-ZRE/fewrl_data/val_data_300.csv",
    1300: "~/codes/QA-ZRE/fewrl_data/val_data_1300.csv",
}
id_files = {
    12321: "~/codes/QA-ZRE/fewrl_data/val_ids_12321.csv",
    943: "~/codes/QA-ZRE/fewrl_data/val_ids_943.csv",
    111: "~/codes/QA-ZRE/fewrl_data/val_ids_111.csv",
    300: "~/codes/QA-ZRE/fewrl_data/val_ids_300.csv",
    1300: "~/codes/QA-ZRE/fewrl_data/val_ids_1300.csv",
}

seeds = [12321, 943, 111, 300, 1300]

for seed in seeds:
    predictions = ["~/sep-1/fewrel/concat_run_{}/relation.concat.run.{}.epoch.0.dev.predictions.step.{}.csv".format(seed, seed, step * 100) for step in range(1, 106, 1)]
    max_f1 = 0.0
    max_file = None
    prediction_files = predictions
    df = pd.read_csv(gold_files[seed], sep=',')
    ids = {val:i for i, val in enumerate(pd.read_csv(id_files[seed], sep=',')["relation_ids"].tolist())}
    actual_ids = df["actual_ids"].tolist()
    num_examples = len(actual_ids) // 5

    gold_indices = []
    for each_relation_id in actual_ids:
        gold_indices.append(ids[each_relation_id])

    gold_indices = np.max(np.reshape(np.array(gold_indices), (num_examples, 5)), axis=1)

    max_f1 = 0.0
    max_file = "None"
    for prediction_file in prediction_files:
        pred_log_ps = pd.read_csv(prediction_file, sep=',')["relation_log_p"].tolist()
        pred_ids = np.argmax(np.reshape(np.array(pred_log_ps), (num_examples, 5)), axis=1)
        avg_prec, avg_rec, f1 = compute_macro_PRF(pred_ids, gold_indices)
        if max_f1 <= f1:
            max_f1 = f1
            max_file = prediction_file

    print(seed, max_file, max_f1)

12321 ~/sep-1/fewrel/concat_run_12321/relation.concat.run.12321.epoch.0.dev.predictions.step.9900.csv 0.6225757946913466
943 ~/sep-1/fewrel/concat_run_943/relation.concat.run.943.epoch.0.dev.predictions.step.1300.csv 0.5099570516870796
111 ~/sep-1/fewrel/concat_run_111/relation.concat.run.111.epoch.0.dev.predictions.step.5000.csv 0.5992184520328142
300 ~/sep-1/fewrel/concat_run_300/relation.concat.run.300.epoch.0.dev.predictions.step.5500.csv 0.7290633687525613
1300 ~/sep-1/fewrel/concat_run_1300/relation.concat.run.1300.epoch.0.dev.predictions.step.1300.csv 0.6430064368535299


In [5]:
# Dev prediction for the model on the fewrel dataset using the concat model with the negative examples.
from re import I
import pandas as pd
import numpy as np

gold_files = {
    12321: "~/codes/QA-ZRE/fewrl_data_unks_more_unks/val_data_12321.csv",
    943: "~/codes/QA-ZRE/fewrl_data_unks_more_unks/val_data_943.csv",
    111: "~/codes/QA-ZRE/fewrl_data_unks_more_unks/val_data_111.csv",
    300: "~/codes/QA-ZRE/fewrl_data_unks_more_unks/val_data_300.csv",
    1300: "~/codes/QA-ZRE/fewrl_data_unks_more_unks/val_data_1300.csv",
}
id_files = {
    12321: "~/codes/QA-ZRE/fewrl_data_unks_more_unks/val_ids_12321.csv",
    943: "~/codes/QA-ZRE/fewrl_data_unks_more_unks/val_ids_943.csv",
    111: "~/codes/QA-ZRE/fewrl_data_unks_more_unks/val_ids_111.csv",
    300: "~/codes/QA-ZRE/fewrl_data_unks_more_unks/val_ids_300.csv",
    1300: "~/codes/QA-ZRE/fewrl_data_unks_more_unks/val_ids_1300.csv",
}

seeds = [12321, 943, 111, 300, 1300]

for seed in seeds:
    predictions = ["~/sep-1/fewrel/concat_run_{}_with_unks_more_unks/relation.concat.run.epoch.0.dev.predictions.step.{}.csv".format(seed, step * 100) for step in range(1, 256, 1)]
    max_f1 = 0.0
    max_file = None
    prediction_files = predictions
    df = pd.read_csv(gold_files[seed], sep=',')
    ids = {val:i for i, val in enumerate(pd.read_csv(id_files[seed], sep=',')["relation_ids"].tolist())}
    actual_ids = df["actual_ids"].tolist()
    num_examples = len(actual_ids) // 5

    gold_indices = []
    for each_relation_id in actual_ids:
        gold_indices.append(ids[each_relation_id])

    gold_indices = np.max(np.reshape(np.array(gold_indices), (num_examples, 5)), axis=1)

    max_f1 = 0.0
    max_file = "None"
    for prediction_file in prediction_files:
        pred_log_ps = pd.read_csv(prediction_file, sep=',')["relation_log_p"].tolist()
        pred_ids = np.argmax(np.reshape(np.array(pred_log_ps), (num_examples, 5)), axis=1)
        avg_prec, avg_rec, f1 = compute_macro_PRF(pred_ids, gold_indices)
        if max_f1 <= f1:
            max_f1 = f1
            max_file = prediction_file

    print(seed, max_file, max_f1)

12321 ~/sep-1/fewrel/concat_run_12321_with_unks_more_unks/relation.concat.run.epoch.0.dev.predictions.step.10400.csv 0.8633038093162008
943 ~/sep-1/fewrel/concat_run_943_with_unks_more_unks/relation.concat.run.epoch.0.dev.predictions.step.4600.csv 0.7563974270769898
111 ~/sep-1/fewrel/concat_run_111_with_unks_more_unks/relation.concat.run.epoch.0.dev.predictions.step.12000.csv 0.8697254766827304
300 ~/sep-1/fewrel/concat_run_300_with_unks_more_unks/relation.concat.run.epoch.0.dev.predictions.step.4500.csv 0.9213144024635793
1300 ~/sep-1/fewrel/concat_run_1300_with_unks_more_unks/relation.concat.run.epoch.0.dev.predictions.step.25100.csv 0.8745865113653335


In [7]:
# Dev prediction for the model on the fewrel dataset using the concat model without the negative examples.
from re import I
import pandas as pd
import numpy as np

gold_files = {
    12321: "~/codes/QA-ZRE/fewrl_data/test_data_12321.csv",
    943: "~/codes/QA-ZRE/fewrl_data/test_data_943.csv",
    111: "~/codes/QA-ZRE/fewrl_data/test_data_111.csv",
    300: "~/codes/QA-ZRE/fewrl_data/test_data_300.csv",
    1300: "~/codes/QA-ZRE/fewrl_data/test_data_1300.csv",
}
id_files = {
    12321: "~/codes/QA-ZRE/fewrl_data/test_ids_12321.csv",
    943: "~/codes/QA-ZRE/fewrl_data/test_ids_943.csv",
    111: "~/codes/QA-ZRE/fewrl_data/test_ids_111.csv",
    300: "~/codes/QA-ZRE/fewrl_data/test_ids_300.csv",
    1300: "~/codes/QA-ZRE/fewrl_data/test_ids_1300.csv",
}

prediction_files = {
    12321: "~/sep-1/fewrel/concat_run_12321/relation.concat.run.12321.epoch.0.test.predictions.step.9900.csv",
    943: "~/sep-1/fewrel/concat_run_943/relation.concat.run.943.epoch.0.test.predictions.step.1300.csv",
    111: "~/sep-1/fewrel/concat_run_111/relation.concat.run.111.epoch.0.test.predictions.step.5000.csv",
    300: "~/sep-1/fewrel/concat_run_300/relation.concat.run.300.epoch.0.test.predictions.step.5500.csv",
    1300: "~/sep-1/fewrel/concat_run_1300/relation.concat.run.1300.epoch.0.test.predictions.step.1300.csv",
}
seeds = [12321, 943, 111, 300, 1300]

avg_f1 = 0.0
avg_p = 0.0
avg_r = 0.0
for seed in seeds:
    predictions = [prediction_files[seed]]
    max_f1 = 0.0
    max_file = None
    df = pd.read_csv(gold_files[seed], sep=',')
    ids = {val:i for i, val in enumerate(pd.read_csv(id_files[seed], sep=',')["relation_ids"].tolist())}
    actual_ids = df["actual_ids"].tolist()
    num_examples = len(actual_ids) // 15

    gold_indices = []
    for each_relation_id in actual_ids:
        gold_indices.append(ids[each_relation_id])

    gold_indices = np.max(np.reshape(np.array(gold_indices), (num_examples, 15)), axis=1)

    max_f1 = 0.0
    max_file = "None"
    for prediction_file in predictions:
        pred_log_ps = pd.read_csv(prediction_file, sep=',')["relation_log_p"].tolist()
        pred_ids = np.argmax(np.reshape(np.array(pred_log_ps), (num_examples, 15)), axis=1)
        prec, rec, f1 = compute_macro_PRF(pred_ids, gold_indices)
        avg_f1 += f1
        avg_p += prec
        avg_r += rec
        if max_f1 <= f1:
            max_f1 = f1
            max_file = prediction_file

    print(seed, max_file, max_f1)

print(avg_f1/5.0)
print(avg_p/5.0)
print(avg_r/5.0)

12321 ~/sep-1/fewrel/concat_run_12321/relation.concat.run.12321.epoch.0.test.predictions.step.9900.csv 0.34701733359825704
943 ~/sep-1/fewrel/concat_run_943/relation.concat.run.943.epoch.0.test.predictions.step.1300.csv 0.3671554049998069
111 ~/sep-1/fewrel/concat_run_111/relation.concat.run.111.epoch.0.test.predictions.step.5000.csv 0.256033773197929
300 ~/sep-1/fewrel/concat_run_300/relation.concat.run.300.epoch.0.test.predictions.step.5500.csv 0.3014579520878535
1300 ~/sep-1/fewrel/concat_run_1300/relation.concat.run.1300.epoch.0.test.predictions.step.1300.csv 0.3067879888834918
0.31569049055346765
0.3082683125936674
0.32371428571428573


In [3]:
# Dev prediction for the model on the fewrel dataset using the concat model with the negative examples.
from re import I
import pandas as pd
import numpy as np

gold_files = {
    12321: "~/codes/QA-ZRE/fewrl_data_unks_more_unks/test_data_12321.csv",
    943: "~/codes/QA-ZRE/fewrl_data_unks_more_unks/test_data_943.csv",
    111: "~/codes/QA-ZRE/fewrl_data_unks_more_unks/test_data_111.csv",
    300: "~/codes/QA-ZRE/fewrl_data_unks_more_unks/test_data_300.csv",
    1300: "~/codes/QA-ZRE/fewrl_data_unks_more_unks/test_data_1300.csv",
}
id_files = {
    12321: "~/codes/QA-ZRE/fewrl_data_unks_more_unks/test_ids_12321.csv",
    943: "~/codes/QA-ZRE/fewrl_data_unks_more_unks/test_ids_943.csv",
    111: "~/codes/QA-ZRE/fewrl_data_unks_more_unks/test_ids_111.csv",
    300: "~/codes/QA-ZRE/fewrl_data_unks_more_unks/test_ids_300.csv",
    1300: "~/codes/QA-ZRE/fewrl_data_unks_more_unks/test_ids_1300.csv",
}

prediction_files = {
    12321: "~/sep-1/fewrel/concat_run_12321_with_unks_more_unks/relation.concat.run.12321.epoch.0.test.predictions.step.10400.csv",
    943: "~/sep-1/fewrel/concat_run_943_with_unks_more_unks/relation.concat.run.943.epoch.0.test.predictions.step.4600.csv",
    111: "~/sep-1/fewrel/concat_run_111_with_unks_more_unks/relation.concat.run.111.epoch.0.test.predictions.step.12000.csv",
    300: "~/sep-1/fewrel/concat_run_300_with_unks_more_unks/relation.concat.run.300.epoch.0.test.predictions.step.4500.csv",
    1300: "~/sep-1/fewrel/concat_run_1300_with_unks_more_unks/relation.concat.run.1300.epoch.0.test.predictions.step.25100.csv",
}
seeds = [12321, 943, 111, 300, 1300]

avg_f1 = 0.0
avg_p = 0.0
avg_r = 0.0
for seed in seeds:
    predictions = [prediction_files[seed]]
    max_f1 = 0.0
    max_file = None
    df = pd.read_csv(gold_files[seed], sep=',')
    ids = {val:i for i, val in enumerate(pd.read_csv(id_files[seed], sep=',')["relation_ids"].tolist())}
    actual_ids = df["actual_ids"].tolist()
    num_examples = len(actual_ids) // 15

    gold_indices = []
    for each_relation_id in actual_ids:
        gold_indices.append(ids[each_relation_id])

    gold_indices = np.max(np.reshape(np.array(gold_indices), (num_examples, 15)), axis=1)

    max_f1 = 0.0
    max_file = "None"
    for prediction_file in predictions:
        pred_log_ps = pd.read_csv(prediction_file, sep=',')["relation_log_p"].tolist()
        pred_ids = np.argmax(np.reshape(np.array(pred_log_ps), (num_examples, 15)), axis=1)
        prec, rec, f1 = compute_macro_PRF(pred_ids, gold_indices)
        avg_f1 += f1
        avg_p += prec
        avg_r += rec
        if max_f1 <= f1:
            max_f1 = f1
            max_file = prediction_file

    print(seed, max_file, max_f1)

print("f1", avg_f1/5.0)
print("p", avg_p/5.0)
print("r", avg_r/5.0)

12321 ~/sep-1/fewrel/concat_run_12321_with_unks_more_unks/relation.concat.run.12321.epoch.0.test.predictions.step.10400.csv 0.6637798375772415
943 ~/sep-1/fewrel/concat_run_943_with_unks_more_unks/relation.concat.run.943.epoch.0.test.predictions.step.4600.csv 0.6641959718712238
111 ~/sep-1/fewrel/concat_run_111_with_unks_more_unks/relation.concat.run.111.epoch.0.test.predictions.step.12000.csv 0.541544457581946
300 ~/sep-1/fewrel/concat_run_300_with_unks_more_unks/relation.concat.run.300.epoch.0.test.predictions.step.4500.csv 0.4925647294296078
1300 ~/sep-1/fewrel/concat_run_1300_with_unks_more_unks/relation.concat.run.1300.epoch.0.test.predictions.step.25100.csv 0.6213148402269071
f1 0.5966799673373852
p 0.619109167655467
r 0.5763238095238095


In [None]:
# Eval of the RE-QA for relation extraction using the concat model.
import pandas as pd
import numpy as np

mean_f1 = 0.0
for fold_i in range(1, 11, 1):
    gold_file = "./zero-shot-extraction/relation_splits/dev.{}.concat.relation_data.csv".format(str(fold_i-1))
    gold_indices = []
    df = pd.read_csv(gold_file, sep=',')
    correct_indices = df["correct_indices"].tolist()
    for i, index in enumerate(correct_indices):
        if index:
            gold_indices.append(int(i % 12))

    num_examples = len(correct_indices) // 12
    gold_indices = np.array(gold_indices)

    max_file = None
    max_f1 = 0.0
    for checkpoint_i in range(1, 200, 1):
        prediction_file = "~/reqa-predictions/fold_{}/concat/relation.concat.dev.predictions.fold.{}.step.{}.csv".format(str(fold_i), str(fold_i), str(100 * checkpoint_i))
        pred_log_ps = pd.read_csv(prediction_file, sep=',')["relation_log_p"].tolist()
        pred_ids = np.argmax(np.reshape(np.array(pred_log_ps), (num_examples, 12)), axis=1)
        avg_prec, avg_rec, f1 = compute_macro_PRF(pred_ids, gold_indices)
        if f1 >= max_f1:
            max_f1 = f1
            max_file = prediction_file

    print(fold_i)
    print(max_f1)
    print(max_file)
    print("\r\n")
    mean_f1 += max_f1

print(mean_f1/10.0)

In [None]:
# Test prediction for the RelationPrompt on the RE-QA dataset.
from re import I
import pandas as pd
import numpy as np

gold_files = {
    1: "~/codes/RelationPrompt/train_reqa_models/fold_1/extractor/pred_in_single.jsonl",
}

prediction_arrs = {
    1: ["~/codes/RelationPrompt/train_reqa_models/fold_1/extractor/pred_out_single.jsonl"]
}

for fold_id in range(1, 2, 1):
    prediction_files = prediction_arrs[fold_id]
    df = pd.read_csv(gold_files[fold_id], sep=',')
    answers = [ans.replace("</s>", "").strip() for ans in df["answers"].tolist()]
    all_classes = set(answers)
    ids = {val:i for i, val in enumerate(list(all_classes))}
    actual_ids = [ids[ans] for ans in answers]
    gold_indices = np.array(actual_ids)
    for prediction_file in prediction_files:
        prediction_ids = []
        for pred_class in pd.read_csv(prediction_file, sep=',')["predictions_str"].tolist():
            if pred_class.strip() in ids:
                prediction_ids.append(ids[pred_class.strip()])
            else:
                prediction_ids.append(-1)
        pred_ids = np.array(prediction_ids)
        avg_prec, avg_rec, f1 = compute_macro_PRF(pred_ids, gold_indices)
        if f1 > max_f1:
            max_f1 = f1
            max_file = prediction_file
        print(prediction_file, avg_prec, avg_rec, f1)

print(max_f1, max_file)

In [6]:
# Eval of the RE-QA for relation extraction using the concat model.
import pandas as pd
import numpy as np

mean_f1 = 0.0
for fold_i in range(1, 11, 1):
    gold_file = "./zero-shot-extraction/relation_splits/dev.{}.concat.relation_data.csv".format(str(fold_i-1))
    gold_indices = []
    df = pd.read_csv(gold_file, sep=',')
    correct_indices = df["correct_indices"].tolist()
    for i, index in enumerate(correct_indices):
        if index:
            gold_indices.append(int(i % 12))

    num_examples = len(correct_indices) // 12
    gold_indices = np.array(gold_indices)

    max_file = None
    max_f1 = 0.0
    for checkpoint_i in range(1, 200, 1):
        prediction_file = "~/reqa-predictions/fold_{}/concat/relation.concat.dev.predictions.fold.{}.step.{}.csv".format(str(fold_i), str(fold_i), str(100 * checkpoint_i))
        pred_log_ps = pd.read_csv(prediction_file, sep=',')["relation_log_p"].tolist()
        pred_ids = np.argmax(np.reshape(np.array(pred_log_ps), (num_examples, 12)), axis=1)
        avg_prec, avg_rec, f1 = compute_macro_PRF(pred_ids, gold_indices)
        if f1 >= max_f1:
            max_f1 = f1
            max_file = prediction_file

    print(fold_i)
    print(max_f1)
    print(max_file)
    print("\r\n")
    mean_f1 += max_f1

print(mean_f1/10.0)

1
0.6645079529607614
~/reqa-predictions/fold_1/concat/relation.concat.dev.predictions.fold.1.step.3600.csv


2
0.7355692801363015
~/reqa-predictions/fold_2/concat/relation.concat.dev.predictions.fold.2.step.4300.csv


3
0.816466070295921
~/reqa-predictions/fold_3/concat/relation.concat.dev.predictions.fold.3.step.5200.csv


4
0.820538067780762
~/reqa-predictions/fold_4/concat/relation.concat.dev.predictions.fold.4.step.1600.csv


5
0.7970665456384882
~/reqa-predictions/fold_5/concat/relation.concat.dev.predictions.fold.5.step.2900.csv


6
0.9100498471715361
~/reqa-predictions/fold_6/concat/relation.concat.dev.predictions.fold.6.step.1400.csv


7
0.7789862082105365
~/reqa-predictions/fold_7/concat/relation.concat.dev.predictions.fold.7.step.2500.csv


8
0.7585498509710629
~/reqa-predictions/fold_8/concat/relation.concat.dev.predictions.fold.8.step.400.csv


9
0.7690369496615103
~/reqa-predictions/fold_9/concat/relation.concat.dev.predictions.fold.9.step.2600.csv


10
0.7394406760740089


In [2]:
# Eval of the RE-QA using the Gold Templates on the dev data over all the folds.
import pandas as pd
import numpy as np

mean_f1 = 0.0
for fold_i in range(1, 11, 1):
    gold_file = "./zero-shot-extraction/relation_splits/dev.{}.relation_data.csv".format(str(fold_i-1))
    gold_indices = []
    df = pd.read_csv(gold_file, sep=',')
    correct_indices = df["correct_indices"].tolist()
    for i, index in enumerate(correct_indices):
        if index:
            gold_indices.append(int(i % 12))

    num_examples = len(correct_indices) // 12
    gold_indices = np.array(gold_indices)

    max_file = None
    max_f1 = 0.0
    for checkpoint_i in range(1, 200, 1):
        try:
            prediction_file = "~/reqa-predictions/fold_{}/gold/relation.gold.dev.predictions.fold.{}.step.{}.csv".format(str(fold_i), str(fold_i), str(100 * checkpoint_i))
            pred_log_ps = pd.read_csv(prediction_file, sep=',')["relation_log_p"].tolist()
            pred_ids = np.argmax(np.reshape(np.array(pred_log_ps), (num_examples, 12)), axis=1)
            avg_prec, avg_rec, f1 = compute_macro_PRF(pred_ids, gold_indices)
            if f1 >= max_f1:
                max_f1 = f1
                max_file = prediction_file
        except:
            print(checkpoint_i, fold_i)

    print(fold_i)
    print(max_f1)
    print(max_file)
    print("\r\n")
    mean_f1 += max_f1

print(mean_f1/10.0)

157 1
159 1
162 1
163 1
164 1
165 1
167 1
168 1
169 1
171 1
1
0.7958029874558785
~/reqa-predictions/fold_1/gold/relation.gold.dev.predictions.fold.1.step.600.csv


2
0.8055480874453981
~/reqa-predictions/fold_2/gold/relation.gold.dev.predictions.fold.2.step.1900.csv


3
0.8088784230714176
~/reqa-predictions/fold_3/gold/relation.gold.dev.predictions.fold.3.step.200.csv


4
0.7950547270773886
~/reqa-predictions/fold_4/gold/relation.gold.dev.predictions.fold.4.step.9500.csv


5
0.8218222460881007
~/reqa-predictions/fold_5/gold/relation.gold.dev.predictions.fold.5.step.15300.csv


6
0.9343882793208368
~/reqa-predictions/fold_6/gold/relation.gold.dev.predictions.fold.6.step.1100.csv


7
0.7977715930389874
~/reqa-predictions/fold_7/gold/relation.gold.dev.predictions.fold.7.step.2600.csv


8
0.8869445616734918
~/reqa-predictions/fold_8/gold/relation.gold.dev.predictions.fold.8.step.1000.csv


9
0.8353253359450881
~/reqa-predictions/fold_9/gold/relation.gold.dev.predictions.fold.9.step.1900.cs

In [3]:
# MML-OFF-PGG performance for Relation Extraction on all the dev folds.
import pandas as pd
import numpy as np

mean_f1 = 0.0
for fold_i in range(1, 11, 1):
    gold_file = "./zero-shot-extraction/relation_splits/dev.{}.qq.relation_data.csv".format(str(fold_i-1))
    gold_indices = []
    df = pd.read_csv(gold_file, sep=',')
    correct_indices = df["correct_indices"].tolist()
    for i, index in enumerate(correct_indices):
        if index:
            gold_indices.append(int(i % 12))

    num_examples = len(correct_indices) // 12
    gold_indices = np.array(gold_indices)

    max_file = None
    max_f1 = 0.0
    for checkpoint_i in range(1, 200, 1):
        try:
            prediction_file = "~/reqa-predictions/fold_{}/mml-pgg-off-sim/relation.mml-pgg-off-sim.run.fold_{}.dev.predictions.step.{}.csv".format(str(fold_i), str(fold_i), str(100 * checkpoint_i))
            pred_log_ps = pd.read_csv(prediction_file, sep=',')["answer_log_p"].tolist()
            pred_log_ps = np.log(np.mean(np.reshape(np.exp(np.array(pred_log_ps)), (num_examples, 12, 8)), axis=2))
            pred_ids = np.argmax(pred_log_ps, axis=1)
            avg_prec, avg_rec, f1 = compute_macro_PRF(pred_ids, gold_indices)
            if f1 >= max_f1:
                max_f1 = f1
                max_file = prediction_file
        except:
            print(checkpoint_i)

    print(fold_i)
    print(max_f1)
    print(max_file)
    print("\r\n")
    mean_f1 += max_f1

print(mean_f1/10.0)

1
0.7021862131500355
~/reqa-predictions/fold_1/mml-pgg-off-sim/relation.mml-pgg-off-sim.run.fold_1.dev.predictions.step.4700.csv


2
0.7318462713898469
~/reqa-predictions/fold_2/mml-pgg-off-sim/relation.mml-pgg-off-sim.run.fold_2.dev.predictions.step.400.csv


3
0.7766533200558716
~/reqa-predictions/fold_3/mml-pgg-off-sim/relation.mml-pgg-off-sim.run.fold_3.dev.predictions.step.3600.csv


4
0.8437707696480834
~/reqa-predictions/fold_4/mml-pgg-off-sim/relation.mml-pgg-off-sim.run.fold_4.dev.predictions.step.800.csv


5
0.8300206299665337
~/reqa-predictions/fold_5/mml-pgg-off-sim/relation.mml-pgg-off-sim.run.fold_5.dev.predictions.step.7900.csv


6
0.8906375171815566
~/reqa-predictions/fold_6/mml-pgg-off-sim/relation.mml-pgg-off-sim.run.fold_6.dev.predictions.step.700.csv


197
198
199
7
0.7827607798234402
~/reqa-predictions/fold_7/mml-pgg-off-sim/relation.mml-pgg-off-sim.run.fold_7.dev.predictions.step.2100.csv


198
199
8
0.795102231532206
~/reqa-predictions/fold_8/mml-pgg-off-sim/rela

In [11]:
# Test set performance over the 10 folds of the RE-QA dataset for the concat and gold models.
import pandas as pd
import numpy as np
import json

gold_files = {
    1: "relation.gold.test.predictions.fold.1.step.600.csv",
    2: "relation.gold.test.predictions.fold.2.step.1900.csv",
    3: "relation.gold.test.predictions.fold.3.step.200.csv",
    4: "relation.gold.test.predictions.fold.4.step.9500.csv",
    5: "relation.gold.test.predictions.fold.5.step.15300.csv",
    6: "relation.gold.test.predictions.fold.6.step.1100.csv",
    7: "relation.gold.test.predictions.fold.7.step.2600.csv",
    8: "relation.gold.test.predictions.fold.8.step.1000.csv",
    9: "relation.gold.test.predictions.fold.9.step.1900.csv",
    10: "relation.gold.test.predictions.fold.10.step.4000.csv"
}

concat_files = {
    1: "relation.concat.test.predictions.fold.1.step.3600.csv",
    2: "relation.concat.test.predictions.fold.2.step.4300.csv",
    3: "relation.concat.test.predictions.fold.3.step.5200.csv",
    4: "relation.concat.test.predictions.fold.4.step.1600.csv",
    5: "relation.concat.test.predictions.fold.5.step.2900.csv",
    6: "relation.concat.test.predictions.fold.6.step.1400.csv",
    7: "relation.concat.test.predictions.fold.7.step.2500.csv",
    8: "relation.concat.test.predictions.fold.8.step.400.csv",
    9: "relation.concat.test.predictions.fold.9.step.2600.csv",
    10: "relation.concat.test.predictions.fold.10.step.800.csv"
}

gold_alone_p_r_f = {'f': [], 'r': [], 'p': []}
concat_alone_p_r_f = {'f': [], 'r': [], 'p': []}

for fold_i in range(1, 11, 1):
    gold_file = "./zero-shot-extraction/relation_splits/test.{}.relation_data.csv".format(str(fold_i-1))
    fewrel_file = "./zero-shot-extraction/relation_splits/test.{}.fewrel_format.json".format(str(fold_i-1))

    example_indices_to_consider = set()
    with open(fewrel_file, 'r') as fin:
        short_examples = json.load(fin)
        for key, val in short_examples.items():
            for row in val:
                example_indices_to_consider.add(row["example_index"])


    gold_indices = []
    df = pd.read_csv(gold_file, sep=',')
    correct_indices = df["correct_indices"].tolist()
    for i, index in enumerate(correct_indices):
        if index:
            gold_indices.append(int(i % 24))

    gold_indices_to_consider = []
    for i, index in enumerate(gold_indices):
        if i in example_indices_to_consider:
            gold_indices_to_consider.append(index)

    num_examples = len(correct_indices) // 24
    gold_indices_to_consider = np.array(gold_indices_to_consider)
    
    concat_prediction_file = "~/may-20/fold_{}/concat/{}".format(fold_i, concat_files[fold_i])
    concat_pred_log_ps = pd.read_csv(concat_prediction_file, sep=',')["relation_log_p"].tolist()
    concat_pred_log_ps = np.reshape(np.array(concat_pred_log_ps), (num_examples, 24))
    concat_pred_ids = np.argmax(concat_pred_log_ps, axis=1)
    concat_pred_ids_to_consider = []
    for i, index in enumerate(concat_pred_ids):
        if i in example_indices_to_consider:
            concat_pred_ids_to_consider.append(index)
    
    concat_pred_ids_to_consider = np.array(concat_pred_ids_to_consider)
    avg_prec, avg_rec, f1 = compute_macro_PRF(concat_pred_ids_to_consider, gold_indices_to_consider)
    concat_alone_p_r_f["f"].append(f1)
    concat_alone_p_r_f["r"].append(avg_rec)
    concat_alone_p_r_f["p"].append(avg_prec)
    print(fold_i, "concat alone", f1, avg_prec, avg_rec)
    
    gold_prediction_file = "~/may-20/fold_{}/gold/{}".format(fold_i, gold_files[fold_i])
    gold_pred_log_ps = pd.read_csv(gold_prediction_file, sep=',')["relation_log_p"].tolist()
    gold_pred_log_ps = np.reshape(np.array(gold_pred_log_ps), (num_examples, 24))
    gold_pred_ids = np.argmax(gold_pred_log_ps, axis=1)

    gold_pred_ids_to_consider = []
    for i, index in enumerate(gold_pred_ids):
        if i in example_indices_to_consider:
            gold_pred_ids_to_consider.append(index)
    
    gold_pred_ids_to_consider = np.array(gold_pred_ids_to_consider)
    avg_prec, avg_rec, f1 = compute_macro_PRF(gold_pred_ids_to_consider, gold_indices_to_consider)
    print(fold_i, "gold alone", f1, avg_prec, avg_rec)
    gold_alone_p_r_f["f"].append(f1)
    gold_alone_p_r_f["r"].append(avg_rec)
    gold_alone_p_r_f["p"].append(avg_prec)
    print("\n")

print("gold alone p:", np.mean(np.array(gold_alone_p_r_f["p"])))
print("gold alone r:", np.mean(np.array(gold_alone_p_r_f["r"])))
print("gold alone f:", np.mean(np.array(gold_alone_p_r_f["f"])))

print("concat alone p:", np.mean(np.array(concat_alone_p_r_f["p"])))
print("concat alone r:", np.mean(np.array(concat_alone_p_r_f["r"])))
print("concat alone f:", np.mean(np.array(concat_alone_p_r_f["f"])))

1 concat alone 0.6868200485458184 0.6926748847477643 0.6810633587961235
1 gold alone 0.7404285767936459 0.7531358825049224 0.728142964768295


2 concat alone 0.5913023110477719 0.5975156340730365 0.585216877806697
2 gold alone 0.6586987557834644 0.6757540315314027 0.6424831986356715


3 concat alone 0.6324106523285894 0.6578301488771372 0.60888256050775
3 gold alone 0.676931710052006 0.6806536959107133 0.6732502083282599


4 concat alone 0.5794697498339418 0.6063706954690022 0.5548542716952589
4 gold alone 0.6728400404740734 0.6895175863911674 0.656950210374016


5 concat alone 0.6191377271586156 0.6268067958412001 0.6116540543838874
5 gold alone 0.49349253129168735 0.4986497121305824 0.48844093262952004


6 concat alone 0.5782347137534999 0.6001076131310151 0.5579002025330327
6 gold alone 0.6607776835391415 0.6691998506324075 0.6525648747509227


7 concat alone 0.6735910798575685 0.6934384410876097 0.6548482345074617
7 gold alone 0.6474800154241827 0.6618318212248525 0.633737435642919

In [4]:
import pandas as pd
import numpy as np

mml_files = {
    1: "relation.mml-pgg-off-sim.run.fold_1.test.predictions.step.4700.csv",
    2: "relation.mml-pgg-off-sim.run.fold_2.test.predictions.step.400.csv",
    3: "relation.mml-pgg-off-sim.run.fold_3.test.predictions.step.3600.csv",
    4: "relation.mml-pgg-off-sim.run.fold_4.test.predictions.step.800.csv",
    5: "relation.mml-pgg-off-sim.run.fold_5.test.predictions.step.7900.csv",
    6: "relation.mml-pgg-off-sim.run.fold_6.test.predictions.step.700.csv",
    7: "relation.mml-pgg-off-sim.run.fold_7.test.predictions.step.2100.csv",
    8: "relation.mml-pgg-off-sim.run.fold_8.test.predictions.step.6800.csv",
    9: "relation.mml-pgg-off-sim.run.fold_9.test.predictions.step.4300.csv",
    10: "relation.mml-pgg-off-sim.run.fold_10.test.predictions.step.1600.csv"
}

mml_pgg_p_r_f = {'f': [], 'r': [], 'p': []}

for fold_i in range(1, 11, 1):
    gold_file = "./zero-shot-extraction/relation_splits/test.{}.qq.relation_data.csv".format(str(fold_i-1))
    gold_indices = []
    df = pd.read_csv(gold_file, sep=',')
    correct_indices = df["correct_indices"].tolist()
    for i, index in enumerate(correct_indices):
        if index:
            gold_indices.append(int(i % 24))

    num_examples = len(correct_indices) // 24
    gold_indices = np.array(gold_indices)
    
    mml_prediction_file = "~/may-20/fold_{}/{}".format(fold_i, mml_files[fold_i])
    mml_pred_log_ps = pd.read_csv(mml_prediction_file, sep=',')["answer_log_p"].tolist()

    pred_log_ps = np.log(np.mean(np.reshape(np.exp(np.array(mml_pred_log_ps)), (num_examples, 24, 8)), axis=2))
    pred_ids = np.argmax(pred_log_ps, axis=1)
    avg_prec, avg_rec, f1 = compute_macro_PRF(pred_ids, gold_indices)
    mml_pgg_p_r_f["f"].append(f1)
    mml_pgg_p_r_f["r"].append(avg_rec)
    mml_pgg_p_r_f["p"].append(avg_prec)
    print(fold_i, "mml pgg off", f1, avg_prec, avg_rec)
    print("\n")

print("mml pgg off p:", np.mean(np.array(mml_pgg_p_r_f["p"])))
print("mml pgg off r:", np.mean(np.array(mml_pgg_p_r_f["r"])))
print("mml pgg off f:", np.mean(np.array(mml_pgg_p_r_f["f"])))

1 mml 0.7324313594407584 0.7347101443139104 0.7301666666666667




ValueError: cannot reshape array of size 692139 into shape (6000,24,8)

In [3]:
# Dev prediction for concat model on the fewrel dataset.
import pandas as pd
import numpy as np

id_files = {
    1: "~/codes/QA-ZRE/fewrl_data/val_ids_12321.csv",
    2: "~/codes/QA-ZRE/fewrl_data/val_ids_943.csv",
    3: "~/codes/QA-ZRE/fewrl_data/val_ids_111.csv",
    4: "~/codes/QA-ZRE/fewrl_data/val_ids_300.csv",
    5: "~/codes/QA-ZRE/fewrl_data/val_ids_1300.csv"
}
gold_files = {   
    1: "~/codes/QA-ZRE/fewrl_data/val_data_12321.csv",
    2: "~/codes/QA-ZRE/fewrl_data/val_data_943.csv",
    3: "~/codes/QA-ZRE/fewrl_data/val_data_111.csv",
    4: "~/codes/QA-ZRE/fewrl_data/val_data_300.csv",
    5: "~/codes/QA-ZRE/fewrl_data/val_data_1300.csv",
}

for run_id in range(1, 6, 1):
    prediction_files = ["~/may-29/fewrl/concat_run_{}/relation.concat.run.0.dev.predictions.step.{}.csv".format(run_id, 200 * i) for i in range(1, 13, 1)]
    prediction_files += ["~/may-29/fewrl/concat_run_{}/relation.concat.run.1.dev.predictions.step.{}.csv".format(run_id, 200 * i) for i in range(1, 13, 1)]
    prediction_files += ["~/may-29/fewrl/concat_run_{}/relation.concat.run.2.dev.predictions.step.{}.csv".format(run_id, 200 * i) for i in range(1, 13, 1)]
    prediction_files += ["~/may-29/fewrl/concat_run_{}/relation.concat.run.3.dev.predictions.step.{}.csv".format(run_id, 200 * i) for i in range(1, 13, 1)]

    df = pd.read_csv(gold_files[run_id], sep=',')
    ids = {val:i for i, val in enumerate(pd.read_csv(id_files[run_id], sep=',')["relation_ids"].tolist())}
    actual_ids = df["actual_ids"].tolist()
    num_examples = len(actual_ids) // 5

    gold_indices = []
    for each_relation_id in actual_ids:
        gold_indices.append(ids[each_relation_id])

    gold_indices = np.max(np.reshape(np.array(gold_indices), (num_examples, 5)), axis=1)

    max_f1 = 0.0
    max_file = "None"
    for prediction_file in prediction_files:
        try:
            pred_log_ps = pd.read_csv(prediction_file, sep=',')["relation_log_p"].tolist()
            pred_ids = np.argmax(np.reshape(np.array(pred_log_ps), (num_examples, 5)), axis=1)
            avg_prec, avg_rec, f1 = compute_macro_PRF(pred_ids, gold_indices)
            if max_f1 <= f1:
                max_f1 = f1
                max_file = prediction_file
        except:
            print(prediction_file)

    print(run_id, max_file, max_f1)

1 ~/may-29/fewrl/concat_run_1/relation.concat.run.0.dev.predictions.step.600.csv 0.5478163622378379
2 ~/may-29/fewrl/concat_run_2/relation.concat.run.0.dev.predictions.step.200.csv 0.4564517056171136
3 ~/may-29/fewrl/concat_run_3/relation.concat.run.0.dev.predictions.step.800.csv 0.6302340196983017
4 ~/may-29/fewrl/concat_run_4/relation.concat.run.1.dev.predictions.step.1000.csv 0.7054889125319359
5 ~/may-29/fewrl/concat_run_5/relation.concat.run.0.dev.predictions.step.800.csv 0.6725607501052342


In [24]:
# Test prediction for concat model on the fewrel dataset.
import pandas as pd
import numpy as np

id_files = {
    1: "~/codes/QA-ZRE/fewrl_data/test_ids_12321.csv",
    2: "~/codes/QA-ZRE/fewrl_data/test_ids_943.csv",
    3: "~/codes/QA-ZRE/fewrl_data/test_ids_111.csv",
    4: "~/codes/QA-ZRE/fewrl_data/test_ids_300.csv",
    5: "~/codes/QA-ZRE/fewrl_data/test_ids_1300.csv"
}
gold_files = {
    1: "~/codes/QA-ZRE/fewrl_data/test_data_12321.csv",
    2: "~/codes/QA-ZRE/fewrl_data/test_data_943.csv",
    3: "~/codes/QA-ZRE/fewrl_data/test_data_111.csv",
    4: "~/codes/QA-ZRE/fewrl_data/test_data_300.csv",
    5: "~/codes/QA-ZRE/fewrl_data/test_data_1300.csv",
}
test_files = {
    #1: "~/june-12/fewrl/concat_run_1/relation.concat.run.0.test.predictions.step.2000.csv",
    1: "~/sep-1/fewrel/concat_run_12321/relation.concat.run.12321.epoch.0.test.predictions.step.9900.csv",
    2: "~/june-12/fewrl/concat_run_2/relation.concat.run.0.test.predictions.step.400.csv",
    3: "~/june-12/fewrl/concat_run_3/relation.concat.run.1.test.predictions.step.800.csv",
    4: "~/june-12/fewrl/concat_run_4/relation.concat.run.0.test.predictions.step.2400.csv",
    5: "~/june-12/fewrl/concat_run_5/relation.concat.run.0.test.predictions.step.200.csv",
}

mean_f1 = 0.0
mean_p = 0.0
mean_r = 0.0
for run_id in range(1, 6, 1):
    prediction_files = [test_files[run_id]]

    df = pd.read_csv(gold_files[run_id], sep=',')
    ids = {val:i for i, val in enumerate(pd.read_csv(id_files[run_id], sep=',')["relation_ids"].tolist())}
    id_to_labels = {i: val for val, i in ids.items()}
    actual_ids = df["actual_ids"].tolist()
    num_examples = len(actual_ids) // 15

    gold_indices = []
    for i, each_relation_id in enumerate(actual_ids):
        gold_indices.append(ids[each_relation_id])

    gold_indices = np.max(np.reshape(np.array(gold_indices), (num_examples, 15)), axis=1)
    print(gold_indices)
    for prediction_file in prediction_files:
        pred_log_ps = pd.read_csv(prediction_file, sep=',')["relation_log_p"].tolist()
        pred_arr = np.reshape(np.array(pred_log_ps), (num_examples, 15))
        for index, i in enumerate(gold_indices):
            if i == 3:
                print(i, index, np.exp(pred_arr[index][:]))
        pred_ids = np.argmax(np.reshape(np.array(pred_log_ps), (num_examples, 15)), axis=1)
        avg_prec, avg_rec, f1 = compute_macro_PRF(pred_ids, gold_indices, id_to_labels)
        print(run_id, avg_prec, avg_rec, f1)
        mean_f1 += f1
        mean_p += avg_prec
        mean_r += avg_rec

mean_f1 /= 5
mean_p /= 5
mean_r /= 5
print("mean_p", mean_p)
print("mean_r", mean_r)
print("mean_f", mean_f1)

[ 0  0  0 ... 14 14 14]
3 2100 [9.80947115e-01 7.67235193e-13 9.79133572e-01 6.87376378e-01
 9.72028294e-01 9.97249902e-01 4.80552536e-04 7.37359876e-02
 7.01464205e-02 4.55554191e-01 8.23039076e-01 9.81600765e-01
 9.47021588e-01 9.86757621e-01 9.89991317e-01]
3 2101 [8.85659521e-01 4.90523373e-08 9.82911877e-01 9.91389578e-01
 9.75772242e-01 9.95607348e-01 5.83084229e-05 2.27497798e-04
 2.24397190e-05 2.21242760e-04 2.92529399e-02 7.57716139e-01
 7.69257034e-04 9.95950486e-01 9.50160936e-01]
3 2102 [9.47464754e-01 1.88632666e-04 9.54933246e-01 9.73216697e-01
 9.27676589e-01 9.60448093e-01 6.82281256e-02 3.60782967e-01
 1.88914056e-01 4.00135590e-03 4.30843614e-01 9.60290544e-01
 2.90221951e-01 9.66246762e-01 8.85933055e-01]
3 2103 [9.83078407e-01 5.15140114e-08 9.23611624e-01 1.05036224e-01
 7.86552757e-01 9.91082965e-01 1.89712501e-02 1.41888942e-05
 5.36330985e-01 8.36798934e-01 9.21041439e-01 9.77043698e-01
 6.79104250e-05 7.27380787e-01 9.86920275e-01]
3 2104 [9.95928781e-01 1.279

In [None]:
0.3326626939906449 0.36266666666666664 0.34701733359825704

In [9]:
# Dev prediction for off-mml-pgg model on the fewrel dataset.
import pandas as pd
import numpy as np

id_files = {
    1: "~/codes/QA-ZRE/small_fewrl_data/val_ids_12321.csv",
    2: "~/codes/QA-ZRE/small_fewrl_data/val_ids_943.csv",
    3: "~/codes/QA-ZRE/small_fewrl_data/val_ids_111.csv",
    4: "~/codes/QA-ZRE/small_fewrl_data/val_ids_300.csv",
    5: "~/codes/QA-ZRE/small_fewrl_data/val_ids_1300.csv"
}
gold_files = {   
    1: "~/codes/QA-ZRE/small_fewrl_data/val_data_12321.csv",
    2: "~/codes/QA-ZRE/small_fewrl_data/val_data_943.csv",
    3: "~/codes/QA-ZRE/small_fewrl_data/val_data_111.csv",
    4: "~/codes/QA-ZRE/small_fewrl_data/val_data_300.csv",
    5: "~/codes/QA-ZRE/small_fewrl_data/val_data_1300.csv",
}

for run_id in range(1, 6, 1):
    prediction_files = ["~/june-16/fewrl/run_{}/relation.mml-pgg-off-sim.run.0.dev.predictions.step.{}.csv".format(run_id, 200 * i) for i in range(1, 13, 1)]
    prediction_files += ["~/june-16/fewrl/run_{}/relation.mml-pgg-off-sim.run.1.dev.predictions.step.{}.csv".format(run_id, 200 * i) for i in range(1, 13, 1)]
    prediction_files += ["~/june-16/fewrl/run_{}/relation.mml-pgg-off-sim.run.2.dev.predictions.step.{}.csv".format(run_id, 200 * i) for i in range(1, 13, 1)]
    prediction_files += ["~/june-16/fewrl/run_{}/relation.mml-pgg-off-sim.run.3.dev.predictions.step.{}.csv".format(run_id, 200 * i) for i in range(1, 13, 1)]

    df = pd.read_csv(gold_files[run_id], sep=',')
    ids = {val:i for i, val in enumerate(pd.read_csv(id_files[run_id], sep=',')["relation_ids"].tolist())}
    actual_ids = df["actual_ids"].tolist()
    num_examples = len(actual_ids) // 5

    gold_indices = []
    for each_relation_id in actual_ids:
        gold_indices.append(ids[each_relation_id])

    gold_indices = np.max(np.reshape(np.array(gold_indices), (num_examples, 5)), axis=1)

    max_f1 = 0.0
    max_file = "None"
    for prediction_file in prediction_files:
        try:
            mml_pred_log_ps = pd.read_csv(prediction_file, sep=',')["answer_log_p"].tolist()
            pred_log_ps = np.log(np.mean(np.reshape(np.exp(np.array(mml_pred_log_ps)), (num_examples, 5, 8)), axis=2))
            pred_ids = np.argmax(pred_log_ps, axis=1)
            avg_prec, avg_rec, f1 = compute_macro_PRF(pred_ids, gold_indices)
            if max_f1 <= f1:
                max_f1 = f1
                max_file = prediction_file
        except:
            print(prediction_file)

    print(run_id, max_file, max_f1)

1 ~/june-16/fewrl/run_1/relation.mml-pgg-off-sim.run.0.dev.predictions.step.1000.csv 0.6866699556484906
~/june-16/fewrl/run_2/relation.mml-pgg-off-sim.run.3.dev.predictions.step.1000.csv
~/june-16/fewrl/run_2/relation.mml-pgg-off-sim.run.3.dev.predictions.step.1200.csv
~/june-16/fewrl/run_2/relation.mml-pgg-off-sim.run.3.dev.predictions.step.1400.csv
~/june-16/fewrl/run_2/relation.mml-pgg-off-sim.run.3.dev.predictions.step.1600.csv
~/june-16/fewrl/run_2/relation.mml-pgg-off-sim.run.3.dev.predictions.step.1800.csv
~/june-16/fewrl/run_2/relation.mml-pgg-off-sim.run.3.dev.predictions.step.2000.csv
~/june-16/fewrl/run_2/relation.mml-pgg-off-sim.run.3.dev.predictions.step.2200.csv
~/june-16/fewrl/run_2/relation.mml-pgg-off-sim.run.3.dev.predictions.step.2400.csv
2 ~/june-16/fewrl/run_2/relation.mml-pgg-off-sim.run.2.dev.predictions.step.2200.csv 0.5213218904876981
3 ~/june-16/fewrl/run_3/relation.mml-pgg-off-sim.run.2.dev.predictions.step.1200.csv 0.6660949360255928
4 ~/june-16/fewrl/run_4/

In [2]:
# Test prediction for off-mml-pgg model on the fewrel dataset.
import pandas as pd
import numpy as np

id_files = {
    1: "~/codes/QA-ZRE/small_fewrl_data/test_ids_12321.csv",
    2: "~/codes/QA-ZRE/small_fewrl_data/test_ids_943.csv",
    3: "~/codes/QA-ZRE/small_fewrl_data/test_ids_111.csv",
    4: "~/codes/QA-ZRE/small_fewrl_data/test_ids_300.csv",
    5: "~/codes/QA-ZRE/small_fewrl_data/test_ids_1300.csv"
}
gold_files = {   
    1: "~/codes/QA-ZRE/small_fewrl_data/test_data_12321.csv",
    2: "~/codes/QA-ZRE/small_fewrl_data/test_data_943.csv",
    3: "~/codes/QA-ZRE/small_fewrl_data/test_data_111.csv",
    4: "~/codes/QA-ZRE/small_fewrl_data/test_data_300.csv",
    5: "~/codes/QA-ZRE/small_fewrl_data/test_data_1300.csv",
}

test_files = {
    1: "~/june-16/fewrl/run_1/relation.mml-pgg-off-sim.run.0.test.predictions.step.1000.csv",
    2: "~/june-16/fewrl/run_2/relation.mml-pgg-off-sim.run.2.test.predictions.step.2200.csv",
    3: "~/june-16/fewrl/run_3/relation.mml-pgg-off-sim.run.2.test.predictions.step.1200.csv",
    4: "~/june-16/fewrl/run_4/relation.mml-pgg-off-sim.run.1.test.predictions.step.1400.csv",
    5: "~/june-16/fewrl/run_5/relation.mml-pgg-off-sim.run.0.test.predictions.step.2400.csv",
}

mean_f1 = 0.0
mean_p = 0.0
mean_r = 0.0
for run_id in range(1, 6, 1):
    prediction_files = [test_files[run_id]]
    df = pd.read_csv(gold_files[run_id], sep=',')
    ids = {val:i for i, val in enumerate(pd.read_csv(id_files[run_id], sep=',')["relation_ids"].tolist())}
    actual_ids = df["actual_ids"].tolist()
    num_examples = len(actual_ids) // 15

    gold_indices = []
    for each_relation_id in actual_ids:
        gold_indices.append(ids[each_relation_id])

    gold_indices = np.max(np.reshape(np.array(gold_indices), (num_examples, 15)), axis=1)
    for prediction_file in prediction_files:
        try:
            mml_pred_log_ps = pd.read_csv(prediction_file, sep=',')["answer_log_p"].tolist()
            pred_log_ps = np.log(np.sum(np.reshape(np.exp(np.array(mml_pred_log_ps)), (num_examples, 15, 8)), axis=2))
            pred_ids = np.argmax(pred_log_ps, axis=1)
            avg_prec, avg_rec, f1 = compute_macro_PRF(pred_ids, gold_indices)
            mean_f1 += f1
            mean_p += avg_prec
            mean_r += avg_rec
            print(prediction_file, avg_prec, avg_rec, f1)
        except:
            print(prediction_file)

mean_f1 /= 5
mean_p /= 5
mean_r /= 5
print("mean_p", mean_p)
print("mean_r", mean_r)
print("mean_f", mean_f1)

~/june-16/fewrl/run_1/relation.mml-pgg-off-sim.run.0.test.predictions.step.1000.csv 0.39090262621490907 0.3941904761904762 0.39253966669598334
~/june-16/fewrl/run_2/relation.mml-pgg-off-sim.run.2.test.predictions.step.2200.csv 0.38823952774766773 0.396 0.39208136684878714
~/june-16/fewrl/run_3/relation.mml-pgg-off-sim.run.2.test.predictions.step.1200.csv 0.33109424738588195 0.338 0.33451148639718353
~/june-16/fewrl/run_4/relation.mml-pgg-off-sim.run.1.test.predictions.step.1400.csv 0.29031332227659645 0.31657142857142856 0.3028743201670405
~/june-16/fewrl/run_5/relation.mml-pgg-off-sim.run.0.test.predictions.step.2400.csv 0.3921389098579965 0.42447619047619045 0.4076672854222021
mean_p 0.3585377266966104
mean_r 0.37384761904761904
mean_f 0.3659348251062393
