TODO

In [1]:
import os
import sys
import torch
import numpy as np
import torch.nn as nn
import numpy as np
from sklearn.decomposition import PCA
from sklearn.cluster import DBSCAN
from sklearn.cluster import DBSCAN
from collections import defaultdict
from torch.nn.utils.rnn import pad_sequence
from hopcroft import hopcroft_minimize

# Add project root (one level up) to Python path 
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from models.parity_rnn import ParityRNN

Load Trained model

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ParityRNN(input_size=2, hidden_size=100, output_size=2).to(device)
checkpoint = torch.load('../models/checkpoints/parity_rnn_checkpoint.pt', map_location=device)
model.load_state_dict(checkpoint["model_state"])
model.eval()

  checkpoint = torch.load('../models/checkpoints/parity_rnn_checkpoint.pt', map_location=device)


ParityRNN(
  (rnn): RNN(2, 100, batch_first=True)
  (readout): Linear(in_features=100, out_features=2, bias=True)
)

# --- 2. Load validation sequences ---


In [3]:
val = np.load('../data/validation/val_data_50.npz', allow_pickle=True)
val_seqs = val['X']  # shape (100,), each an array of bits
val_labels = val['Y']

# --- 3. Collect Hidden State Trajectories ---

### 3a. Full trajectories for DFA extraction


In [4]:
H_all, X_all, idx_map = [], [], []
model.eval()
with torch.no_grad():
    for seq_i, seq in enumerate(val_seqs):
        # build one-hot batch
        x = torch.zeros(len(seq), 2, device=device)
        for t, b in enumerate(seq):
            x[t, b] = 1.0
        out, _ = model.rnn(x.unsqueeze(0))       # [1, T, H]
        h_seq = out.squeeze(0).cpu().numpy()     # (T, H)
        H_all.append(h_seq)
        X_all.append(list(seq))
        # record index map
        for t in range(h_seq.shape[0]):
            idx_map.append((seq_i, t))

### 3b. Final hidden states for clustering parities

In [5]:
H_finals, parities = [], []
for seq_i, label in enumerate(val_labels):
    h_seq = H_all[seq_i]              # full trajectory for sequence i
    H_finals.append(h_seq[-1])        # last hidden state
    parities.append(int(label))
H_finals = np.stack(H_finals)         # (N, H)
parities  = np.array(parities) 
print(f"Collected {len(H_all)} sequences of trajectories and {H_finals.shape[0]} final hidden-state vectors.") 
print(f"Collected {H_finals.shape[0]} final hidden-state vectors.")

Collected 100 sequences of trajectories and 100 final hidden-state vectors.
Collected 100 final hidden-state vectors.


# --- 4. Cluster Hidden States with DBSCAN ---

In [91]:
# 3.1 Flatten all hidden states (order = idx_map order)
flat_H = np.vstack(H_all)          # shape [N_tot, H]

# 3.2 PCA to lower dimension (tune n_components if you like)
pca = PCA(n_components=8, random_state=42)
flat_H_p = pca.fit_transform(flat_H)

print(f"PCA explained variance ratio sum: {pca.explained_variance_ratio_.sum():.3f}")

PCA explained variance ratio sum: 0.952


In [95]:
# 3.3 DBSCAN with an adaptive eps search until we get a reasonable #clusters

eps = 0.08
min_samples = 5
db = DBSCAN(eps=eps, min_samples=min_samples).fit(flat_H_p)
all_labels = db.labels_
unique = set(all_labels)
# we expect at least the two final-state parity clusters plus noise
if len(unique) >= 3:
    print(f"DBSCAN clustering over all states with eps={eps} gave clusters={unique}")
else:
    raise ValueError("DBSCAN failed to find enough clusters; adjust eps_values/min_samples.")

DBSCAN clustering over all states with eps=0.08 gave clusters={0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22}


In [96]:
# 3.4 Rebuild per-sequence labels using idx_map
labels_per_seq = [[] for _ in range(len(H_all))]
for (seq_i, t), lbl in zip(idx_map, all_labels):
    labels_per_seq[seq_i].append(lbl)

final_labels = np.array([seq_lbls[-1] for seq_lbls in labels_per_seq])  # label at last timestep

# Identify final-time non-noise clusters
clusters_final = set(final_labels) - {-1}

# Map each final cluster to parity by ground truth
means = {}
for c in clusters_final:
    idxs = np.where(final_labels == c)[0]
    means[c] = val_labels[idxs].mean()  # assuming val_labels are 0/1 parity
odd_cluster  = max(means, key=means.get)
even_cluster = min(means, key=means.get)
print(f"Cluster→parity: {even_cluster}=even, {odd_cluster}=odd")

Cluster→parity: 2=even, 3=odd


# --- 6. Build the Raw DFA Transitions ---

In [97]:
Sigma = [0, 1]
raw_Q = sorted(unique - {-1})  # drop noise label

# Initialize count dict for majority vote
from collections import defaultdict
trans_counts = {(q,b): defaultdict(int) for q in raw_Q for b in Sigma}

# Fill counts by walking trajectories
for (seq_i, t), lbl in zip(idx_map, all_labels):
    if lbl not in raw_Q:          # skip noise
        continue
    if t + 1 >= len(X_all[seq_i]):
        continue                  # no next state
    bit = X_all[seq_i][t]
    nxt = all_labels[idx_map.index((seq_i, t+1))]
    if nxt in raw_Q:
        trans_counts[(lbl, bit)][nxt] += 1

# Majority vote to decide single next state
delta = {}
for key, cnts in trans_counts.items():
    if cnts:  # if we saw anything
        delta[key] = max(cnts, key=cnts.get)
    else:
        # default self-loop if never observed
        q,_ = key
        delta[key] = q

print("Raw DFA transitions:")
for (q,b), nxt in delta.items():
    print(f"  {q} --{b}--> {nxt}")


Raw DFA transitions:
  0 --0--> 0
  0 --1--> 21
  1 --0--> 1
  1 --1--> 5
  2 --0--> 13
  2 --1--> 2
  3 --0--> 3
  3 --1--> 4
  4 --0--> 4
  4 --1--> 2
  5 --0--> 5
  5 --1--> 9
  6 --0--> 7
  6 --1--> 6
  7 --0--> 7
  7 --1--> 8
  8 --0--> 8
  8 --1--> 9
  9 --0--> 9
  9 --1--> 2
  10 --0--> 12
  10 --1--> 10
  11 --0--> 11
  11 --1--> 11
  12 --0--> 12
  12 --1--> 17
  13 --0--> 14
  13 --1--> 13
  14 --0--> 22
  14 --1--> 14
  15 --0--> 15
  15 --1--> 16
  16 --0--> 10
  16 --1--> 16
  17 --0--> 3
  17 --1--> 17
  18 --0--> 20
  18 --1--> 18
  19 --0--> 19
  19 --1--> 4
  20 --0--> 22
  20 --1--> 20
  21 --0--> 10
  21 --1--> 21
  22 --0--> 22
  22 --1--> 22


In [9]:
# --- 7. Hopcroft Minimization ---

In [105]:
# ========= STEP 6: Build reps, totalize δ, prune, (optional) merge unstable, minimize =========

# 6.0 Fast index (seq_i,t) -> flat idx
idx2flat = {pair: i for i, pair in enumerate(idx_map)}

# 6.1 Initial state set Q (all non-noise clusters)
Q = sorted(unique - {-1})
print(f"Initial states (non-noise): {len(Q)}")

# 6.2 representative_hidden for every q in Q
representative_hidden = {}
for (seq_i, t), lbl in zip(idx_map, all_labels):
    if lbl in Q and lbl not in representative_hidden:
        representative_hidden[lbl] = torch.from_numpy(H_all[seq_i][t]).to(device)
    if len(representative_hidden) == len(Q):
        break
# fill any missing with a dummy
if len(representative_hidden) < len(Q):
    any_rep = next(iter(representative_hidden.values()))
    for q in Q:
        representative_hidden.setdefault(q, any_rep)

# 6.3 Outputs & probabilities (Moore machine)
outputs_full = {}
probs_full   = {}
model.eval()
with torch.no_grad():
    for q, h in representative_hidden.items():
        logits = model.readout(h.unsqueeze(0))[0]          # [2]
        probs = torch.softmax(logits, dim=0)                # [2]
        outputs_full[q] = int(torch.argmax(logits).item()) # 0/1
        probs_full[q]   = probs[1].item()                   # P(odd)

# 6.4 Ensure δ is total & closed by adding a sink if needed
Sigma = [0, 1]
sink = max(Q) + 1
need_sink = False
for q in Q:
    for a in Sigma:
        if (q, a) not in delta:
            delta[(q, a)] = sink
            need_sink = True

if need_sink:
    Q.append(sink)
    # give sink a rep & output
    rep_any = next(iter(representative_hidden.values()))
    representative_hidden[sink] = rep_any
    outputs_full[sink] = 0
    probs_full[sink]   = 0.0

# paranoia: ensure every target is in Q
targets = {delta[(q,a)] for q in Q for a in Sigma}
missing = targets - set(Q)
if missing:
    rep_any = next(iter(representative_hidden.values()))
    for m in missing:
        Q.append(m)
        representative_hidden[m] = rep_any
        outputs_full[m] = 0
        probs_full[m]   = 0.0

# 6.5 Reachability prune
start_lbl = all_labels[idx2flat[(0, 0)]]
reachable = {start_lbl}
frontier = {start_lbl}
while frontier:
    nxt = set()
    for q in frontier:
        for a in Sigma:
            q2 = delta[(q, a)]
            if q2 not in reachable:
                reachable.add(q2)
                nxt.add(q2)
    frontier = nxt

Q = sorted(reachable)
delta = {(q,a): delta[(q,a)] for q in Q for a in Sigma}
outputs = {q: outputs_full[q] for q in Q}
probs   = {q: probs_full[q]   for q in Q}
representative_hidden = {q: representative_hidden[q] for q in Q}

print(f"Reachable states: {len(Q)}")

# (Optionally you could also re-run reachability / drop unstable from Q,
# but it's okay to leave them; Hopcroft will merge them if truly equivalent.)

# assume even_cluster, odd_cluster from final labels
cent_even = representative_hidden[even_cluster]
cent_odd  = representative_hidden[odd_cluster]

def to_parity_block(q):
    h = representative_hidden[q]
    return even_cluster if torch.norm(h-cent_even) < torch.norm(h-cent_odd) else odd_cluster

# coarsen transitions
delta_par = {}
for (q,a), qn in delta.items():
    delta_par[(q,a)] = to_parity_block(qn)

# parity outputs
outputs_par = {even_cluster:0, odd_cluster:1}
Q_par = [even_cluster, odd_cluster]
Sigma = [0,1]
for q in Q_par:
    for a in Sigma:
        delta_par.setdefault((q,a), q)

print(f"Total states (after merge step): {len(Q)} | Output classes: {len(set(outputs.values()))}")

# 6.7 Minimize (Moore version)
min_states, Sigma, min_delta, min_outputs, state_map = hopcroft_minimize(
    Q_par, Sigma, delta_par, outputs_par
)
print(f"Minimized to {len(min_states)} states: {min_states}")
print("Outputs per minimized state:", min_outputs)

# Debug: show merged blocks
blocks = {i: [] for i in min_states}
for q_old, q_new in state_map.items():
    blocks[q_new].append(q_old)
print("Merged blocks (old->new):")
for new_id, olds in blocks.items():
    print(f"  {new_id}: {sorted(olds)}")


Initial states (non-noise): 23
Reachable states: 11
Total states (after merge step): 11 | Output classes: 2
Minimized to 2 states: [0, 1]
Outputs per minimized state: {0: 0, 1: 1}
Merged blocks (old->new):
  0: [2]
  1: [3]


In [116]:
# Assume: min_states, state_map, blocks, Sigma, representative_hidden, model are defined
Wxh = model.rnn.weight_ih_l0
Whh = model.rnn.weight_hh_l0
b   = model.rnn.bias_ih_l0 + model.rnn.bias_hh_l0
act = torch.nn.ReLU()

def step(h, bit):
    x = torch.zeros(2, device=h.device); x[bit] = 1.0
    return act(Wxh @ x + Whh @ h + b)

# Representative hidden per minimized block
rep_block = {}
for s, olds in blocks.items():
    Hs = torch.stack([representative_hidden[q] for q in olds], 0)
    rep_block[s] = Hs.median(0).values  # median is more robust than mean

# Map output→state (Moore)
block_with_output = {}
for s in min_states:
    block_with_output[min_outputs[s]] = s

min_delta_sim = {}
for s, olds in blocks.items():
    for a in Sigma:
        votes = []
        for q_old in olds:
            h1 = step(representative_hidden[q_old], a)
            # First try output-based routing
            with torch.no_grad():
                logits = model.readout(h1.unsqueeze(0))[0]
                out = int(torch.argmax(logits))
            if out in block_with_output:
                nxt = block_with_output[out]
            else:
                # fallback: nearest centroid
                dmin, nxt = 1e9, None
                for s2, hb in rep_block.items():
                    d = torch.norm(h1 - hb).item()
                    if d < dmin:
                        dmin, nxt = d, s2
            votes.append(nxt)
        min_delta_sim[(s,a)] = max(set(votes), key=votes.count)

names = {s: ("odd" if min_outputs[s]==1 else "even") for s in min_states}
print("Final minimized DFA transitions (simulated & voted):")
for s in min_states:
    for a in Sigma:
        nxt = min_delta_sim[(s,a)]
        print(f"  {names[s]} ({s}) --{a}--> {names[nxt]} ({nxt})")




Final minimized DFA transitions (simulated & voted):
  even (0) --0--> even (0)
  even (0) --1--> odd (1)
  odd (1) --0--> odd (1)
  odd (1) --1--> even (0)


In [124]:
import json

# ---------- 1) Choose which minimized delta to export ----------
# Use the simulated & voted one you just fixed:
final_delta = min_delta_sim          # dict[(s,a)] = s_next
states      = list(min_states)       # [0, 1]
alphabet    = [0, 1]
outputs     = {int(s): int(min_outputs[s]) for s in states}

# Figure out which minimized state is the start (map original start cluster)
orig_start = all_labels[idx2flat[(0,0)]]
start_state = state_map[even_cluster]   # minimized id

# ---------- 2) Save to JSON ----------
dfa_json = {
    "states": states,
    "alphabet": alphabet,
    "start": int(start_state),
    "transitions": {f"{s},{a}": int(final_delta[(s,a)]) for s in states for a in alphabet},
    "outputs": outputs                # Moore machine outputs
}
with open("dfas/minimal_dfa.json", "w") as f:
    json.dump(dfa_json, f, indent=2)
print("Saved JSON to dfas/minimal_dfa.json")

# ---------- 3) Make a Graphviz DOT (looks like the paper’s figure) ----------
# Helpers to label states as even/odd
name = {s: ("even" if outputs[s]==0 else "odd") for s in states}

dot_lines = [
    "digraph DFA {",
    "  rankdir=LR;",
    '  node [shape=circle, style=filled, fillcolor="#cfeeee", fontsize=18];'
]

# Invisible start node arrow
dot_lines.append('  __start [shape=point, width=0];')
dot_lines.append(f'  __start -> {start_state};')

# State nodes (doublecircle optional if you want to highlight something)
for s in states:
    lab = f"{name[s]}\\n({outputs[s]})"
    dot_lines.append(f'  {s} [label="{lab}"];')

# Edges
for s in states:
    for a in alphabet:
        nxt = final_delta[(s,a)]
        # put label "0" or "1" on the edge
        dot_lines.append(f'  {s} -> {nxt} [label="{a}", fontsize=18];')

dot_lines.append("}")
dot_str = "\n".join(dot_lines)
with open("dfas/minimal_dfa.dot", "w") as f:
    f.write(dot_str)
print("Saved DOT to dfas/minimal_dfa.dot")

# If graphviz is installed in your environment, auto-render to PNG:
try:
    import graphviz
    g = graphviz.Source(dot_str)
    g.render("dfas/minimal_dfa", format="png", cleanup=True)
    print("Rendered to minimal_dfa.png")
except Exception as e:
    print("Graphviz python package not available; run:\n  dot -Tpng minimal_dfa.dot -o minimal_dfa.png")



Saved JSON to dfas/minimal_dfa.json
Saved DOT to dfas/minimal_dfa.dot
Rendered to minimal_dfa.png
