# Runtime Comparison - Taxi

Compare the run time of different models: Q table, MLP, NDNF-MT and logic-based
programs.

Each model is run 100 episodes.


In [1]:
from datetime import datetime
from pathlib import Path
import sys

sys.path.append("..")

In [2]:
import gymnasium as gym
import numpy as np
import pandas as pd
import torch

In [3]:
from neural_dnf.neural_dnf import NeuralDNFMutexTanh

from eval.problog_inference_common import prolog_inference_in_env_single_run
from eval.taxi_ppo_rl_eval_common import eval_model_on_environment
from eval.taxi_problog_rules_inference import taxi_problog_context_gen_fn
from eval.taxi_distillation_rl_eval_common import (
    eval_on_environments,
    eval_get_ndnf_action,
)
from taxi_common import construct_model, taxi_env_preprocess_obs

In [4]:
DEVICE = torch.device("cpu")
NUM_EPISODES = 10000

In [5]:
taxi_env = gym.make("Taxi-v3", render_mode=None)

# Q table


In [6]:
# Q table
with open(
    Path(
        "../results/Taxi-TAB/TAXI-TAB-q-1e4/TAXI-TAB-q-1e4-1771/taxi_tab_q_1e4_1771.csv"
    ),
    "r",
) as f:
    df = pd.read_csv(f, index_col=None)
target_policy = df.to_numpy()


def get_action_from_q_table(
    q_table: np.ndarray, obs: int, use_argmax: bool, epsilon: float
) -> int:
    if use_argmax:
        return int(np.argmax(q_table[obs]))
    else:
        if np.random.rand() < epsilon:
            return np.random.randint(2)
        return int(np.argmax(q_table[obs]))


start_time = datetime.now().timestamp()

for _ in range(NUM_EPISODES):
    obs, _ = taxi_env.reset()
    terminated, truncated = False, False
    while not terminated and not truncated:
        action = get_action_from_q_table(target_policy, obs, False, 0.1)
        obs, _, terminated, truncated, _ = taxi_env.step(action)

end_time = datetime.now().timestamp()

print(f"Time taken: {end_time - start_time}")
print(f"Avg time per episode: {(end_time - start_time) / NUM_EPISODES}")

Time taken: 1.2021000385284424
Avg time per episode: 0.00012021000385284423


## MLP


In [7]:
# MLP
mlp_model = construct_model(
    actor_latent_size=256,
    use_ndnf=False,
    use_decode_obs=False,
    use_eo=False,
    use_mt=False,
    share_layer_with_critic=False,
    critic_latent_1=256,
    critic_latent_2=256,
    pretrained_critic=None,
    mlp_actor_disable_bias=False,
)
mlp_model.to(DEVICE)
sd = torch.load(
    "../taxi_ppo_storage/taxi_ppo_mlp_raw_al256_cr256x256tanh_3e6_4839/model.pth",
    map_location=DEVICE,
)
mlp_model.load_state_dict(sd)
mlp_model.eval()

  sd = torch.load(


TaxiEnvPPOMLPAgent(
  (actor): Sequential(
    (0): Linear(in_features=500, out_features=256, bias=True)
    (1): Tanh()
    (2): Linear(in_features=256, out_features=6, bias=True)
  )
  (critic): Sequential(
    (0): Linear(in_features=500, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=1, bias=True)
  )
)

In [8]:
# Non-parallel
reward_list = []
start_time = datetime.now().timestamp()

for _ in range(NUM_EPISODES):
    obs, _ = taxi_env.reset()
    reward_sum = 0
    terminated, truncated = False, False
    while not terminated and not truncated:
        with torch.no_grad():
            obs_tensor = taxi_env_preprocess_obs(
                obs=np.array([obs]),
                use_ndnf=False,
                device=DEVICE,
            )
            action = mlp_model.get_actions(obs_tensor, use_argmax=False)[0]
        obs, reward, terminated, truncated, _ = taxi_env.step(action)
        reward_sum += reward
    reward_list.append(reward_sum)

end_time = datetime.now().timestamp()

print(f"Avg reward: {np.mean(reward_list)}")
print(f"Time taken: {end_time - start_time}")
print(f"Avg time per episode: {(end_time - start_time) / NUM_EPISODES}")

Avg reward: 7.6184
Time taken: 17.337646961212158
Avg time per episode: 0.0017337646961212158


In [9]:
# Parallel
start_time = datetime.now().timestamp()
ret = eval_model_on_environment(
    model=mlp_model,
    device=DEVICE,
    use_argmax=False,
    eval_num_runs=NUM_EPISODES,
)
end_time = datetime.now().timestamp()

print(f"Avg reward: {np.mean(ret['return_per_episode'])}")
print(f"Time taken: {end_time - start_time}")
print(f"Avg time per episode: {(end_time - start_time) / NUM_EPISODES}")

Avg reward: 7.5872
Time taken: 7.279323101043701
Avg time per episode: 0.0007279323101043702


# NDNF-MT


In [10]:
# NDNF-MT
ndnf_mt_model = NeuralDNFMutexTanh(
    num_preds=500,
    num_conjuncts=64,
    n_out=6,
    delta=1.0,
)
ndnf_mt_model.to(DEVICE)
sd = torch.load(
    "../taxi_distillation_storage/taxi_distillation_ndnf_mt_actdist_nc64_e5e3_5874/model.pth",
    map_location=DEVICE,
)
ndnf_mt_model.load_state_dict(sd)
ndnf_mt_model.eval()

  sd = torch.load(


NeuralDNFMutexTanh(
  (conjunctions): SemiSymbolic(in_features=500, out_features=64, layer_type=SemiSymbolicLayerType.CONJUNCTION,current_delta=1.00)
  (disjunctions): SemiSymbolicMutexTanh(in_features=64, out_features=6, layer_type=SemiSymbolicLayerType.DISJUNCTION,current_delta=1.00)
)

In [11]:
# Non-parallel

reward_list = []
has_truncation = False
start_time = datetime.now().timestamp()

for _ in range(NUM_EPISODES):
    obs, _ = taxi_env.reset()
    reward_sum = 0
    terminated, truncated = False, False
    while not terminated and not truncated:
        with torch.no_grad():
            action = eval_get_ndnf_action(
                ndnf_mt_model, np.array([obs]), DEVICE, use_argmax=False
            )[0][0].item()
        obs, reward, terminated, truncated, _ = taxi_env.step(action)
        reward_sum += reward
    reward_list.append(reward_sum)
    if truncated:
        has_truncation = True

end_time = datetime.now().timestamp()

print(f"Avg reward: {np.mean(reward_list)}")
print(f"Time taken: {end_time - start_time}")
print(f"Avg time per episode: {(end_time - start_time) / NUM_EPISODES}")

Avg reward: 7.415
Time taken: 49.95405912399292
Avg time per episode: 0.004995405912399292


In [12]:
# Parallel
start_time = datetime.now().timestamp()
ret = eval_on_environments(
    ndnf_model=ndnf_mt_model,
    device=DEVICE,
    use_argmax=False,
    num_episodes=NUM_EPISODES,
)
end_time = datetime.now().timestamp()

print(f"Avg reward: {ret['env_eval_avg_return_per_episode']}")
print(f"Time taken: {end_time - start_time}")
print(f"Avg time per episode: {(end_time - start_time) / NUM_EPISODES}")

Avg reward: 7.4479
Time taken: 9.790256023406982
Avg time per episode: 0.0009790256023406983


# ProbLog

Each inference takes more than 30min. We do not evaluate the run time of ProbLog.