In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import json
import pickle
import time
from pathlib import Path

import pandas as pd


import torch
from sklearn.metrics import f1_score, hamming_loss

import neurosym as ns
from neurosym.examples import near

LIST_LENGTH = 5

In [28]:
def bce_loss(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    """
    Compute the binary cross entropy loss with class weights for the Tiny CalMS21 dataset.
    This is the same loss function used in the base NEAR implementation.

    Args:
        predictions (torch.Tensor): Model predictions with shape (B, T, O).
        targets (torch.Tensor): Ground truth labels with shape (B, T, 1).

    Returns:
        torch.Tensor: The computed binary cross-entropy loss.
    """
    targets = targets.squeeze(-1)  # (B, T, 1) -> (B, T)
    predictions = predictions.view(-1, predictions.shape[-1])
    targets = targets.view(-1)
    # pylint: disable=not-callable
    targets_one_hot = torch.nn.functional.one_hot(targets, num_classes=2)
    # pylint: enable=not-callable
    return torch.nn.functional.binary_cross_entropy_with_logits(
        predictions.float(),
        targets_one_hot.float(),
        weight=torch.tensor([1.0, 1.5], device=predictions.device),
    )

In [29]:
def compute_near_metrics(predictions, ground_truth):
    weighted_avg_f1 = f1_score(predictions, ground_truth, average="weighted")
    unweighted_avg_f1 = f1_score(predictions, ground_truth, average="macro")
    all_f1 = f1_score(predictions, ground_truth, average=None)
    hamming_accuracy = 1 - hamming_loss(ground_truth, predictions)
    return dict(
        f1_score=weighted_avg_f1,
        unweighted_f1=unweighted_avg_f1,
        all_f1s=all_f1,
        hamming_accuracy=hamming_accuracy,
    )

In [30]:
def run_crim13_experiment(
    output_path: str = "outputs/compression_near_logs/crim13_results.pkl",
):
    """
    Run the NEAR experiment on the CRIM13 dataset.

    Args:
        output_path (str): File path to save the resulting programs list as a pickle file.
    """

    # Prepare data and DSL
    datamodule = ns.datasets.crim13_data_example(train_seed=0, batch_size=1024)
    _, output_dim = datamodule.train.get_io_dims()
    original_dsl = near.simple_crim13_dsl(num_classes=output_dim, hidden_dim=16)

    # Trainer configuration
    trainer_cfg = near.NEARTrainerConfig(
        n_epochs=15,
        lr=1e-3,
        loss_callback=bce_loss,
    )

    neural_dsl = near.NeuralDSL.from_dsl(
        dsl=original_dsl,
        neural_hole_filler=near.GenericMLPRNNNeuralHoleFiller(hidden_size=16),
    )

    cost = near.default_near_cost(
        trainer_cfg=trainer_cfg,
        datamodule=datamodule,
        structural_cost_weight=0.05,
    )

    # Create the NEAR graph
    g = near.near_graph(
        neural_dsl,
        neural_dsl.valid_root_types[0],
        is_goal=lambda _: True,
        cost=cost,
    )

    # Search for programs with bounded A*
    iterator = ns.search.bounded_astar(g, max_depth=10)

    programs_list = []
    start_time = time.time()

    # Collect programs up to LIST_LENGTH
    while True:
        try:
            program = next(iterator)
        except StopIteration:
            break

        timer = time.time() - start_time
        programs_list.append({"program": program, "time": timer})

        if len(programs_list) >= LIST_LENGTH:
            print("Programs list is too long")
            break

    # Optional: Drop into an interactive shell (commented out for module usage)
    # import IPython; IPython.embed()

    # Evaluate each discovered program
    for d in programs_list:
        program = d["program"]
        initialized_program = neural_dsl.initialize(program)
        _ = cost.validation_heuristic.with_n_epochs(40).compute_cost(
            neural_dsl, initialized_program, cost.embedding
        )

        feature_data = datamodule.test.inputs
        labels = datamodule.test.outputs.flatten()

        module = ns.examples.near.TorchProgramModule(neural_dsl, initialized_program)
        predictions = (
            module(torch.tensor(feature_data), environment=())
            .argmax(-1)
            .numpy()
            .flatten()
        )
        metrics = compute_near_metrics(predictions, labels)
        d["report"] = metrics

    # Save the programs_list to a pickle file
    with open(output_path, "wb") as f:
        pickle.dump(programs_list, f)

In [None]:
run_crim13_experiment("../outputs/mice_results/crim13_results.pkl")

In [69]:
def load_obj(file_path: str) -> object:
    """Load an object from a JSON or pickle file."""
    path = Path(file_path)
    if path.suffix == ".json":
        with path.open("r", encoding="utf-8") as f:
            return json.load(f)
    elif path.suffix == ".pkl":
        with path.open("rb") as f:
            return pickle.load(f)
    else:
        raise ValueError(f"Unsupported file extension: {path.suffix}")


baseline_results = load_obj("../outputs/mice_results/baseline_results.json")
our_results = load_obj("../outputs/mice_results/crim13_results.pkl")

df1 = (
    pd.DataFrame(our_results)
    .assign(
        program=lambda d: d["program"].map(ns.render_s_expression),
        method="astar_neurosymlib",
    )
    .pipe(lambda d: d.join(d.pop("report").apply(pd.Series)))
    .round(4)
)

df2 = pd.DataFrame(baseline_results).T.reset_index(names=["method"])

df = pd.concat([df1, df2], ignore_index=True).dropna(axis=1)

# Display desired columns
final_df = df[["method", "time", "f1_score", "unweighted_f1", "hamming_accuracy", "program"]]
final_df

Unnamed: 0,method,time,f1_score,unweighted_f1,hamming_accuracy,program
0,astar_neurosymlib,979.8875,0.9439,0.4719,0.8937,(output (map (affine_distance)))
1,astar_neurosymlib,979.8876,0.585,0.3919,0.642,(output (map (affine_position)))
2,astar_neurosymlib,1114.0062,0.1903,0.1259,0.1312,(output (map (add (affine_angle) (affine_dista...
3,astar_neurosymlib,1173.9176,0.8078,0.5069,0.8094,(output (map (add (affine_position) (affine_an...
4,astar_neurosymlib,1173.9176,0.8704,0.564,0.8582,(output (map (add (affine_position) (affine_po...
5,astar_near,2275.7735,0.9439,0.4719,0.8937,Start(MapPrefixes(Last5Avg(PositionSelect())))
6,enumeration,271.8955,0.1076,0.4719,0.8937,Start(Map(PositionSelect()))
7,iddfs_near,19691.4884,0.9439,0.4719,0.8937,Start(MapPrefixes(Last5Avg(AccelerationSelect(...
