## Heuristic Relation Matcher

In [18]:
import pandas as pd
from relation_modeling_utils import load_data, explode_labels

val_df = explode_labels(load_data("data/atomic2020_data-feb2021/dev.tsv", multi_label=True))

train_n1_df = pd.read_csv("data/atomic_split/n1/train_n1.csv")
test_n1_df = pd.read_csv("data/atomic_split/n1/test_n1.csv")

train_n3_df = pd.read_csv("data/atomic_split/n3/train_n3.csv")
test_n3_df = pd.read_csv("data/atomic_split/n3/test_n3.csv")

train_n5_df = pd.read_csv("data/atomic_split/n5/train_n5.csv")
test_n5_df = pd.read_csv("data/atomic_split/n5/test_n5.csv")

In [23]:
import torch
import numpy as np

from relation_modeling_utils import report_metrics

def convert_str_to_list(x):
    if isinstance(x, list):
        return x
    return [int(n) for n in x[1:-1].split(',')]


def evaluate(model, df):
    preds = model.predict(df)
    report_metrics(torch.tensor(preds, dtype=float), torch.tensor(np.asarray(df.label.apply(convert_str_to_list).to_list())))

In [20]:
from relation_modeling_utils import HeuristicClassifier
heuristic_model = HeuristicClassifier()

### Validation results

In [24]:
evaluate(heuristic_model, val_df)

100%|██████████| 2962/2962 [00:10<00:00, 273.85it/s]

Accuracy=0.765, precision=0.806, recall=0.816, f1=0.803





### N1 results

In [25]:
evaluate(heuristic_model, train_n1_df)

100%|██████████| 40777/40777 [02:13<00:00, 305.62it/s]


Accuracy=0.833, precision=0.835, recall=0.841, f1=0.835


In [26]:
evaluate(heuristic_model, test_n1_df)

100%|██████████| 810/810 [00:02<00:00, 286.92it/s]

Accuracy=0.788, precision=0.719, recall=0.748, f1=0.728





### N3 results

In [27]:
evaluate(heuristic_model, train_n3_df)

100%|██████████| 40516/40516 [02:07<00:00, 316.90it/s]


Accuracy=0.834, precision=0.837, recall=0.842, f1=0.837


In [28]:
evaluate(heuristic_model, test_n3_df)

100%|██████████| 1071/1071 [00:03<00:00, 300.66it/s]

Accuracy=0.739, precision=0.696, recall=0.702, f1=0.687





### N5 results

In [29]:
evaluate(heuristic_model, train_n5_df)

100%|██████████| 40395/40395 [02:07<00:00, 316.70it/s]


Accuracy=0.835, precision=0.837, recall=0.843, f1=0.837


In [30]:
evaluate(heuristic_model, test_n5_df)

100%|██████████| 1192/1192 [00:03<00:00, 301.16it/s]

Accuracy=0.728, precision=0.711, recall=0.693, f1=0.687



