In [3]:
import sys, pathlib
sys.path.insert(0, str(pathlib.Path.cwd().parent.parent))  # add repo root
from tsum import tsum
import torch
import json

from ndtools import fun_binary_graph as fbg # ndtools available at github.com/jieunbyun/network-datasets
from ndtools.graphs import build_graph
from pathlib import Path
import networkx as nx   


In [4]:
DATASET = Path("data") 

nodes = json.loads((DATASET / "nodes.json").read_text(encoding="utf-8"))
edges = json.loads((DATASET / "edges.json").read_text(encoding="utf-8"))
probs_dict = json.loads((DATASET / "probs.json").read_text(encoding="utf-8"))

# build base graph
G_base: nx.Graph = build_graph(nodes, edges, probs_dict)

# all edges ON (example); add node/edge 0s as needed
states = {eid: 1 for eid in edges.keys()}

k_val, status, _ = fbg.eval_global_conn_k(states, G_base, target_k=2)
print("k =", k_val, "status =", status)

k = 2 status = s


In [5]:
s_fun = lambda comps_st: fbg.eval_global_conn_k(comps_st, G_base, target_k=2)
row_names = list(edges.keys()) + ['sys']
n_state = 2  # binary states: 0, 1

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
probs = [[probs_dict[n]['0']['p'], probs_dict[n]['1']['p']] for n in row_names[:-1]]
probs = torch.tensor(probs, dtype=torch.float32, device=device)

In [6]:
result = tsum.run_rule_extraction(
    # Problem-specific callables / data
    sfun=s_fun,
    probs=probs,
    row_names=row_names,
    n_state=n_state,
    output_dir="toy_tsum",
    surv_json_name="rules_surv.json",
    fail_json_name="rules_fail.json"
) 

Round: 1, Unk. prob.: 1.000e+00
No. of non-dominant rules: 0, Survival rules: 0, Failure rules: 0
Survival sample found from survival rules üëç
No. of existing rules removed:  0
New rule added. System state: s, flow: 2. Total samples: 1024.
New rule (No. of conditions: 11): {'e01': ('>=', 1), 'e02': ('>=', 1), 'e03': ('>=', 1), 'e04': ('>=', 1), 'e05': ('>=', 1), 'e06': ('>=', 1), 'e07': ('>=', 1), 'e08': ('>=', 1), 'e09': ('>=', 1), 'e10': ('>=', 1), 'e11': ('>=', 1), 'sys': ('>=', 1)}
Updated sys_vals: [2]
Failure sample found from failure rules üëç
No. of existing rules removed:  0
New rule added. System state: f, flow: 0. Total samples: 1024.
New rule (No. of conditions: 11): {'e01': ('<=', 0), 'e02': ('<=', 0), 'e03': ('<=', 0), 'e04': ('<=', 0), 'e05': ('<=', 0), 'e06': ('<=', 0), 'e07': ('<=', 0), 'e08': ('<=', 0), 'e09': ('<=', 0), 'e10': ('<=', 0), 'e11': ('<=', 0), 'sys': ('<=', 0)}
Unique system values: [2, 0]
Round: 2, Unk. prob.: 1.000e+00
No. of non-dominant rules: 2, S

In [None]:
import torch
from typing import Dict, List, Tuple, Any

def classify_samples_with_indices(
    samples: torch.Tensor,
    survival_rules: List[torch.Tensor],
    failure_rules: List[torch.Tensor],
    *,
    return_masks: bool = False
) -> Dict[str, Any]:
    """
    Classify samples as survival, failure, or unknown using subset checks,
    and return indices for each class.

    Args:
        samples: (n_sample, n_var, n_state) binary tensor
        survival_rules: list of rule tensors, each (n_var, n_state) or (n_var+1, n_state)
        failure_rules: list of rule tensors, each (n_var, n_state) or (n_var+1, n_state)
        return_masks: if True, also return boolean masks per class

    Returns:
        {
          'survival': int,
          'failure' : int,
          'unknown' : int,
          'idx_survival': LongTensor[ns],
          'idx_failure' : LongTensor[nf],
          'idx_unknown' : LongTensor[nu],
          # optionally:
          'mask_survival': BoolTensor[n_sample],
          'mask_failure' : BoolTensor[n_sample],
          'mask_unknown' : BoolTensor[n_sample],
        }
    """
    device = samples.device
    n_sample = samples.shape[0]

    # Tracking masks
    classified = torch.zeros(n_sample, dtype=torch.bool, device=device)
    survival_mask = torch.zeros(n_sample, dtype=torch.bool, device=device)
    failure_mask = torch.zeros(n_sample, dtype=torch.bool, device=device)

    # Build (rule_tensor, label) list; drop system row if requested
    def _prep_rules(rules, label):
        out = []
        for r in rules:
            r_ok = r[:-1, :]
            out.append((r_ok.to(device=device, dtype=torch.bool), label))
        return out

    all_rules = _prep_rules(survival_rules, 'survival') + _prep_rules(failure_rules, 'failure')

    # Classification loop
    samples_b = samples.to(device=device, dtype=torch.bool)
    for rule_tensor, label in all_rules:
        unclassified_idx = ~classified
        if not unclassified_idx.any():
            break

        current_samples = samples_b[unclassified_idx]  # (n_curr, n_var, n_state)
        # Subset check: sample ‚äÜ rule  <=>  (sample & rule) == sample  across (var, state)
        is_subset = torch.all((current_samples & rule_tensor) == current_samples, dim=(1, 2))

        # Map back to original indices
        idx_all = torch.where(unclassified_idx)[0]
        matched_idx = idx_all[is_subset]

        if matched_idx.numel() == 0:
            continue

        if label == 'survival':
            survival_mask[matched_idx] = True
        else:
            failure_mask[matched_idx] = True

        classified[matched_idx] = True

    unknown_mask = ~classified

    # Indices
    idx_survival = torch.where(survival_mask)[0]
    idx_failure  = torch.where(failure_mask)[0]
    idx_unknown  = torch.where(unknown_mask)[0]

    result: Dict[str, Any] = {
        'survival': int(survival_mask.sum().item()),
        'failure' : int(failure_mask.sum().item()),
        'unknown' : int(unknown_mask.sum().item()),
        'idx_survival': idx_survival,
        'idx_failure' : idx_failure,
        'idx_unknown' : idx_unknown,
    }

    if return_masks:
        result['mask_survival'] = survival_mask
        result['mask_failure']  = failure_mask
        result['mask_unknown']  = unknown_mask

    return result


In [None]:
import torch
from typing import Dict, Any, Sequence, Optional
from torch import Tensor

def get_comp_cond_sys_prob(
    rules_mat_surv: Tensor,
    rules_mat_fail: Tensor,
    probs: Tensor,
    comps_st_cond: Dict[str, int],
    row_names: Sequence[str],
    s_fun,                          # Callable[[Dict[str,int]], tuple]
    n_sample: int = 1_000_000,
    n_batch:  int = 1_000_000,
    *,
    sys_row: Optional[int] = -1,    # index of the system row to exclude when building comps; set to None to include all
) -> Dict[str, float]:
    """
    P(system state | given component states).

    - 'probs' is (n_var, n_state) categorical; we condition rows listed in comps_st_cond to one-hot.
    - We classify samples using rules; for unknowns we call s_fun(comps_dict) to resolve.
    - Returns probabilities over {'survival','failure','unknown'} that sum ~ 1.0.

    Notes
    -----
    - If sys_row is not None, that row is excluded when building the comps dict for s_fun.
      Default -1 means the last row is the system variable.
    """
    # --- clone probs and apply conditioning ---
    if torch.is_tensor(probs):
        probs_cond = probs.clone()
        n_comps, n_states = probs_cond.shape
        n_vars = n_comps + 1 # system event
    else:
        raise TypeError("Expected 'probs' to be a torch.Tensor of shape (n_var, n_state).")

    if len(row_names) != n_vars:
        raise ValueError(f"row_names length ({len(row_names)}) must match probs rows ({n_vars}).")

    for x, s in comps_st_cond.items():
        try:
            row_idx = row_names.index(x)
        except ValueError:
            raise ValueError(f"Component {x} not found in row_names.")
        if not (0 <= int(s) < n_states):
            raise ValueError(f"State {s} for component {x} is out of bounds [0,{n_states-1}].")
        probs_cond[row_idx].zero_()
        probs_cond[row_idx, int(s)] = 1.0

    # --- sampling loop (exactly n_sample draws) ---
    batch_size = max(1, min(int(n_batch), int(n_sample)))
    remaining = int(n_sample)

    counts = {"survival": 0, "failure": 0, "unknown": 0}

    while remaining > 0:
        b = min(batch_size, remaining)
        # IMPORTANT: sample from the *conditioned* probs
        samples = tsum.sample_categorical(probs_cond, b)  # (b, n_var, n_state) one-hot

        res = classify_samples_with_indices(
            samples, rules_mat_surv, rules_mat_fail, return_masks=True
        )

        counts["survival"] += int(res["survival"])
        counts["failure"]  += int(res["failure"])

        # Resolve unknowns with s_fun
        idx_unknown = res["idx_unknown"]
        if idx_unknown.numel() > 0:
            # precompute the system row index if excluding
            sys_idx = None
            if sys_row is not None:
                sys_idx = sys_row if sys_row >= 0 else (n_vars + sys_row)

            for j in idx_unknown.tolist():
                sample_j = samples[j]  # (n_var, n_state)
                # convert one-hot row -> state index per var
                states = torch.argmax(sample_j, dim=1).tolist()

                # build comps dict for s_fun, excluding system row if requested
                if sys_idx is not None:
                    comps = {row_names[k]: int(states[k]) for k in range(n_vars) if k != sys_idx}
                else:
                    comps = {row_names[k]: int(states[k]) for k in range(n_vars)}

                _, sys_st, _ = s_fun(comps)

                if sys_st in ("s", "survival", 1, True):
                    counts["survival"] += 1
                elif sys_st in ("f", "failure", 0, False):
                    counts["failure"] += 1

        remaining -= b

    # --- normalize to probabilities (denominator = requested n_sample) ---
    total = float(n_sample)
    cond_probs = {k: counts[k] / total for k in counts}
    return cond_probs


In [12]:
TSUMPATH = Path("toy_tsum") 

rules_mat_surv = torch.load(r"toy_tsum/rules_surv.pt", map_location="cpu")
rules_mat_surv = rules_mat_surv.to(device)
rules_mat_fail = torch.load(r"toy_tsum/rules_fail.pt", map_location="cpu")
rules_mat_fail = rules_mat_fail.to(device)


  rules_mat_surv = torch.load(r"toy_tsum/rules_surv.pt", map_location="cpu")
  rules_mat_fail = torch.load(r"toy_tsum/rules_fail.pt", map_location="cpu")


In [13]:
pr_cond = get_comp_cond_sys_prob(
    rules_mat_surv,
    rules_mat_fail,
    probs,
    comps_st_cond = {'e01': 1, 'e02': 1},
    row_names = row_names,
    s_fun = s_fun
)

ValueError: row_names length (12) must match probs rows (11).

In [None]:
import copy

def get_comp_cond_sys_prob(rules_mat_surv, rules_mat_fail, probs, comps_st_cond, row_names, s_fun, n_sample = 1_000_000, n_batch = 1_000_000):
    """
    Get conditional system survival/failure probabilities given component states: P(sys | comps_st_cond)
    """
    
    # Update probs based on comps_st_cond
    probs_cond = copy.deepcopy(probs)
    for x, s in comps_st_cond.items():
        if x not in row_names:
            raise ValueError(f"Component {x} not found in row_names.")
        row_idx = row_names.index(x)
        if s < 0 or s >= probs.shape[1]:
            raise ValueError(f"State {s} for component {x} is out of bounds.")
        probs_cond[row_idx, :] = 0.0
        probs_cond[row_idx, s] = 1.0
    
    sample_batch_size = min(n_sample, n_batch)
    total_loops = max(n_sample // sample_batch_size, 1)
    counts = {"survival": 0, "failure": 0, "unknown": 0}
    for i in range(total_loops):
        samples = tsum.sample_categorical(probs, sample_batch_size)
        res = classify_samples_with_indices(samples, rules_mat_surv, rules_mat_fail, return_masks=True)

        counts["survival"] += res["survival"]
        counts["failure"] += res["failure"]

        unknown_mask = res["mask_unknown"]                   # BoolTensor[n_sample]
        unknown_samples = samples[unknown_mask]              # (n_unknown, n_var, n_state)

        for sample_j in unknown_samples:
            states = torch.argmax(sample_j, dim=1).tolist()
            comps = dict(zip(row_names[:-1], states[:-1])) # exclude system state
            fval, sys_st, _ = s_fun(comps)

            if sys_st == 's':
                counts["survival"] += 1
            else:
                counts["failure"] += 1

    cond_probs = {k: v / n_sample for k, v in counts.items()}

    return cond_probs
