## Imports

In [None]:
import torch
import pandas as pd

from data_generation import (
    generate_objects,
    build_binary_relation_dataset,
    left_of
)

from predicates import LeftOf
from axioms import kb
from trainer import train_ltn
from metrics import compute_metrics

## Configuração inicial

In [None]:
NUM_RUNS = 5
EPOCHS = 500

all_results = []

## Loop das 5 execuções

In [None]:
for run in range(NUM_RUNS):
    print(f"\n==============================")
    print(f" Execução {run + 1}")
    print(f"==============================")

    # 1. Gerar objetos
    objects = generate_objects(n=25, seed=run)

    # 2. Treinar LTN
    sat = train_ltn(kb, objects, epochs=EPOCHS, verbose=False)
    print(f"satAgg = {sat:.4f}")

    # ------------------------------
    # LEFT OF
    # ------------------------------
    X, y_true = build_binary_relation_dataset(objects, left_of)
    with torch.no_grad():
        y_pred = LeftOf(X).squeeze()
    acc, prec, rec, f1 = compute_metrics(y_true, y_pred)

    all_results.append({
        "Run": run + 1,
        "Predicate": "LeftOf",
        "satAgg": sat,
        "Accuracy": acc,
        "Precision": prec,
        "Recall": rec,
        "F1": f1
    })

    # ------------------------------
    # RIGHT OF
    # ------------------------------
    X, y_true = build_binary_relation_dataset(objects, right_of)
    with torch.no_grad():
        y_pred = RightOf(X).squeeze()
    acc, prec, rec, f1 = compute_metrics(y_true, y_pred)

    all_results.append({
        "Run": run + 1,
        "Predicate": "RightOf",
        "satAgg": sat,
        "Accuracy": acc,
        "Precision": prec,
        "Recall": rec,
        "F1": f1
    })

    # ------------------------------
    # CLOSE TO
    # ------------------------------
    X, y_true = build_binary_relation_dataset(objects, close_to)
    with torch.no_grad():
        y_pred = CloseTo(X).squeeze()
    acc, prec, rec, f1 = compute_metrics(y_true, y_pred)

    all_results.append({
        "Run": run + 1,
        "Predicate": "CloseTo",
        "satAgg": sat,
        "Accuracy": acc,
        "Precision": prec,
        "Recall": rec,
        "F1": f1
    })

    # ------------------------------
    # ABOVE
    # ------------------------------
    X, y_true = build_binary_relation_dataset(objects, above)
    with torch.no_grad():
        y_pred = Above(X).squeeze()
    acc, prec, rec, f1 = compute_metrics(y_true, y_pred)

    all_results.append({
        "Run": run + 1,
        "Predicate": "Above",
        "satAgg": sat,
        "Accuracy": acc,
        "Precision": prec,
        "Recall": rec,
        "F1": f1
    })

    # ------------------------------
    # BELOW
    # ------------------------------
    X, y_true = build_binary_relation_dataset(objects, below)
    with torch.no_grad():
        y_pred = Below(X).squeeze()
    acc, prec, rec, f1 = compute_metrics(y_true, y_pred)

    all_results.append({
        "Run": run + 1,
        "Predicate": "Below",
        "satAgg": sat,
        "Accuracy": acc,
        "Precision": prec,
        "Recall": rec,
        "F1": f1
    })

    # ------------------------------
    # IN BETWEEN (ternário)
    # ------------------------------
    X, y_true = build_inbetween_dataset(objects)
    with torch.no_grad():
        y_pred = InBetween(X).squeeze()
    acc, prec, rec, f1 = compute_metrics(y_true, y_pred)

    all_results.append({
        "Run": run + 1,
        "Predicate": "InBetween",
        "satAgg": sat,
        "Accuracy": acc,
        "Precision": prec,
        "Recall": rec,
        "F1": f1
    })

## Dataframe para o relatório

In [None]:
df = pd.DataFrame(results)
df

In [None]:
df.groupby("Predicate")[["Accuracy", "Precision", "Recall", "F1"]].mean()