# Parametric k-bit Adder Verification

This notebook demonstrates:
1. Creating a parametric k-bit adder
2. Translating it to SMT using `net_to_smt`
3. Verifying correctness with Z3
4. Benchmarking verification time vs. bit width

In [None]:
import sys
sys.path.insert(0, '../base')

import pyrtl
from pyrtl import *
import z3
from circuit import net_to_smt
import time
import matplotlib.pyplot as plt

## Define the Parametric Adder

In [None]:
def make_adder(k):
    """
    Create a k-bit adder circuit.
    
    Args:
        k: bit width of the adder inputs and output
    
    Returns:
        tuple: (a, b, sum) - the input and output wires
    """
    pyrtl.reset_working_block()
    
    a = pyrtl.Input(bitwidth=k, name='a')
    b = pyrtl.Input(bitwidth=k, name='b')
    sum_out = pyrtl.Output(bitwidth=k, name='sum')
    
    # Implement the adder - truncate to k bits
    sum_out <<= (a + b)[:k]
    
    return a, b, sum_out

## Verify a Single Adder

In [None]:
# Create a 4-bit adder
k = 4
print(f"Creating {k}-bit adder...")
a, b, sum_out = make_adder(k)
wb = pyrtl.working_block()

print(f"Circuit has {len(list(wb))} logic nets")
print("\nCircuit structure:")
for net in wb:
    print(f"  {net.op:3} {str(net.op_param or ''):10} "
          f"{[(a.name, a.bitwidth) for a in net.args]} -> "
          f"{[(d.name, d.bitwidth) for d in net.dests]}")

In [None]:
# Translate to SMT
print("\nTranslating to SMT...")
wires, ops, assertions = net_to_smt(wb)

print(f"Generated {len(assertions)} SMT assertions:")
for assertion in assertions:
    print(f"  {assertion}")

In [None]:
# Verify correctness
print("\nVerifying correctness...")
solver = z3.Solver()

# Add circuit constraints
for assertion in assertions:
    solver.add(assertion)

# Get wire variables
a_var = wires.lookup('a')
b_var = wires.lookup('b')
sum_var = wires.lookup('sum')

# Correctness property: sum == (a + b) mod 2^k
expected_sum = z3.Extract(k-1, 0, a_var + b_var)
correctness = (sum_var == expected_sum)

# Try to find counterexample
solver.add(z3.Not(correctness))
result = solver.check()

print(f"\nVerification result: {result}")
if result == z3.unsat:
    print(f"✓ The {k}-bit adder is CORRECT!")
    print("  No counterexample exists.")
elif result == z3.sat:
    print(f"✗ The {k}-bit adder is INCORRECT!")
    model = solver.model()
    print("  Counterexample:")
    print(f"    a = {model.eval(a_var)}")
    print(f"    b = {model.eval(b_var)}")
    print(f"    sum = {model.eval(sum_var)}")
    print(f"    expected = {model.eval(expected_sum)}")
else:
    print(f"? Verification UNKNOWN: {solver.reason_unknown()}")

## Benchmark: Verification Time vs. Bit Width

In [None]:
def verify_adder(k, verbose=False):
    """
    Verify that the k-bit adder correctly implements addition.
    Returns: (verification_result, time_taken_seconds)
    """
    # Create the adder circuit
    a, b, sum_out = make_adder(k)
    wb = pyrtl.working_block()
    
    # Convert to SMT and verify
    start_time = time.time()
    wires, ops, assertions = net_to_smt(wb)
    
    solver = z3.Solver()
    for assertion in assertions:
        solver.add(assertion)
    
    a_var = wires.lookup('a')
    b_var = wires.lookup('b')
    sum_var = wires.lookup('sum')
    
    expected_sum = z3.Extract(k-1, 0, a_var + b_var)
    solver.add(z3.Not(sum_var == expected_sum))
    
    result = solver.check()
    elapsed_time = time.time() - start_time
    
    if verbose:
        status = "✓ CORRECT" if result == z3.unsat else "✗ INCORRECT" if result == z3.sat else "? UNKNOWN"
        print(f"k={k:2d}: {status} ({elapsed_time:.4f}s)")
    
    return (result == z3.unsat), elapsed_time

In [None]:
# Run benchmark
k_values = [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 16, 20, 24, 32]

print("Benchmarking adder verification...")
print("="*50)

times = []
results = []

for k in k_values:
    result, elapsed = verify_adder(k, verbose=True)
    results.append(result)
    times.append(elapsed)

print("="*50)
print("\nBenchmark complete!")

In [None]:
# Plot results
plt.figure(figsize=(12, 5))

# Linear scale
plt.subplot(1, 2, 1)
plt.plot(k_values, times, 'o-', linewidth=2, markersize=8, color='#2E86AB')
plt.xlabel('Bit Width (k)', fontsize=12)
plt.ylabel('Verification Time (seconds)', fontsize=12)
plt.title('Adder Verification Time (Linear Scale)', fontsize=14)
plt.grid(True, alpha=0.3)

# Log scale
plt.subplot(1, 2, 2)
plt.plot(k_values, times, 'o-', linewidth=2, markersize=8, color='#A23B72')
plt.xlabel('Bit Width (k)', fontsize=12)
plt.ylabel('Verification Time (seconds)', fontsize=12)
plt.title('Adder Verification Time (Log Scale)', fontsize=14)
plt.yscale('log')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('adder_verification_benchmark.png', dpi=150, bbox_inches='tight')
print("\n✓ Plot saved to 'adder_verification_benchmark.png'")
plt.show()

In [None]:
# Summary table
import pandas as pd

df = pd.DataFrame({
    'Bit Width (k)': k_values,
    'Time (s)': [f"{t:.6f}" for t in times],
    'Result': ['✓ CORRECT' if r else '✗ INCORRECT' for r in results]
})

print("\n" + "="*50)
print("Summary Table")
print("="*50)
print(df.to_string(index=False))
print("="*50)

## Analysis

Questions to consider:
1. How does verification time grow with bit width k?
2. What is the largest k you can verify in reasonable time (< 10 seconds)?
3. Is the growth linear, polynomial, or exponential?
4. What factors affect verification time?