In [5]:
import sys, pathlib
sys.path.insert(0, str(pathlib.Path.cwd().parent.parent))  # add repo root
from tsum import tsum
import torch
import json

from ndtools import fun_binary_graph as fbg # ndtools available at github.com/jieunbyun/network-datasets
from ndtools.graphs import build_graph
from pathlib import Path
import networkx as nx   

In [6]:
DATASET = Path("data") 

nodes = json.loads((DATASET / "nodes.json").read_text(encoding="utf-8"))
edges = json.loads((DATASET / "edges.json").read_text(encoding="utf-8"))
probs_dict = json.loads((DATASET / "probs_bin.json").read_text(encoding="utf-8"))

# build base graph
G_base: nx.Graph = build_graph(nodes, edges, probs_dict)

In [7]:
#origin = 'n1'
origin = 'n32'
dests = ['n22', 'n66']

def s_fun(comps_st):
    travel_time, sys_st, info = fbg.eval_travel_time_to_nearest(comps_st, G_base, origin, dests,
                                                         avg_speed=60, # km/h
                                                         target_max = 0.5, # hours: it shouldn't take longer than this compared to the original travel time
                                                         length_attr = 'length_km')
    if sys_st == 's':
       path = info['path_filtered_edges'] 
       min_comps_st = {eid: ('>=', 1) for eid in path} # edges in the path are working
       min_comps_st['sys'] = ('>=', 1) # system edge is also working
    else:
        min_comps_st = None
    return travel_time, sys_st, min_comps_st

row_names = list(edges.keys()) + ['sys']
n_state = 2 # binary states

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
probs = [[probs_dict[n]['0']['p'], probs_dict[n]['1']['p']] for n in row_names[:-1]]
probs = torch.tensor(probs, dtype=torch.float32, device=device)


In [14]:
# run rule extraction: two options available: tsum.run_rule_extraction or tsum.run_rule_extraction_by_mcs
# result = tsum.run_rule_extraction(
result = tsum.run_rule_extraction_by_mcs(
    sfun=s_fun,
    probs=probs,
    row_names=row_names,
    n_state=n_state,
    output_dir="tsum_res",
    surv_json_name="rules_surv.json",
    fail_json_name="rules_fail.json",
    unk_prob_thres = 1e-6
) 

---
Round: 1, Unk. prob.: 1.000e+00
No. of non-dominant rules: 0, Survival rules: 0, Failure rules: 0
Survival sample found from sampling.
No. of existing rules removed:  0
New rule added. System state: s, System value: 0.407880105486707. Total samples: 100000.
New rule (No. of conditions: 4): {'e0063': ('>=', 1), 'e0054': ('>=', 1), 'e0053': ('>=', 1), 'e0042': ('>=', 1), 'sys': ('>=', 1)}
Updated sys_vals: [0.408]
---
Round: 2, Unk. prob.: 1.000e+00
No. of non-dominant rules: 1, Survival rules: 1, Failure rules: 0
Survival sample found from sampling.
No. of existing rules removed:  0
New rule added. System state: s, System value: 0.3984885399102102. Total samples: 100000.
New rule (No. of conditions: 4): {'e0063': ('>=', 1), 'e0049': ('>=', 1), 'e0047': ('>=', 1), 'e0041': ('>=', 1), 'sys': ('>=', 1)}
Updated sys_vals: [0.398, 0.408]
---
Round: 3, Unk. prob.: 3.462e-01
No. of non-dominant rules: 2, Survival rules: 2, Failure rules: 0
Survival sample found from sampling.
No. of existi