# SR-FBAM Inference Demo

This notebook loads a trained SR-FBAM checkpoint, runs it on example queries, and visualises the reasoning hops. Make sure you have already trained the model and saved a checkpoint in `../checkpoints/`.

In [None]:
from pathlib import Path
import sys

# Allow notebook to import project modules
PROJECT_ROOT = Path.cwd().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))

from src.data import load_dataset
from src.training.train import load_checkpoint

CHECKPOINT_PATH = PROJECT_ROOT / "checkpoints" / "sr_fbam_train_best.pt"
DATA_DIR = PROJECT_ROOT / "data"

model = load_checkpoint(CHECKPOINT_PATH, device="cpu")
kg, queries = load_dataset(DATA_DIR, split="eval")
print(f"Loaded {len(queries)} evaluation queries")

In [None]:
query = queries[0]
output = model.reason(query, kg)

print(f"Query: {query.natural_language}")
print(f"Ground truth: {query.answer_id}\n")

for hop in output.hop_traces:
    print(f"[{hop.hop_number}] {hop.action:6s} -> {hop.result} (confidence={hop.confidence:.2f})")

print(f"\nPrediction: {output.prediction_id}")
print(f"Correct: {output.prediction_id == query.answer_id}")
print(f"Hops: {len(output.hop_traces)}")

In [None]:
import networkx as nx
import matplotlib.pyplot as plt

G = nx.DiGraph()
for hop in output.hop_traces:
    G.add_node(hop.hop_number, action=hop.action, label=f"{hop.hop_number}: {hop.action}")
    if hop.hop_number > 1:
        G.add_edge(hop.hop_number - 1, hop.hop_number)

pos = nx.spring_layout(G, seed=42)
labels = nx.get_node_attributes(G, "label")
plt.figure(figsize=(6, 4))
nx.draw(G, pos, labels=labels, with_labels=True, node_color="#8ecae6", node_size=1200, font_size=10)
plt.title("Reasoning Hop Graph")
plt.show()