# Python equivalent for `readme_for_test.m`


Automatically reload module on change. Useful for development


In [None]:
%load_ext autoreload
%autoreload 2

Library imports for NumPy, CuPy, and JAX version


In [2]:
import functools
import time
from typing import Any, Callable, Literal, cast

import cupy as cp
import jax
import jax.numpy as jnp
import numpy as np
from numpy.random import Generator

# JAX-MatDNF
from mat_dnf.jax.losses import (
    acc_classi as jacc_classi,
)
from mat_dnf.jax.losses import (
    acc_dnf as jacc_dnf,
)
from mat_dnf.jax.losses import (
    logi_conseq as jlogi_conseq,
)
from mat_dnf.jax.losses import (
    logi_equiv as jlogi_equiv,
)
from mat_dnf.jax.models import MatDNF as JAXMatDNF
from mat_dnf.jax.models import train_mat_dnf as jtrain_mat_dnf

# NumPy/CuPy-MatDNF
from mat_dnf.numpy.losses import acc_classi, acc_dnf, logi_conseq, logi_equiv
from mat_dnf.numpy.models import (
    MatDNF,
    train_mat_dnf,
)

# Utility functions
from mat_dnf.simplifications import simp_dnf
from mat_dnf.utils import (
    MeanLogger,
    n_parity_function,
    random_dnf,
    random_function,
    read_nth_dnf,
)

# Enable 64-bit datatype in JAX
jax.config.update("jax_enable_x64", True)

## Parameters


### Problem size and selection


In [3]:
TargetFunction = Literal[
    "random_dnf", "random_function", "n_parity_function", "read_nth_dnf"
]
target_fn_name: TargetFunction = "read_nth_dnf"

When generating DNF


In [4]:
# Number of bool vars
# n = 5  # default
n = 5
# Maximum number of disjuncts
# h_gen = 10  # default
h_gen = 10
# Number of "enabled" disjuncts
# d_size = 3  # default
d_size = 3
# Maximum number of literals in each disjuncts
# c_max = 5  # default
c_max = 5
add_noise = False

When reading DNF


In [5]:
i = 0  # i-th DNF from a (*_T.bin.csv)
fname = "../data/E-MTAB-1908/01_T.bin.csv"

### Learning parameters


In [6]:
alpha = 0.1
max_try = 20  # default
# max_try = 4
max_itr = 500  # default
# max_itr = 50
# extra_itr = 0  # default
extra_itr = 0
Er_max = 0
i_max = 10
# i_max = 4
# Domain ratio; ratio of test data vs. all possible?
dr = 0.5  # default
# dr = 1.0
# Number of disjuncts in the initial guess?
# h = 1000   # default
h = 1000
aa = 4

### Data type and device selection


For multiple GPUs (JAX only), refer to https://docs.jax.dev/en/latest/sharded-computation.html


In [7]:
dtype = np.dtype(np.int64)
device_name: Literal["gpu", "cpu"] = "gpu"

## Compute


### Select device


In [8]:
xp = cp if device_name == "gpu" else np
(device,) = jax.devices(device_name)
jax.config.update("jax_default_device", device)

### Problem generator functions


In [9]:
rng = np.random.default_rng()

# TODO: Explicit return typing
match target_fn_name:
    case "random_dnf":
        target_fn: Callable[[], Any] = functools.partial(
            random_dnf,
            rng=rng,
            n=n,
            h_gen=h_gen,
            d_size=d_size,
            c_max=c_max,
            add_noise=add_noise,
            dtype=dtype,
        )
    case "random_function":
        target_fn: Callable[[], Any] = functools.partial(
            random_function,
            rng=rng,
            n=n,
            add_noise=add_noise,
            dtype=dtype,
        )
    case "n_parity_function":
        target_fn: Callable[[], Any] = functools.partial(
            n_parity_function,
            rng=rng,
            n=n,
            add_noise=add_noise,
            dtype=dtype,
        )
    case "read_nth_dnf":
        target_fn: Callable[[], Any] = functools.partial(
            read_nth_dnf,
            fname=fname,
            i=i,
            delimiter=",",
            dtype=dtype,
        )

### Numpy / CuPy


NOTE: Ideally, the problem generator could also use CuPy's RNG,
but since the implementation is incomplete, only the optimization loop is using it.


In [10]:
if device_name == "gpu":
    rng = cast(Generator, xp.random.default_rng())

Learning starts


In [None]:
# Log output
mean_logger = MeanLogger(dr=dr, l=2**n)
for i in range(i_max):
    # Read or generate random DNF
    i1, i2_k = target_fn()
    l2 = i1.shape[1]

    # Split train / test data
    x = np.floor(l2 * dr).astype(np.int64)
    i1_dr = i1[:, :x]
    i1_test = i1[:, x:]
    i2_k_dr = i2_k[:x]
    i2_k_test = i2_k[x:]

    # Initialize model
    model = MatDNF.create_random(h=h, n=i1_dr.shape[0], aa=aa, xp=xp, rng=rng)

    # Transfer training data arrays to GPU
    if device_name == "gpu":
        i1_dr = cp.asarray(i1_dr)
        i2_k_dr = cp.asarray(i2_k_dr)
        i2_k = cp.asarray(i2_k)
        i1 = cp.asarray(i1)

    s = time.monotonic()
    model, v_k_th, learned_dnf = train_mat_dnf(
        model=model,
        i_in=i1_dr,
        i_out=i2_k_dr,
        er_max=Er_max,
        alpha=alpha,
        max_itr=max_itr,
        max_try=max_try,
        extra_itr=extra_itr,
        fold=i,
        use_perturbation=True,
        use_sam=True,
        rng=rng,
    )
    e = time.monotonic()
    elapsed_time = e - s

    learned_dnf_s = simp_dnf(learned_dnf)
    cnsq, _ = logi_conseq(learned_dnf_s, i2_k, i1)
    eqv, _ = logi_equiv(learned_dnf_s, i2_k, i1)

    mean_logger.append(
        i1_test.shape[1],
        elapsed_time,
        float(acc_classi(model.d_k, v_k_th, i1, i2_k, l2, model.c)),  # type: ignore
        float(acc_dnf(learned_dnf_s, i1, i2_k, l2)),
        cnsq,
        eqv,
    )

print(mean_logger)

### JAX


Initialize PRNG key


In [12]:
key = jax.random.key(42)

Learning starts


In [None]:
dtype = jnp.dtype(dtype)

# Log output
mean_logger = MeanLogger(dr=dr, l=2**n)
for i in range(i_max):
    key, model_key, train_key = jax.random.split(key, num=3)

    # Read or generate random DNF
    i1, i2_k = target_fn()
    l2 = i1.shape[1]
    i1 = jnp.array(i1)
    i2_k = jnp.array(i2_k)

    # Split train / test data
    x = jnp.floor(l2 * dr).astype(jnp.int64)
    i1_dr = i1[:, :x]
    i1_test = i1[:, x:]
    i2_k_dr = i2_k[:x]
    i2_k_test = i2_k[x:]

    # Initialize model
    model = JAXMatDNF.create_random(h=h, n=i1_dr.shape[0], aa=aa, key=model_key)

    s = time.monotonic()
    # with jax.log_compiles(True):
    model, v_k_th, learned_dnf = jtrain_mat_dnf(
        model=model,
        key=train_key,
        fold=i,
        i_in=i1_dr,
        i_out=i2_k_dr,
        er_max=Er_max,
        alpha=alpha,
        max_itr=max_itr,
        max_try=max_try,
        use_perturbation=False,
        use_sam=False,
    )
    e = time.monotonic()
    elapsed_time = e - s

    # learned_dnf_s = jsimp_dnf(learned_dnf)
    learned_dnf_s = jnp.array(simp_dnf(np.array(learned_dnf)))
    cnsq, _ = jlogi_conseq(learned_dnf_s, i2_k, i1)
    eqv, _ = jlogi_equiv(learned_dnf_s, i2_k, i1)

    mean_logger.append(
        i1_test.shape[1],
        elapsed_time,
        float(jacc_classi(model.d_k, v_k_th, i1, i2_k, l2, model.c)),
        float(jacc_dnf(learned_dnf_s, i1, i2_k, l2)),
        cnsq,
        eqv,
    )

print(mean_logger)