# Transition System Verification

This notebook demonstrates verifying stateful hardware using transition systems.

Based on slides: **02-transition-system.pdf**

## Concepts:
- **State**: Registers + Memory
- **Init**: Initial state formula
- **Step**: Transition relation (state → state')
- **Invariant**: Property that holds in all reachable states

In [None]:
import sys
sys.path.insert(0, '.')

import pyrtl
from pyrtl import *
import z3
from circuit import net_to_smt
from transition_system import net_to_smt_stateful
from verification_utils import CHCs

## Example 1: Simple Counter

From the slides:
- **Init**: `r = 00`
- **Step**: `r' = r + 01`

We'll verify that `r < 4` (fits in 2 bits).

In [None]:
# Create a simple counter circuit
pyrtl.reset_working_block()

r = pyrtl.Register(bitwidth=2, name='r')
r.next <<= (r + 1)[:2]  # Increment and wrap at 2 bits

wb = pyrtl.working_block()
print("Counter circuit:")
for net in wb:
    print(f"  {net.op:3} {str(net.op_param or ''):10} "
          f"{[(a.name, a.bitwidth) for a in net.args]} -> "
          f"{[(d.name, d.bitwidth) for d in net.dests]}")

In [None]:
# Create state variables
r_current = z3.BitVec('r', 2)
r_next = z3.BitVec("r'", 2)

# Initial state: r = 0
init = (r_current == 0)

# Step relation: r' = r + 1 (mod 4)
step = (r_next == z3.Extract(1, 0, r_current + 1))

# Invariant: always true (counter always fits in 2 bits)
inv = z3.BoolVal(True)

print("Transition System:")
print(f"  Init: {init}")
print(f"  Step: {step}")
print(f"  Inv:  {inv}")

# Verify initiation
s = z3.Solver()
s.add(init)
s.add(z3.Not(inv))
print(f"\nInitiation check: {s.check()}")
print("  (should be unsat - init implies inv)")

# Verify consecution  
s = z3.Solver()
s.add(inv)
s.add(step)
inv_next = inv  # inv is always true, so inv' is also true
s.add(z3.Not(inv_next))
print(f"\nConsecution check: {s.check()}")
print("  (should be unsat - inv && step implies inv')")

## Example 2: Memory Loop

From verify.ipynb:
```python
sp = Register(bitwidth=3, name='sp')
mem = MemBlock(bitwidth=3, addrwidth=3, name='mem')

mem[sp] <<= (mem[sp] + 1)[:3]
mem[0] <<= 0
```

This increments memory at location `sp` and sets `mem[0] = 0`.

In [None]:
# Create memory loop circuit
pyrtl.reset_working_block()

sp = pyrtl.Register(bitwidth=3, name='sp')
mem = pyrtl.MemBlock(bitwidth=3, addrwidth=3, name='mem', max_write_ports=2)

# Increment memory at location sp
mem[sp] <<= (mem[sp] + 1)[:3]
# Always set mem[0] = 0
mem[0] <<= 0

# Keep sp constant for now
sp.next <<= sp

wb = pyrtl.working_block()
print("Memory loop circuit:")
for net in wb:
    print(f"  {net.op:3} {str(net.op_param or ''):10} "
          f"{[(a.name, a.bitwidth) for a in net.args]} -> "
          f"{[(d.name, d.bitwidth) for d in net.dests]}")

In [None]:
# Translate to SMT with state
sp_current = z3.BitVec('sp', 3)
sp_next = z3.BitVec("sp'", 3)
mem_current = z3.Array('mem', z3.BitVecSort(3), z3.BitVecSort(3))
mem_next = z3.Array("mem'", z3.BitVecSort(3), z3.BitVecSort(3))

# Manually encode the transition (simplified)
init = z3.And(
    sp_current == 1,
    z3.Select(mem_current, 0) == 0,
    z3.Select(mem_current, 1) == 0
)

# Step: mem'[sp] = mem[sp] + 1, mem'[0] = 0, sp' = sp
step = z3.And(
    sp_next == sp_current,
    z3.Select(mem_next, sp_current) == z3.Extract(2, 0, z3.Select(mem_current, sp_current) + 1),
    z3.Select(mem_next, 0) == 0
)

# Invariant: mem[0] is always 0
inv = (z3.Select(mem_current, 0) == 0)
inv_next = (z3.Select(mem_next, 0) == 0)

print("Verifying invariant: mem[0] == 0")

# Check initiation
s = z3.Solver()
s.add(init)
s.add(z3.Not(inv))
result = s.check()
print(f"\nInitiation: {result}")
if result == z3.sat:
    print("  Counterexample found!")
else:
    print("  ✓ Init implies inv")

# Check consecution
s = z3.Solver()
s.add(inv)
s.add(step)
s.add(z3.Not(inv_next))
result = s.check()
print(f"\nConsecution: {result}")
if result == z3.sat:
    print("  Counterexample found!")
    model = s.model()
    print(f"  sp = {model.eval(sp_current)}")
else:
    print("  ✓ Inv && step implies inv'")

print("\n" + "="*60)
print("Invariant verified! mem[0] is always 0.")
print("="*60)

## Constrained Horn Clauses (CHCs)

CHCs are an alternative way to express verification problems:

```
init(state) => Inv(state)
Inv(state) ∧ step(state, state') => Inv(state')
Inv(state) => Safety(state)
```

Z3's Spacer engine can solve these automatically!

In [None]:
# Example: Counter with CHCs
from z3 import Function, ForAll, Implies, Solver, sat, unsat

# Declare uninterpreted predicate
Inv = z3.Function('Inv', z3.BitVecSort(2), z3.BoolSort())

# Variables
r = z3.BitVec('r', 2)
r_prime = z3.BitVec("r'", 2)

# CHC rules
rules = [
    # Init => Inv
    z3.Implies(r == 0, Inv(r)),
    
    # Inv(r) && r' = r+1 => Inv(r')
    z3.Implies(
        z3.And(Inv(r), r_prime == z3.Extract(1, 0, r + 1)),
        Inv(r_prime)
    ),
    
    # Inv(r) => r < 4 (always true for 2-bit values)
    z3.Implies(Inv(r), z3.ULE(r, 3))
]

# Create CHC solver
chc_solver = CHCs(rules, {Inv})
result = chc_solver.solve()

print(f"CHC verification result: {type(result).__name__}")
if hasattr(result, 'model'):
    print("Solution found!")
    print(f"Inv = {result[Inv]}")
else:
    print("Verification successful!")

## Next Steps

From the assignment (02-transition-system.pdf):

1. ✓ Extend `net_to_smt` to handle state
2. ✓ Verify simple memory loop
3. ⬜ Encode stack machine as CHCs
4. ⬜ Verify PUSH/POP operations
5. ⬜ Specify and verify PUSH/POP semantics

Continue in: `verify_stack_machine.ipynb`