## Heuristic Relation Matcher

In [1]:
import pandas as pd
from relation_modeling_utils import load_data, explode_labels

val_df = explode_labels(load_data("data/atomic2020_data-feb2021/dev.tsv", multi_label=True))

train_n1_df = pd.read_csv("data/atomic_split/n1/train_n1.csv")
test_n1_df = pd.read_csv("data/atomic_split/n1/test_n1.csv")

train_n3_df = pd.read_csv("data/atomic_split/n3/train_n3.csv")
test_n3_df = pd.read_csv("data/atomic_split/n3/test_n3.csv")

train_n5_df = pd.read_csv("data/atomic_split/n5/train_n5.csv")
test_n5_df = pd.read_csv("data/atomic_split/n5/test_n5.csv")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
import numpy as np

from relation_modeling_utils import report_metrics

def convert_str_to_list(x):
    if isinstance(x, list):
        return x
    return [int(n) for n in x[1:-1].split(',')]


def evaluate(model, df):
    preds = model.predict(df)
    report_metrics(torch.tensor(preds, dtype=float), torch.tensor(np.asarray(df.label.apply(convert_str_to_list).to_list())))

In [3]:
from relation_modeling_utils import BaseClassifier
base_model = BaseClassifier()

### Validation results

In [4]:
evaluate(base_model, val_df)

100%|██████████| 2962/2962 [00:00<00:00, 385225.69it/s]

Accuracy=0.543, precision=0.714, recall=1.000, f1=0.817





### N1 results

In [5]:
evaluate(base_model, train_n1_df)

100%|██████████| 40777/40777 [00:00<00:00, 839861.79it/s]


Accuracy=0.504, precision=0.518, recall=1.000, f1=0.679


In [6]:
evaluate(base_model, test_n1_df)

100%|██████████| 810/810 [00:00<00:00, 635381.75it/s]

Accuracy=0.402, precision=0.479, recall=1.000, f1=0.625





### N3 results

In [7]:
evaluate(base_model, train_n3_df)

100%|██████████| 40516/40516 [00:00<00:00, 956719.93it/s]

Accuracy=0.504, precision=0.518, recall=1.000, f1=0.678





In [8]:
evaluate(base_model, test_n3_df)

100%|██████████| 1071/1071 [00:00<00:00, 561582.65it/s]

Accuracy=0.430, precision=0.441, recall=1.000, f1=0.609





### N5 results

In [9]:
evaluate(base_model, train_n5_df)

100%|██████████| 40395/40395 [00:00<00:00, 245928.37it/s]


Accuracy=0.503, precision=0.517, recall=1.000, f1=0.678


In [10]:
evaluate(base_model, test_n5_df)

100%|██████████| 1192/1192 [00:00<00:00, 627067.65it/s]

Accuracy=0.449, precision=0.451, recall=1.000, f1=0.621



