In [1]:
from copy import deepcopy
import json
from pathlib import Path
import sys
from typing import Any
import warnings
from IPython.display import Video


sys.path.append("..")
warnings.filterwarnings('ignore')

In [2]:
import clingo
import gymnasium as gym
from gymnasium.wrappers import RecordVideo

import matplotlib.pyplot as plt

import numpy as np
import torch

In [3]:
from common import synthesize
from corridor_grid.envs import (
    DoorCorridorEnv,
    DoorCorridorTEnv,
    DoorCorridorOTEnv,
)
from door_corridor_ppo import construct_model, make_env, DCPPONDNFMutexTanhAgent
from eval.door_corridor_ppo_ndnf_mt_multirun_eval import (
    get_ndnf_action,
    simulate_fn,
)

In [4]:
MODEL_SEED = 6731

model_name = f"dc5_ppo_ndnf_mt_k1eoc4_tanh_exl16_3e5_aux_{MODEL_SEED}"
print(model_name)

model_cfg = {
    "experiment_name": "dc5_ppo_ndnf_mt_k1eoc4_tanh_exl16_3e5_aux",
    "customised_image_encoder": {
        "encoder_output_chanel": 4,
        "last_act": "tanh",
        "kernel_size": 1,
        "use_extra_layer": True,
        "extra_layer_out": 16,
        "extra_layer_use_bias": True,
    },
    "use_eo": False,
    "use_mt": True,
    "use_argmax_to_choose_action": True,
    "discretise_img_encoding": True,
}

dc5_ppo_ndnf_mt_k1eoc4_tanh_exl16_3e5_aux_6731


In [5]:
NUM_PROCESSES = 8
NUM_EPISODES = 100
DEVICE = torch.device("cpu")

BASE_STORAGE_DIR = Path("../dc_ppo_storage")
single_env = DoorCorridorEnv(render_mode="rgb_array")
envs = gym.vector.SyncVectorEnv(
    [make_env(i, i, False) for i in range(NUM_PROCESSES)]
)

simulate = lambda action_fn: simulate_fn(envs, action_fn)

# Original Model and DoorCorridorEnv

In [6]:
model_dir = BASE_STORAGE_DIR / model_name
model: DCPPONDNFMutexTanhAgent = construct_model(
    model_cfg, # type: ignore
    DoorCorridorEnv.get_num_actions(),
    True,
    single_env.observation_space["image"],  # type: ignore
)
model.to(DEVICE)
model_state = torch.load(
    model_dir / "thresholded_model.pth", map_location=DEVICE
)
model.load_state_dict(model_state)
model.eval()

print("Model loaded!")
print(model)

Model loaded!
DCPPONDNFMutexTanhAgent(
  (image_encoder): Sequential(
    (0): Conv2d(2, 4, kernel_size=(1, 1), stride=(1, 1))
    (1): Tanh()
  )
  (extra_layer): Sequential(
    (0): Linear(in_features=36, out_features=16, bias=True)
    (1): Tanh()
  )
  (actor): NeuralDNFMutexTanh(
    (conjunctions): SemiSymbolic(in_features=16, out_features=12, layer_type=SemiSymbolicLayerType.CONJUNCTION,current_delta=1.00)
    (disjunctions): SemiSymbolicMutexTanh(in_features=12, out_features=4, layer_type=SemiSymbolicLayerType.DISJUNCTION,current_delta=1.00)
  )
  (critic): Sequential(
    (0): Linear(in_features=16, out_features=64, bias=True)
    (1): Tanh()
    (2): Linear(in_features=64, out_features=1, bias=True)
  )
)


In [7]:
_ndnf_mt_dis_action_fn = lambda obs: get_ndnf_action(model, True, obs)


def _simulate_with_print(action_fn, model_name: str) -> dict[str, Any]:
    logs = simulate(action_fn)

    num_frames = sum(logs["num_frames_per_episode"])
    return_per_episode = synthesize(logs["return_per_episode"])
    num_frames_per_episode = synthesize(logs["num_frames_per_episode"])

    print(
        "{}\tF {} | R:μσmM {:.2f} {:.2f} {:.2f} {:.2f} | F:μσmM {:.1f} {:.1f} {} {}".format(
            model_name,
            num_frames,
            *return_per_episode.values(),
            *num_frames_per_episode.values(),
        )
    )
    print(f"Mutual exclusivity: {logs['mutual_exclusivity']}")
    print(f"Missing actions: {logs['missing_actions']}")
    return logs

In [8]:
_simulate_with_print(_ndnf_mt_dis_action_fn, "NDNF-MT final model")
print()

NDNF-MT final model	F 832.0 | R:μσmM -8.00 0.00 -8.00 -8.00 | F:μσmM 8.0 0.0 8.0 8.0
Mutual exclusivity: True
Missing actions: False



In [9]:
obs, _ = single_env.reset()

terminated = False
truncated = False
reward_sum = 0

while not terminated and not truncated:
    with torch.no_grad():
        raw_img_encoding = model.get_img_encoding(
            preprocessed_obs={
                "image": torch.tensor(obs["image"].copy(), device=DEVICE)
                .unsqueeze(0)
                .float()
            }
        ).squeeze(0)
        actions = model.get_actions(
            preprocessed_obs={
                "image": torch.tensor(obs["image"].copy(), device=DEVICE)
                .unsqueeze(0)
                .float()
            }
        )
    img_encoding = [
        f"a_{a.item()}." for a in torch.nonzero(raw_img_encoding > 0)
    ]
    print(img_encoding)
    obs, reward, terminated, truncated, _ = single_env.step(actions[0].item())
    reward_sum += reward

['a_1.', 'a_2.', 'a_6.', 'a_8.', 'a_13.']
['a_1.', 'a_2.', 'a_5.', 'a_6.', 'a_8.', 'a_9.', 'a_10.', 'a_13.']
['a_2.', 'a_6.', 'a_10.']
['a_1.', 'a_2.', 'a_5.', 'a_6.', 'a_7.', 'a_8.', 'a_9.', 'a_10.', 'a_13.', 'a_14.']
['a_2.', 'a_6.', 'a_7.', 'a_10.', 'a_14.']
['a_1.', 'a_2.', 'a_5.', 'a_6.', 'a_7.', 'a_8.', 'a_9.', 'a_10.', 'a_13.', 'a_14.']
['a_2.', 'a_6.', 'a_7.', 'a_11.', 'a_14.', 'a_15.']
['a_2.', 'a_6.', 'a_7.', 'a_11.', 'a_13.', 'a_14.']


In [10]:
with open(model_dir / "asp_rules.lp", "r") as f:
    asp_rules = list(filter(lambda s: s != "", f.read().split("\n")))

print("ASP rules:")
for r in asp_rules:
    print(r)

print()
print("Interpretation:")
with open(model_dir / "interpret_result.json", "r") as f:
    interpret_result = json.load(f)
interpretation = interpret_result["0"]["parsed_program"]
for r in interpretation:
    print(r)

ASP rules:
disj_1 :- conj_5.
disj_2 :- conj_4.
disj_3 :- not conj_6.
conj_4 :- not a_8.
conj_5 :- not a_9, a_13, not a_14.
conj_6 :- not a_9.

Interpretation:
a_8 :- top_right_corner_unseen.
a_9 :- one_step_ahead_closed_door.
a_13 :- not one_step_ahead_open_door.
a_14 :- curr_location_open_door.


turn_right :- not one_step_ahead_closed_door, not one_step_ahead_open_door, not curr_location_open_door.

forward :- not top_right_corner_unseen.

toggle :- one_step_ahead_closed_door. 

# NDNF-MT on Door Corridor T

A modified version of `DoorCorridorEnv`, but to finish the environment, the
agent must be in front of the goal state and toggle it instead of moving onto it.

In [11]:
dct = DoorCorridorTEnv(render_mode="rgb_array")
dct.metadata["render_fps"] = 1
model.load_state_dict(model_state)

<All keys matched successfully>

In [12]:
obs, _ = dct.reset()

terminated = False
truncated = False
reward_sum = 0

while not terminated and not truncated:
    with torch.no_grad():
        actions = model.get_actions(
            preprocessed_obs={
                "image": torch.tensor([obs["image"]], dtype=torch.float32).to(
                    DEVICE
                )
            },
            use_argmax=True,
            discretise_img_encoding=True,
        )
    actions = actions[0]
    obs, reward, terminated, truncated, _ = dct.step(actions[0])
    reward_sum += reward

print(f"Reward: {reward_sum}")
print(f"Terminated: {terminated}")
print(f"Truncated: {truncated}")

Reward: -270
Terminated: False
Truncated: True


### Modified ASP rules:

disj_1 :- conj_5.

disj_2 :- conj_4.

disj_3 :- not conj_6.

**disj_3 :- conj_0.**

**conj_0 :- a_11, a_13.**

conj_4 :- not a_8, **not a_13**.

conj_5 :- not a_9, a_13, not a_14.

conj_6 :- not a_9.


Image encoding at each timestep:

0 ['a_1.', 'a_2.', 'a_6.', 'a_8.', 'a_13.']

1 ['a_1.', 'a_2.', 'a_5.', 'a_6.', 'a_8.', 'a_9.', 'a_10.', 'a_13.']

2 ['a_2.', 'a_6.', 'a_10.']

3 ['a_1.', 'a_2.', 'a_5.', 'a_6.', 'a_7.', 'a_8.', 'a_9.', 'a_10.', 'a_13.', 'a_14.']

4 ['a_2.', 'a_6.', 'a_7.', 'a_10.', 'a_14.']

5 ['a_1.', 'a_2.', 'a_5.', 'a_6.', 'a_7.', 'a_8.', 'a_9.', 'a_10.', 'a_13.', 'a_14.']

6 ['a_2.', 'a_6.', 'a_7.', 'a_11.', 'a_14.', 'a_15.']

7 ['a_2.', 'a_6.', 'a_7.', 'a_11.', 'a_13.', 'a_14.']

### Evaluate the rules

In [13]:
obs, _ = dct.reset()

terminated = False
truncated = False
reward_sum = 0

new_rules = [
    "disj_1 :- conj_5.",
    "disj_2 :- conj_4.",
    "disj_3 :- not conj_6.",
    "disj_3 :- conj_0.",
    "conj_4 :- not a_8, not a_13.",
    "conj_5 :- not a_9, a_13, not a_14.",
    "conj_6 :- not a_9.",
    "conj_0 :- a_11, a_13.",
]

while not terminated and not truncated:
    with torch.no_grad():
        raw_img_encoding = model.get_img_encoding(
            preprocessed_obs={
                "image": torch.tensor(obs["image"].copy(), device=DEVICE)
                .unsqueeze(0)
                .float()
            }
        ).squeeze(0)
    img_encoding = [
        f"a_{a.item()}." for a in torch.nonzero(raw_img_encoding > 0)
    ]
    print(img_encoding)
    ctl = clingo.Control(["--warn=none"])
    show_statements = [
        f"#show disj_{i}/0." for i in range(DoorCorridorEnv.get_num_actions())
    ]
    ctl.add("base", [], " ".join(img_encoding + show_statements + new_rules))
    ctl.ground([("base", [])])
    with ctl.solve(yield_=True) as handle:  # type: ignore
        all_answer_sets = [str(a) for a in handle]

    if len(all_answer_sets) != 1:
        # No model or multiple answer sets, should not happen
        print(f"No model or multiple answer sets when evaluating rules.")
        break

    if all_answer_sets[0] == "":
        print(f"No output action!")
        break

    output_classes = all_answer_sets[0].split(" ")
    if len(output_classes) == 0:
        print(f"No output action!")
        break
    output_classes_set = set([int(o[5:]) for o in output_classes])

    if len(output_classes_set) != 1:
        print(f"Output set: {output_classes_set} not exactly one item!")
        break

    action = list(output_classes_set)[0]
    print(f"Action: {action}")
    obs, reward, terminated, truncated, _ = dct.step(action)
    reward_sum += reward

print(f"Reward: {reward_sum}")
print(f"Terminated: {terminated}")
print(f"Truncated: {truncated}")

['a_1.', 'a_2.', 'a_6.', 'a_8.', 'a_13.']
Action: 1
['a_1.', 'a_2.', 'a_5.', 'a_6.', 'a_8.', 'a_9.', 'a_10.', 'a_13.']
Action: 3
['a_2.', 'a_6.', 'a_10.']
Action: 2
['a_1.', 'a_2.', 'a_5.', 'a_6.', 'a_7.', 'a_8.', 'a_9.', 'a_10.', 'a_13.', 'a_14.']
Action: 3
['a_2.', 'a_6.', 'a_7.', 'a_10.', 'a_14.']
Action: 2
['a_1.', 'a_2.', 'a_5.', 'a_6.', 'a_7.', 'a_8.', 'a_9.', 'a_10.', 'a_13.', 'a_14.']
Action: 3
['a_2.', 'a_6.', 'a_7.', 'a_11.', 'a_14.', 'a_15.']
Action: 2
['a_2.', 'a_6.', 'a_7.', 'a_11.', 'a_13.', 'a_14.']
Action: 3
Reward: -8
Terminated: True
Truncated: False


### Modify the model and evaluate

In [14]:
modified_sd = deepcopy(model_state)
modified_sd["actor.conjunctions.weights"][0] *= 0
modified_sd["actor.conjunctions.weights"][0, 11] = 6
modified_sd["actor.conjunctions.weights"][0, 13] = 6
modified_sd["actor.conjunctions.weights"][4] *= 0
modified_sd["actor.conjunctions.weights"][4, 8] = -6
modified_sd["actor.conjunctions.weights"][4, 13] = -6


modified_sd["actor.disjunctions.weights"][3, 0] = 6

In [15]:
dct_model: DCPPONDNFMutexTanhAgent = construct_model(
    model_cfg, # type: ignore
    DoorCorridorEnv.get_num_actions(),
    True,
    single_env.observation_space["image"],  # type: ignore
)
dct_model.to(DEVICE)
dct_model.load_state_dict(modified_sd)
dct_model.eval()

DCPPONDNFMutexTanhAgent(
  (image_encoder): Sequential(
    (0): Conv2d(2, 4, kernel_size=(1, 1), stride=(1, 1))
    (1): Tanh()
  )
  (extra_layer): Sequential(
    (0): Linear(in_features=36, out_features=16, bias=True)
    (1): Tanh()
  )
  (actor): NeuralDNFMutexTanh(
    (conjunctions): SemiSymbolic(in_features=16, out_features=12, layer_type=SemiSymbolicLayerType.CONJUNCTION,current_delta=1.00)
    (disjunctions): SemiSymbolicMutexTanh(in_features=12, out_features=4, layer_type=SemiSymbolicLayerType.DISJUNCTION,current_delta=1.00)
  )
  (critic): Sequential(
    (0): Linear(in_features=16, out_features=64, bias=True)
    (1): Tanh()
    (2): Linear(in_features=64, out_features=1, bias=True)
  )
)

In [16]:
obs, _ = dct.reset()

terminated = False
truncated = False
reward_sum = 0

while not terminated and not truncated:
    with torch.no_grad():
        actions = dct_model.get_actions(
            preprocessed_obs={
                "image": torch.tensor([obs["image"]], dtype=torch.float32).to(
                    DEVICE
                )
            },
            use_argmax=True,
            discretise_img_encoding=True,
        )
    actions = actions[0]
    obs, reward, terminated, truncated, _ = dct.step(actions[0])
    reward_sum += reward

print(f"Reward: {reward_sum}")
print(f"Terminated: {terminated}")
print(f"Truncated: {truncated}")

Reward: -8
Terminated: True
Truncated: False


### Record the environment

In [17]:
record_dct = RecordVideo(
    dct,
    video_folder="video",
    name_prefix="dct",
    episode_trigger=lambda x: True,
    disable_logger=True,
)
obs, _ = record_dct.reset()

terminated = False
truncated = False
reward_sum = 0

new_rules = [
    "disj_1 :- conj_5.",
    "disj_2 :- conj_4.",
    "disj_3 :- not conj_6.",
    "disj_3 :- conj_0.",
    "conj_4 :- not a_8, not a_13.",
    "conj_5 :- not a_9, a_13, not a_14.",
    "conj_6 :- not a_9.",
    "conj_0 :- a_11, a_13.",
]

while not terminated and not truncated:
    with torch.no_grad():
        raw_img_encoding = dct_model.get_img_encoding(
            preprocessed_obs={
                "image": torch.tensor(obs["image"].copy(), device=DEVICE)
                .unsqueeze(0)
                .float()
            }
        ).squeeze(0)
    img_encoding = [
        f"a_{a.item()}." for a in torch.nonzero(raw_img_encoding > 0)
    ]
    print(img_encoding)
    ctl = clingo.Control(["--warn=none"])
    show_statements = [
        f"#show disj_{i}/0." for i in range(DoorCorridorEnv.get_num_actions())
    ]
    ctl.add("base", [], " ".join(img_encoding + show_statements + new_rules))
    ctl.ground([("base", [])])
    with ctl.solve(yield_=True) as handle:  # type: ignore
        all_answer_sets = [str(a) for a in handle]

    if len(all_answer_sets) != 1:
        # No model or multiple answer sets, should not happen
        print(f"No model or multiple answer sets when evaluating rules.")
        break

    if all_answer_sets[0] == "":
        print(f"No output action!")
        break

    output_classes = all_answer_sets[0].split(" ")
    if len(output_classes) == 0:
        print(f"No output action!")
        break
    output_classes_set = set([int(o[5:]) for o in output_classes])

    if len(output_classes_set) != 1:
        print(f"Output set: {output_classes_set} not exactly one item!")
        break

    action = list(output_classes_set)[0]
    print(f"Action: {action}")
    obs, reward, terminated, truncated, _ = record_dct.step(action)
    reward_sum += reward

print(f"Reward: {reward_sum}")
print(f"Terminated: {terminated}")
print(f"Truncated: {truncated}")

['a_1.', 'a_2.', 'a_6.', 'a_8.', 'a_13.']
Action: 1
['a_1.', 'a_2.', 'a_5.', 'a_6.', 'a_8.', 'a_9.', 'a_10.', 'a_13.']
Action: 3
['a_2.', 'a_6.', 'a_10.']
Action: 2
['a_1.', 'a_2.', 'a_5.', 'a_6.', 'a_7.', 'a_8.', 'a_9.', 'a_10.', 'a_13.', 'a_14.']
Action: 3
['a_2.', 'a_6.', 'a_7.', 'a_10.', 'a_14.']
Action: 2
['a_1.', 'a_2.', 'a_5.', 'a_6.', 'a_7.', 'a_8.', 'a_9.', 'a_10.', 'a_13.', 'a_14.']
Action: 3
['a_2.', 'a_6.', 'a_7.', 'a_11.', 'a_14.', 'a_15.']
Action: 2
['a_2.', 'a_6.', 'a_7.', 'a_11.', 'a_13.', 'a_14.']
Action: 3
Reward: -8
Terminated: True
Truncated: False


In [18]:
obs, _ = record_dct.reset()

terminated = False
truncated = False
reward_sum = 0

while not terminated and not truncated:
    with torch.no_grad():
        actions = dct_model.get_actions(
            preprocessed_obs={
                "image": torch.tensor([obs["image"]], dtype=torch.float32).to(
                    DEVICE
                )
            },
            use_argmax=True,
            discretise_img_encoding=True,
        )
    actions = actions[0]
    print(actions)
    obs, reward, terminated, truncated, _ = record_dct.step(actions[0])
    reward_sum += reward

print(f"Reward: {reward_sum}")
print(f"Terminated: {terminated}")
print(f"Truncated: {truncated}")

record_dct.close()

[1]
[3]
[2]
[3]
[2]
[3]
[2]
[3]
Reward: -8
Terminated: True
Truncated: False


# NDNF-MT on Door Corridor OT

A modified version of `DoorCorridorEnv`, but to finish the environment, the
agent must stand on the goal state and toggle it.


In [19]:
dcot = DoorCorridorOTEnv(render_mode="rgb_array")
dcot.metadata["render_fps"] = 1
model.load_state_dict(model_state)

<All keys matched successfully>

In [20]:
obs, _ = dcot.reset()

terminated = False
truncated = False
reward_sum = 0

i = 0
while not terminated and not truncated:
    with torch.no_grad():
        obs_dict = {
            "image": torch.tensor(obs["image"].copy(), device=DEVICE)
            .unsqueeze(0)
            .float()
        }
        raw_img_encoding = model.get_img_encoding(
            preprocessed_obs=obs_dict
        ).squeeze(0)
        actions = model.get_actions(
            preprocessed_obs=obs_dict,
            use_argmax=True,
            discretise_img_encoding=True,
        )

    if i <= 8:
        img_encoding = [
            f"a_{a.item()}." for a in torch.nonzero(raw_img_encoding > 0)
        ]
        print(img_encoding)

    actions = actions[0]
    obs, reward, terminated, truncated, _ = dcot.step(actions[0])
    reward_sum += reward
    i += 1

print(f"Reward: {reward_sum}")
print(f"Terminated: {terminated}")
print(f"Truncated: {truncated}")

['a_1.', 'a_2.', 'a_6.', 'a_8.', 'a_13.']
['a_1.', 'a_2.', 'a_5.', 'a_6.', 'a_8.', 'a_9.', 'a_10.', 'a_13.']
['a_2.', 'a_6.', 'a_10.']
['a_1.', 'a_2.', 'a_5.', 'a_6.', 'a_7.', 'a_8.', 'a_9.', 'a_10.', 'a_13.', 'a_14.']
['a_2.', 'a_6.', 'a_7.', 'a_10.', 'a_14.']
['a_1.', 'a_2.', 'a_5.', 'a_6.', 'a_7.', 'a_8.', 'a_9.', 'a_10.', 'a_13.', 'a_14.']
['a_2.', 'a_6.', 'a_7.', 'a_11.', 'a_14.', 'a_15.']
['a_2.', 'a_6.', 'a_7.', 'a_11.', 'a_13.', 'a_14.']
['a_1.', 'a_2.', 'a_6.', 'a_8.', 'a_10.', 'a_13.']
Reward: -270
Terminated: False
Truncated: True


### Modified ASP rules:

disj_1 :- conj_5.

disj_2 :- conj_4.

**disj_3 :- conj_0.**

disj_3 :- not conj_6.

**conj_0 :- a_8, not a_9, a_10.**

conj_4 :- not a_8.

conj_5 :- not a_9, **not a_10**, a_13, not a_14.

conj_6 :- not a_9.


Image encoding at each timestep:

0 ['a_1.', 'a_2.', 'a_6.', 'a_8.', 'a_13.']

1 ['a_1.', 'a_2.', 'a_5.', 'a_6.', 'a_8.', 'a_9.', 'a_10.', 'a_13.']

2 ['a_2.', 'a_6.', 'a_10.']

3 ['a_1.', 'a_2.', 'a_5.', 'a_6.', 'a_7.', 'a_8.', 'a_9.', 'a_10.', 'a_13.', 'a_14.']

4 ['a_2.', 'a_6.', 'a_7.', 'a_10.', 'a_14.']

5 ['a_1.', 'a_2.', 'a_5.', 'a_6.', 'a_7.', 'a_8.', 'a_9.', 'a_10.', 'a_13.', 'a_14.']

6 ['a_2.', 'a_6.', 'a_7.', 'a_11.', 'a_14.', 'a_15.']

7 ['a_2.', 'a_6.', 'a_7.', 'a_11.', 'a_13.', 'a_14.']

8 ['a_1.', 'a_2.', 'a_6.', 'a_8.', 'a_10.', 'a_13.']

In [21]:
obs, _ = dcot.reset()

terminated = False
truncated = False
reward_sum = 0

new_rules = [
    "disj_1 :- conj_5.",
    "disj_2 :- conj_4.",
    "disj_3 :- conj_0." "disj_3 :- not conj_6.",
    "conj_0 :- a_8, not a_9, a_10.",
    "conj_4 :- not a_8.",
    "conj_5 :- not a_9, not a_10, a_13, not a_14.",
    "conj_6 :- not a_9.",
]

while not terminated and not truncated:
    with torch.no_grad():
        raw_img_encoding = model.get_img_encoding(
            preprocessed_obs={
                "image": torch.tensor(obs["image"].copy(), device=DEVICE)
                .unsqueeze(0)
                .float()
            }
        ).squeeze(0)
    img_encoding = [
        f"a_{a.item()}." for a in torch.nonzero(raw_img_encoding > 0)
    ]
    print(img_encoding)
    ctl = clingo.Control(["--warn=none"])
    show_statements = [
        f"#show disj_{i}/0." for i in range(DoorCorridorEnv.get_num_actions())
    ]
    ctl.add("base", [], " ".join(img_encoding + show_statements + new_rules))
    ctl.ground([("base", [])])
    with ctl.solve(yield_=True) as handle:  # type: ignore
        all_answer_sets = [str(a) for a in handle]

    if len(all_answer_sets) != 1:
        # No model or multiple answer sets, should not happen
        print(f"No model or multiple answer sets when evaluating rules.")
        break

    if all_answer_sets[0] == "":
        print(f"No output action!")
        break

    output_classes = all_answer_sets[0].split(" ")
    if len(output_classes) == 0:
        print(f"No output action!")
        break
    output_classes_set = set([int(o[5:]) for o in output_classes])

    if len(output_classes_set) != 1:
        print(f"Output set: {output_classes_set} not exactly one item!")
        break

    action = list(output_classes_set)[0]
    print(f"Action: {action}")
    obs, reward, terminated, truncated, _ = dcot.step(action)
    reward_sum += reward

['a_1.', 'a_2.', 'a_6.', 'a_8.', 'a_13.']
Action: 1
['a_1.', 'a_2.', 'a_5.', 'a_6.', 'a_8.', 'a_9.', 'a_10.', 'a_13.']
Action: 3
['a_2.', 'a_6.', 'a_10.']
Action: 2
['a_1.', 'a_2.', 'a_5.', 'a_6.', 'a_7.', 'a_8.', 'a_9.', 'a_10.', 'a_13.', 'a_14.']
Action: 3
['a_2.', 'a_6.', 'a_7.', 'a_10.', 'a_14.']
Action: 2
['a_1.', 'a_2.', 'a_5.', 'a_6.', 'a_7.', 'a_8.', 'a_9.', 'a_10.', 'a_13.', 'a_14.']
Action: 3
['a_2.', 'a_6.', 'a_7.', 'a_11.', 'a_14.', 'a_15.']
Action: 2
['a_2.', 'a_6.', 'a_7.', 'a_11.', 'a_13.', 'a_14.']
Action: 2
['a_1.', 'a_2.', 'a_6.', 'a_8.', 'a_10.', 'a_13.']
Action: 3


In [22]:
modified_ot_sd = deepcopy(model_state)
modified_ot_sd["actor.conjunctions.weights"][0] *= 0
modified_ot_sd["actor.conjunctions.weights"][0, 8] = 6
modified_ot_sd["actor.conjunctions.weights"][0, 9] = -6
modified_ot_sd["actor.conjunctions.weights"][0, 10] = 6

modified_ot_sd["actor.conjunctions.weights"][5, 10] = -6


modified_ot_sd["actor.disjunctions.weights"][3, 0] = 6

In [23]:
dcot_model: DCPPONDNFMutexTanhAgent = construct_model(
    model_cfg, # type: ignore
    DoorCorridorEnv.get_num_actions(),
    True,
    single_env.observation_space["image"],  # type: ignore
)
dcot_model.to(DEVICE)
dcot_model.load_state_dict(modified_ot_sd)
dcot_model.eval()

DCPPONDNFMutexTanhAgent(
  (image_encoder): Sequential(
    (0): Conv2d(2, 4, kernel_size=(1, 1), stride=(1, 1))
    (1): Tanh()
  )
  (extra_layer): Sequential(
    (0): Linear(in_features=36, out_features=16, bias=True)
    (1): Tanh()
  )
  (actor): NeuralDNFMutexTanh(
    (conjunctions): SemiSymbolic(in_features=16, out_features=12, layer_type=SemiSymbolicLayerType.CONJUNCTION,current_delta=1.00)
    (disjunctions): SemiSymbolicMutexTanh(in_features=12, out_features=4, layer_type=SemiSymbolicLayerType.DISJUNCTION,current_delta=1.00)
  )
  (critic): Sequential(
    (0): Linear(in_features=16, out_features=64, bias=True)
    (1): Tanh()
    (2): Linear(in_features=64, out_features=1, bias=True)
  )
)

In [24]:
obs, _ = dcot.reset()

terminated = False
truncated = False
reward_sum = 0

while not terminated and not truncated:
    with torch.no_grad():
        actions = dcot_model.get_actions(
            preprocessed_obs={
                "image": torch.tensor([obs["image"]], dtype=torch.float32).to(
                    DEVICE
                )
            },
            use_argmax=True,
            discretise_img_encoding=True,
        )
    actions = actions[0]
    obs, reward, terminated, truncated, _ = dcot.step(actions[0])
    reward_sum += reward

print(f"Reward: {reward_sum}")
print(f"Terminated: {terminated}")
print(f"Truncated: {truncated}")

Reward: -9
Terminated: True
Truncated: False


### Record the environment

In [25]:
record_dcot = RecordVideo(
    dcot,
    video_folder="video",
    name_prefix="dcot",
    episode_trigger=lambda x: True,
    disable_logger=True,
)
obs, _ = record_dcot.reset()

terminated = False
truncated = False
reward_sum = 0

new_rules = [
    "disj_1 :- conj_5.",
    "disj_2 :- conj_4.",
    "disj_3 :- conj_0." "disj_3 :- not conj_6.",
    "conj_0 :- a_8, not a_9, a_10.",
    "conj_4 :- not a_8.",
    "conj_5 :- not a_9, not a_10, a_13, not a_14.",
    "conj_6 :- not a_9.",
]

while not terminated and not truncated:
    with torch.no_grad():
        raw_img_encoding = dcot_model.get_img_encoding(
            preprocessed_obs={
                "image": torch.tensor(obs["image"].copy(), device=DEVICE)
                .unsqueeze(0)
                .float()
            }
        ).squeeze(0)
    img_encoding = [
        f"a_{a.item()}." for a in torch.nonzero(raw_img_encoding > 0)
    ]
    print(img_encoding)
    ctl = clingo.Control(["--warn=none"])
    show_statements = [
        f"#show disj_{i}/0." for i in range(DoorCorridorEnv.get_num_actions())
    ]
    ctl.add("base", [], " ".join(img_encoding + show_statements + new_rules))
    ctl.ground([("base", [])])
    with ctl.solve(yield_=True) as handle:  # type: ignore
        all_answer_sets = [str(a) for a in handle]

    if len(all_answer_sets) != 1:
        # No model or multiple answer sets, should not happen
        print(f"No model or multiple answer sets when evaluating rules.")
        break

    if all_answer_sets[0] == "":
        print(f"No output action!")
        break

    output_classes = all_answer_sets[0].split(" ")
    if len(output_classes) == 0:
        print(f"No output action!")
        break
    output_classes_set = set([int(o[5:]) for o in output_classes])

    if len(output_classes_set) != 1:
        print(f"Output set: {output_classes_set} not exactly one item!")
        break

    action = list(output_classes_set)[0]
    print(f"Action: {action}")
    obs, reward, terminated, truncated, _ = record_dcot.step(action)
    reward_sum += reward

print(f"Reward: {reward_sum}")
print(f"Terminated: {terminated}")
print(f"Truncated: {truncated}")

['a_1.', 'a_2.', 'a_6.', 'a_8.', 'a_13.']
Action: 1
['a_1.', 'a_2.', 'a_5.', 'a_6.', 'a_8.', 'a_9.', 'a_10.', 'a_13.']
Action: 3
['a_2.', 'a_6.', 'a_10.']
Action: 2
['a_1.', 'a_2.', 'a_5.', 'a_6.', 'a_7.', 'a_8.', 'a_9.', 'a_10.', 'a_13.', 'a_14.']
Action: 3
['a_2.', 'a_6.', 'a_7.', 'a_10.', 'a_14.']
Action: 2
['a_1.', 'a_2.', 'a_5.', 'a_6.', 'a_7.', 'a_8.', 'a_9.', 'a_10.', 'a_13.', 'a_14.']
Action: 3
['a_2.', 'a_6.', 'a_7.', 'a_11.', 'a_14.', 'a_15.']
Action: 2
['a_2.', 'a_6.', 'a_7.', 'a_11.', 'a_13.', 'a_14.']
Action: 2
['a_1.', 'a_2.', 'a_6.', 'a_8.', 'a_10.', 'a_13.']
Action: 3
Reward: -9
Terminated: True
Truncated: False


In [26]:
obs, _ = record_dcot.reset()

terminated = False
truncated = False
reward_sum = 0

while not terminated and not truncated:
    with torch.no_grad():
        actions = dcot_model.get_actions(
            preprocessed_obs={
                "image": torch.tensor([obs["image"]], dtype=torch.float32).to(
                    DEVICE
                )
            },
            use_argmax=True,
            discretise_img_encoding=True,
        )
    actions = actions[0]
    print(actions)
    obs, reward, terminated, truncated, _ = record_dcot.step(actions[0])
    reward_sum += reward

print(f"Reward: {reward_sum}")
print(f"Terminated: {terminated}")
print(f"Truncated: {truncated}")

record_dcot.close()

[1]
[3]
[2]
[3]
[2]
[3]
[2]
[2]
[3]
Reward: -9
Terminated: True
Truncated: False


# Display the recorded videos

The videos are renamed to represent the environment and the model type.

In [27]:
# DCT with ASP policy
# Actions: 1, 3, 2, 3, 2, 3, 2, 3
Video("video/dct-6731-asp.mp4")

In [28]:
# DCT with NDNF-MT policy
# Actions: 1, 3, 2, 3, 2, 3, 2, 3
Video("video/dct-6731-ndnf-mt.mp4")

In [29]:
# DCOT with ASP policy
# Actions: 1, 3, 2, 3, 2, 3, 2, 2, 3
Video("video/dcot-6731-asp.mp4")

In [30]:
# DCOT with NDNF-MT policy
# Actions: 1, 3, 2, 3, 2, 3, 2, 2, 3
Video("video/dcot-6731-ndnf-mt.mp4")