# Interpret NDNF-MT with Logic-based Programming

We take the output from experiment `sc_ws_ppo_ndnf_mt_l4_1e5_aux_2151` on
`SmallCorridorEnv`.

In [1]:
import sys

sys.path.append("..")

In [2]:
import torch

from neural_dnf.neural_dnf import NeuralDNFMutexTanh

### Thresholded sc_ws_ppo_ndnf_mt_l4_1e5_aux_2151

Threshold upper bound: 0.67

Best threshold: 0.5299999713897705

KL divergence: 0.01657593995332718

After 2nd Prune

Action distribution: tensor([[0.6590, 0.3410], [0.0413, 0.9587], [0.8226, 0.1774]])

KL divergence cmp to after prune: 0.01657593995332718

Model:

Conjunction: tensor([[ 6., -0.], [ 6., -6.], [-0.,  6.], [ 6., -0.]])

Disjunction: tensor([[-0.0000, -0.0000,  0.1501, -0.5853], [ 0.6486,  0.3062, -0.2876,  0.3611]])

In [3]:
experiment_name = "sc_ws_ppo_ndnf_mt_l4_1e5_aux_2151"
conjunction_tensor = torch.Tensor(([[ 6., -0.], [ 6., -6.], [-0.,  6.], [ 6., -0.]]))
disjunction_tensor = torch.Tensor([[-0.0000, -0.0000,  0.1501, -0.5853], [ 0.6486,  0.3062, -0.2876,  0.3611]])

ndnf_mt = NeuralDNFMutexTanh(
    num_preds = conjunction_tensor.shape[1],
    num_conjuncts=conjunction_tensor.shape[0],
    n_out=disjunction_tensor.shape[0],
    delta=1.0
)
ndnf_mt.conjunctions.weights.data = conjunction_tensor
ndnf_mt.disjunctions.weights.data = disjunction_tensor

In [4]:
all_possible_inputs = torch.Tensor([
    [1, -1],    # left wall, right no wall
    [-1, -1],   # 
])

## Method

Step 1: Prune the model

Step 2: Compute deterministic conjunction via thresholding

Step 3: Re-prune the model

Step 4: Raw enumeration of the disjunctions

-  This step needs to compute the bias of the disjunction layer

Step 5: Condensation via logical equivalence

Step 6: Rule simplification based on experienced observations

Step 7: Interpretation of conjunction based on experienced observations

Step 8: ProbLog rules with annotated disjunction based on experienced observations

In [5]:
# Step 4: Raw enmueration of the layers
# compute the bias of the disjunction layer

abs_weight = torch.abs(disjunction_tensor)
# abs_weight: Q x P
max_abs_w = torch.max(abs_weight, dim=1)[0]
# max_abs_w: Q
sum_abs_w = torch.sum(abs_weight, dim=1)
# sum_abs_w: Q
bias = sum_abs_w - max_abs_w
# bias: Q
bias

tensor([0.1501, 0.9549])

### Method applied sc_ws_ppo_ndnf_mt_l4_1e5_aux_2151

`disj_0` is `turn_left`, and `disj_1` is `turn_right`.

`conj_0 :- left_wall.`

`conj_1 :- left_wall, not right_wall.`

`conj_2 :- right_wall.`

`conj_3 :- left_wall.`

**Step 4: Raw enumeration of the disjunctions**

```prolog
disj_0 = 0.1501 * conj_2 - 0.5853 * conj_3 + 0.1501.

disj_1 = 0.6486 * conj_0 + 0.3062 * conj_1 - 0.2876 * conj_2 + 0.3611 * conj_3 + 0.9549.
```

**Step 5: Condensation via logical equivalence**

`conj_0` and `conj_3` are equivalent. We replace `conj_3` with `conj_0`.

```prolog
disj_0 = 0.1501 * conj_2 - 0.5853 * conj_0 + 0.1501.

disj_1 = 0.6486 * conj_0 + 0.3062 * conj_1 - 0.2876 * conj_2 + 0.3611 * conj_0 + 0.9549

       = 1.0097 * conj_0 + 0.3062 * conj_1 - 0.2876 * conj_2 + 0.9549.
```

**Step 6: Rule simplification based on experienced observations**

`not right_wall` is always true and `right_wall` is always false.

`conj_1` is equivalent to `conj_1 :- left_wall` and thus is equivalent to `conj_0`.

`conj_2` is thus never true will always gives -1.

```prolog
disj_0 = 0.1501 * (-1) - 0.5853 * conj_0 + 0.1501

       = -0.5853 * conj_0.

disj_1 = 1.0097 * conj_0 + 0.3062 * conj_0 - 0.2876 * (-1) + 0.9549

       = 1.3159 * conj_0 + 1.2425.
```

**Step 7: Interpretation of conjunction based on experienced observations**

`conj_0` is equivalent to `left_wall`.

```prolog
disj_0 = -0.5853 * left_wall.

disj_1 = 1.3159 * left_wall + 1.2425.
```

**Step 8: ProbLog rules with annotated disjunction based on experienced observations**

Input: [1, -1] (`left_wall`, `-right_wall`)


```python

disj_0 = -0.583, disj_1 = 2.5584

tanh([disj_0, disj_1]) = [-0.5265,  0.9881]

mutex_tanh([disj_0, disj_1]) = [-0.9173,  0.9173]

prob([disj_0, disj_1]) = [0.0413, 0.9587]

```

Input: [-1, -1] (`-left_wall`, `-right_wall`)

```python

disj_0 = 0.5853, disj_1 = -0.0734

tanh([disj_0, disj_1]) = [0.5265, -0.0733]

mutex_tanh([disj_0, disj_1]) = [0.3179, -0.3179]

prob([disj_0, disj_1]) = [0.6590, 0.3410]

```

Covert to Problog rule:

```prolog

rule_1 :- left_wall.

0.0413::turn_left; 0.9587::turn_right :- rule_1.

rule_2 :- not left_wall.

0.6590::turn_left; 0.3410::turn_right :- rule_2.

```

In [9]:
with torch.no_grad():
    out_dict = ndnf_mt.get_all_forms(all_possible_inputs)
print(out_dict)
prob = (out_dict["disjunction"]["mutex_tanh"] + 1) / 2
print(prob)

{'conjunction': {'raw': tensor([[ 6.,  6., -6.,  6.],
        [-6., -6., -6., -6.]]), 'tanh': tensor([[ 1.0000,  1.0000, -1.0000,  1.0000],
        [-1.0000, -1.0000, -1.0000, -1.0000]])}, 'disjunction': {'raw': tensor([[-0.5853,  2.5584],
        [ 0.5853, -0.0734]]), 'tanh': tensor([[-0.5265,  0.9881],
        [ 0.5265, -0.0733]]), 'mutex_tanh': tensor([[-0.9173,  0.9173],
        [ 0.3179, -0.3179]])}}
tensor([[0.0413, 0.9587],
        [0.6590, 0.3410]])
