# Neurosymbolic Software Tutorial - ECG Dataset

<a target="_blank" href="https://colab.research.google.com/github/kavigupta/neurosym-lib/blob/main/tutorial/ecg_exercise.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

## Instruction
- Navigating this notebook on Google Colab: There will be text blocks and code blocks throughout the notebook. The text blocks, such as this one, will contain instructions and questions for you to consider. The code blocks, such as the one below, will contain executible code. Sometimes you will have to modify the code blocks following the instructions in the text blocks. You can run the code block by either pressing control/cmd + enter or by clicking the arrow on left-hand side.
- Saving Work: If you wish to save your work in this .ipynb, we recommend downloading the compressed repository from GitHub, unzipping it, uploading it to Google Drive, and opening this notebook from within Google Drive.

## Notebook

In this notebook, you will construct a DSL to analyze ECG data.

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging
import os

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

import neurosym as ns
from neurosym.examples import near

pl = ns.import_pytorch_lightning()


  from .autonotebook import tqdm as notebook_tqdm


## Data

We then load and plot some bouncing ball trajectories. Note that these trajectories are represented as a list `[x, y, vx, vy]`

In [2]:
def load_dataset_npz(features_pth, label_pth):
    assert os.path.exists(features_pth), f"{features_pth} does not exist."
    assert os.path.exists(label_pth), f"{label_pth} does not exist."
    X = np.load(features_pth)
    y = np.load(label_pth)
    return X, y


def filter_multilabel(split):
    x_fname = f"ecg_exercise/x_{split}.npy"
    y_fname = f"ecg_exercise/y_{split}.npy"
    X = np.load(x_fname)
    y = np.load(y_fname)

    mask = y.sum(-1) == 1

    # filter
    X = X[mask]
    y = y[mask]

    # normalize each column of X to [0, 1]
    X = (X - X.min(0)) / (X.max(0) - X.min(0))

    # save as filtered
    np.save(x_fname.replace(f"{split}", f"{split}_filtered"), X.astype(np.float32))
    np.save(y_fname.replace(f"{split}", f"{split}_filtered"), y.astype(np.float32))


filter_multilabel("train")
filter_multilabel("test")

In [3]:
dataset_factory = lambda train_seed: ns.DatasetWrapper(
    ns.DatasetFromNpy(
        f"ecg_exercise/x_train_filtered.npy",
        f"ecg_exercise/y_train_filtered.npy",
        train_seed,
    ),
    ns.DatasetFromNpy(
        f"ecg_exercise/x_test_filtered.npy",
        f"ecg_exercise/y_test_filtered.npy",
        None,
    ),
    batch_size=200,
    num_workers=0,
)
datamodule = dataset_factory(42)

In [4]:
def plot_trajectory(trajectory, color):
    # TODO: What is a good way to visualize the trajectory?
    pass

In [5]:
for i in range(3):
    plot_trajectory(datamodule.train.inputs[:], f"C{i}")

## Exercise: DSL

Fill in the `bounce_dsl` to parameterize the space of functions that could represent the trajectories of bouncing balls.

In [6]:
datamodule.train.get_io_dims()

(144, 2)

In [7]:
def subset_selector_all_feat(x, channel, typ):
    x = x.reshape(-1, 12, 6, 2)
    typ_idx = torch.full(
        size=(x.shape[0],), fill_value=(0 if typ == "interval" else 1), device=x.device
    )
    channel_mask = channel(x.reshape(-1, 144))  # [B, 12]
    masked_x = (x * channel_mask[..., None, None]).sum(1)
    return masked_x[torch.arange(x.shape[0]), :, typ_idx]

# def subset_selector_all_feat(x, channel, typ):
#     x = x.reshape(-1, 12, 6, 2)
#     typ_idx = torch.full(
#         size=(x.shape[0],), fill_value=(0 if typ == "interval" else 1), device=x.device
#     )
#     channel_mask = channel(x.reshape(-1, 144))  # [B, 12]
#     masked_x = (x * channel_mask[..., None, None]).sum(1)
#     return masked_x[torch.arange(x.shape[0]), :, typ_idx]

# def guard_callables(fn, **kwargs):
#     is_callable = [callable(kwargs[k]) for k in kwargs]
#     if any(is_callable):
#         return lambda z: fn(
#             **{
#                 k: (kwargs[k](z) if is_callable[i] else kwargs[k])
#                 for i, k in enumerate(kwargs)
#             }
#         )
#     else:
#         return fn(**kwargs)

# def filter_constants(x):
#     match x:
#         case ns.ArrowType(a, b):
#             return filter_constants(a) and filter_constants(b)
#         case ns.AtomicType(a):
#             return a not in ["channel", "feature"]
#         case _:
#             return True

# def filter_same_type(x):
#     raise NotImplementedError

# def ecg_dsl():
#     L = 144
#     O = 2
#     F = 6
#     dslf = ns.DSLFactory(L=L, O=O, F=F, max_overall_depth=10)
#     dslf.typedef("fInp", "{f, $L}")
#     dslf.typedef("fOut", "{f, $O}")
#     dslf.typedef("fFeat", "{f, $F}")

#     for i in range(12):
#         dslf.concrete(
#             f"channel_{i}",
#             "() -> channel",
#             # onehot vector where the ith element is 1
#             # lambda: lambda x: torch.full(
#             #     tuple(x.shape[:-1] + (1,)), i, device=x.device
#             # ),
#             lambda: lambda x: torch.nn.functional.one_hot(
#                 torch.full(tuple(x.shape[:-1]), i, device=x.device, dtype=torch.long),
#                 num_classes=12,
#             ),
#         )

#     dslf.concrete(
#         "select_interval",
#         "(channel) -> ($fInp) -> $fFeat",
#         lambda ch: lambda x: subset_selector_all_feat(x, ch, "interval"),
#     )

#     dslf.concrete(
#         "select_amplitude",
#         "(channel) -> ($fInp) -> $fFeat",
#         lambda ch: lambda x: subset_selector_all_feat(x, ch, "amplitude"),
#     )

#     dslf.filtered_type_variable("num", lambda x: filter_constants(x))
#     dslf.filtered_type_variable("num", lambda x: filter_same_type(x))
#     dslf.concrete(
#         "add",
#         "(%num, %num) -> %num",
#         lambda x, y: guard_callables(fn=lambda x, y: x + y, x=x, y=y),
#     )
#     dslf.concrete(
#         "mul",
#         "(%num, %num) -> %num",
#         lambda x, y: guard_callables(fn=lambda x, y: x * y, x=x, y=y),
#     )

#     dslf.parameterized(
#         "linear",
#         "(($fInp) -> $fFeat) -> $fInp -> {f, 1}",
#         lambda f, lin: lambda x: lin(f(x)),
#         dict(lin=lambda: nn.Linear(F, 1)),
#     )

#     dslf.parameterized(
#         "output",
#         "(($fInp) -> $fFeat) -> $fInp -> $fOut",
#         lambda f, lin: lambda x: lin(f(x)),
#         dict(lin=lambda: nn.Linear(F, O)),
#     )

#     # dslf.concrete("ite_ab", "(#a -> f, #a -> #b, #a -> #b) -> #a -> #b", near.operations.ite_torch)
#     # dslf.concrete("ite_aa", "(#a -> f, #a -> #a, #a -> #a) -> #a -> #a", near.operations.ite_torch)
#     # dslf.concrete(
#     #     "map", "(#a -> #b) -> [#a] -> [#b]", lambda f: lambda x: near.operations.map_torch(f, x)
#     # )
#     dslf.prune_to("($fInp) -> $fOut")
#     return dslf.finalize()
# dsl = ecg_dsl()


from neurosym.dsl.dsl_factory import DSLFactory
from neurosym.examples.near.operations.basic import ite_torch
from neurosym.types.type import ArrowType, AtomicType


def ecg_dsl(input_dim, output_dim, max_overall_depth=6):
    """Creates a domain-specific language (DSL) for neural symbolic computation.

    This function sets up a DSL with basic operations like addition, multiplication,
    and folds, as well as neural network components like linear layers.

    Args:
        input_dim (int): The dimensionality of the input features.
        output_dim (int): The dimensionality of the output features.

    Returns:
        DSLFactory: An instance of `DSLFactory` with the defined operations and types.
    """
    feature_dim = 6
    dslf = DSLFactory(
        I=input_dim, O=output_dim, F=feature_dim, max_overall_depth=max_overall_depth
    )
    dslf.typedef("fInp", "{f, $I}")
    dslf.typedef("fOut", "{f, $O}")
    dslf.typedef("fFeat", "{f, $F}")

    # "add(select_interval(channel_2), select_amplitude(channel_8))
    # "add(select_interval_channel_2, select_amplitude_channel_8)
    # "add(??, ??)"

    for i in range(12):
        dslf.concrete(
            f"channel_{i}",
            "() -> channel",
            # onehot vector where the ith element is 1
            # lambda: lambda x: torch.full(
            #     tuple(x.shape[:-1] + (1,)), i, device=x.device
            # ),
            lambda: lambda x: torch.nn.functional.one_hot(
                torch.full(tuple(x.shape[:-1]), i, device=x.device, dtype=torch.long),
                num_classes=12,
            ),
        )

    # for i in range(6):
    #     dslf.concrete(
    #         f"feature_{i}",
    #         "() -> () -> feature",
    #         lambda: lambda x: torch.full(
    #             tuple(x.shape[:-1] + (1,)), i, device=x.device
    #         ),
    #     )

    dslf.concrete(
        "select_interval",
        "(channel) -> ($fInp) -> $fFeat",
        lambda ch: lambda x: subset_selector_all_feat(x, ch, "interval"),
    )

    dslf.concrete(
        "select_amplitude",
        "(channel) -> ($fInp) -> $fFeat",
        lambda ch: lambda x: subset_selector_all_feat(x, ch, "amplitude"),
    )

    def guard_callables(fn, **kwargs):
        is_callable = [callable(kwargs[k]) for k in kwargs]
        if any(is_callable):
            return lambda z: fn(
                **{
                    k: (kwargs[k](z) if is_callable[i] else kwargs[k])
                    for i, k in enumerate(kwargs)
                }
            )
        else:
            return fn(**kwargs)

    def filter_constants(x):
        match x:
            case ArrowType(a, b):
                return filter_constants(a) and filter_constants(b)
            case AtomicType(a):
                return a not in ["channel", "feature"]
            case _:
                return True

    # def filter_same_type(x):
    #     raise NotImplementedError

    dslf.filtered_type_variable("num", lambda x: filter_constants(x))
    # dslf.filtered_type_variable("num", lambda x: filter_same_type(x))
    dslf.concrete(
        "add",
        "(%num, %num) -> %num",
        lambda x, y: guard_callables(fn=lambda x, y: x + y, x=x, y=y),
    )
    dslf.concrete(
        "mul",
        "(%num, %num) -> %num",
        lambda x, y: guard_callables(fn=lambda x, y: x * y, x=x, y=y),
    )
    # dslf.concrete(
    #     "sub",
    #     "(%num, %num) -> %num",
    #     lambda x, y: guard_callables(fn=lambda x, y: x - y, x=x, y=y),
    # )

    # dslf.parameterized("linear_bool", "() -> $fFeat -> $fFeat", lambda lin: lin, dict(lin=lambda: nn.Linear(input_dim, 1)))
    dslf.parameterized(
        "linear",
        "(($fInp) -> $fFeat) -> $fInp -> {f, 1}",
        lambda f, lin: lambda x: lin(f(x)),
        dict(lin=lambda: nn.Linear(feature_dim, 1)),
    )

    dslf.parameterized(
        "output",
        "(($fInp) -> $fFeat) -> $fInp -> $fOut",
        lambda f, lin: lambda x: lin(f(x)),
        dict(lin=lambda: nn.Linear(feature_dim, output_dim)),
    )

    # dslf.concrete("iteA", "(#a -> $fFeat, #a -> #a, #a -> #a) -> #a -> #a", lambda cond, fx, fy: ite_torch(cond, fx, fy))
    dslf.concrete(
        "ite",
        "(#a -> {f, 1}, #a -> #b, #a -> #b) -> #a -> #b",
        lambda cond, fx, fy: ite_torch(cond, fx, fy),
        # lambda cond, fx, fy: guard_callables(fn=partial(ite_torch, condition=cond), if_true=fx, if_else=fy),
    )
    # dslf.concrete("map", "(#a -> #b) -> [#a] -> [#b]", lambda f: lambda x: map_torch(f, x))
    # dslf.concrete("compose", "(#a -> #b, #b -> #c) -> #a -> #c", lambda f, g: lambda x: g(f(x)))
    # dslf.concrete("fold", "((#a, #a) -> #a) -> [#a] -> #a", lambda f: lambda x: fold_torch(f, x))

    dslf.prune_to("($fInp) -> $fOut")
    return dslf.finalize(), dslf.t

input_dim, output_dim = 144, 9

dsl, dsl_type_env = ecg_dsl(input_dim=input_dim, output_dim=output_dim)

### DSL Printout

See your DSL printed below, and ensure it is what you would expect

In [8]:
print(dsl.render())

      channel_0 :: () -> channel
      channel_1 :: () -> channel
      channel_2 :: () -> channel
      channel_3 :: () -> channel
      channel_4 :: () -> channel
      channel_5 :: () -> channel
      channel_6 :: () -> channel
      channel_7 :: () -> channel
      channel_8 :: () -> channel
      channel_9 :: () -> channel
     channel_10 :: () -> channel
     channel_11 :: () -> channel
select_interval :: channel -> {f, 144} -> {f, 6}
select_amplitude :: channel -> {f, 144} -> {f, 6}
            add :: (%num, %num) -> %num
            mul :: (%num, %num) -> %num
            ite :: (#a -> {f, 1}, #a -> #b, #a -> #b) -> #a -> #b
    linear[lin] :: ({f, 144} -> {f, 6}) -> {f, 144} -> {f, 1}
    output[lin] :: ({f, 144} -> {f, 6}) -> {f, 144} -> {f, 9}


### Setting up Neural DSL

In [9]:
t = ns.TypeDefiner(L=input_dim, O=output_dim)
t.typedef("fL", "{f, $L}")
neural_dsl = near.NeuralDSL.from_dsl(
    dsl=dsl,
    modules={
        **near.create_modules(
            "mlp",
            [
                dsl_type_env("($fInp) -> $fInp"),
                dsl_type_env("($fInp) -> $fOut"),
                dsl_type_env("($fInp) -> $fFeat"),
                dsl_type_env("($fInp) -> {f, 1}"),
                #  t("([$fI]) -> [$fI]"), t("([$fI]) -> [$fO]")
            ],
            near.mlp_factory(hidden_size=10),
        ),
        **near.create_modules(
            "constant_int",
            [dsl_type_env("() -> channel")],
            near.selector_factory(input_dim=input_dim),
            known_atom_shapes=dict(channel=(12,), feature=(6,)),
        ),
    },
)
logging.getLogger("pytorch_lightning.utilities.rank_zero").setLevel(logging.WARNING)
logging.getLogger("pytorch_lightning.accelerators.cuda").setLevel(logging.WARNING)

def cross_entropy_callback(predicitons, targets):
    return torch.nn.functional.cross_entropy(predicitons, targets)

trainer_cfg = near.NEARTrainerConfig(
    lr=1e-2,
    max_seq_len=300,
    n_epochs=30,
    num_labels=output_dim,
    train_steps=len(datamodule.train),
    loss_callback=cross_entropy_callback,
    scheduler="cosine",
    optimizer=torch.optim.Adam,
)

validation_cost = near.ValidationCost(
    neural_dsl=neural_dsl,
    trainer_cfg=trainer_cfg,
    datamodule=datamodule,
    accelerator="cuda",
    devices=1,
    callbacks=[
        pl.callbacks.EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=5, mode='min')
    ],
    enable_progress_bar=False,
    enable_model_summary=False,
    progress_by_epoch=True,
)

g = near.near_graph(
    neural_dsl,
    ns.parse_type(
        s="({f, $L}) -> {f, $O}", env=ns.TypeDefiner(L=input_dim, O=output_dim)
    ),
    is_goal=neural_dsl.program_has_no_holes,
)


### Run NEAR

In [10]:
from contextlib import redirect_stdout, redirect_stderr

with open('ecg_exercise/output3.logs', 'w') as f:
    with redirect_stdout(f):
        with redirect_stderr(f):
            iterator = ns.search.bounded_astar(g, validation_cost, max_depth=16)
            # iterator = ns.search.bounded_astar_async(g, validation_cost, max_depth=16, max_workers=5)
            best_program_nodes = []
            # Let's collect the top three programs
            while len(best_program_nodes) <= 2:
                try:
                    node = next(iterator)
                    cost = validation_cost(node)
                    best_program_nodes.append((node, cost))
                    print("Got another program")
                except StopIteration:
                    print("No more programs found.")
                    break

### Top 3 Programs

The code below assumes you found some top 3 programs and stored them in the best_program_nodes variable.

In [None]:
best_program_nodes = sorted(best_program_nodes, key=lambda x: x[1])
for i, (node, cost) in enumerate(best_program_nodes):
    print(
        "({i}) Cost: {cost:.4f}, {program}".format(
            i=i, program=ns.render_s_expression(node.program), cost=cost
        )
    )

The function below is set up to further fine tune the program, test it, and return a set of values produced by it.

In [None]:
def testProgram(best_program_node):
    module = near.TorchProgramModule(
        dsl=neural_dsl, program=best_program_node[0].program
    )
    pl_model = near.NEARTrainer(module, config=trainer_cfg)
    trainer = pl.Trainer(
        max_epochs=4000,
        devices="auto",
        accelerator="cpu",
        enable_checkpointing=False,
        logger=False,
        callbacks=[
            pl.callbacks.EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=5)
        ],
        enable_progress_bar=False,
    )

    trainer.fit(pl_model, datamodule.train_dataloader(), datamodule.val_dataloader())
    # T = 100
    # path = np.zeros((T, 4))
    # X = torch.tensor(
    #     np.array([0.21413583, 4.4062634, 3.4344807, 0.12440437]), dtype=torch.float32
    # )
    # for t in range(T):
    #     path[t, :] = X.detach().numpy()
    #     Y = module(X.unsqueeze(0)).squeeze(0)
    #     X = Y
    # return path
    pass

In [19]:
# # We generate trajectories for the top 2 programs.
# trajectory = testProgram(best_program_nodes[0])
# trajectoryb = testProgram(best_program_nodes[1])

## Plotting Trajectories

In [20]:
# plt.figure(figsize=(8, 8))

# plot_trajectory(trajectory, "C0")
# plot_trajectory(trajectoryb, "C1")
# plot_trajectory(datamodule.train.inputs[0], "black")

# plt.title("Bouncing ball (ground truth in black)")
# plt.show()
