In [None]:
import os

# set huggingface cache folder, must be done before loading the module
# adjust to your liking
os.environ["TRANSFORMERS_CACHE"] = "~/disk1/huggingface_cache/"
os.environ["HF_DATASETS_CACHE"] = "~/disk1/huggingface_cache/"


from datasets import load_dataset

__SRC = os.path.abspath(".") + "/.."
__DATA = __SRC + "data"
__NOTEBOOKS = __SRC + "notebooks"

# Mining beliefs

An example on `roberta-base`.

In [None]:
from datasets import load_dataset
import numpy

from miners.baert import BaertMiner
from miners.mine import LAMA_BAERT_MINER as mining_config

# load mining dataset
lama_dataset = load_dataset("lama")
dataset_size = len(lama_dataset)

# mining
nr_random_entries = 1000
random_entries_indexes = numpy.random.randint(low=0, high=dataset_size,
                                              size=nr_random_entries)
K = 100
mining_config.update({"K":K,
                      "indexes": random_entries_indexes})
miner = BaertMiner("roberta-base", "roberta", device="cuda")
predictions = miner.mine(lama_dataset["train"], mining_config)

## Belief precision

In [None]:
from miners.validation.precision import precisions_at

ground_truths, model_predictions = list()
for instance in predictions:
    ground_truths.append(instance["ground_truth_prediction"])
    model_predictions.append(instance["prediction"])

precisions = precisions_at(model_predictions, ground_truths, K=K)

## Belief precision by predicate

In [None]:
# extract predicates
predicates = numpy.array([lama_dataset["train"][i]["predicate_id"] for i in random_entries_indexes])
unique_predicates = numpy.unique(predicates)
predicate_indexes = [(predicate, numpy.argwhere(predicates == predicate).squeeze())
                     for predicate in unique_predicates]

precisions_on_predicate = dict()
for predicate, indexes in predicate_indexes:
    precisions_on_predicate[predicate] = precisions_at([model_predictions[i] for i in indexes],
                                                       [ground_truths[i] for i in indexes],
                                                       K=K)

# Results visualization

## Precisions
Precision curves on the whole dataset.

In [None]:
from bokeh.plotting import figure, show
from bokeh.io import output_notebook
# https://coolors.co/palette/ffbe0b-fb5607-ff006e-8338ec-3a86ff
output_notebook()

precision_plot = figure(title=f"Precision@{K}", x_axis_label="K", y_axis_label="Precision")
precision_plot.line(precisions[0].astype(int), precisions[1], legend_label="precision",
                    line_width=3, line_color="#FFBE0B")
precision_plot.legend.location="bottom_right"
show(precision_plot)

### Precision on specific relations
Precision curves aggregated by relation type.

In [None]:
for predicate, indexes in predicate_indexes:
    precision_plot = figure(title=f"Precision@K on predicate: {predicate}", x_axis_label="K", y_axis_label="Precision")
    precision_plot.line(precisions_on_predicate[predicate][0].astype(int),
                        precisions_on_predicate[predicate][1], legend_label=f"{predicate}",
                        line_width=3, line_color="#FFBE0B")

precision_plot.legend.location="bottom_right"
show(precision_plot)