Skip to content

Commit

Permalink
Update implementation for int_vector branch of bit_vector
Browse files Browse the repository at this point in the history
  • Loading branch information
leonardt committed Aug 2, 2018
1 parent 87e4a6a commit 1d4dd86
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 10 deletions.
16 changes: 15 additions & 1 deletion fault/python_simulator_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .target import Target
from magma.simulator.python_simulator import PythonSimulator
from fault.array import Array
from bit_vector import BitVector


def convert_value(val):
Expand Down Expand Up @@ -30,7 +31,20 @@ def __check_value(self, port, expected_val):
f"is {expected_val}")
sim_val = self._simulator.get_value(port)
expected_val = convert_value(expected_val)
assert sim_val == expected_val
assert self.__check(sim_val, expected_val), \
f"Expected {expected_val}, got {sim_val}"

def __check(self, sim_val, expected_val):
if expected_val is None:
# Expected None, skipping
return True
if isinstance(sim_val, list):
if isinstance(expected_val, BitVector):
return expected_val.__class__(sim_val) == expected_val
assert isinstance(expected_val, list)
return all(self.__check(x, y)
for x, y in zip(sim_val, expected_val))
return sim_val == expected_val

def __parse_tv(self, tv):
inputs = {}
Expand Down
14 changes: 10 additions & 4 deletions fault/test_vectors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from magma import BitType, ArrayType, SIntType
from magma.simulator.python_simulator import PythonSimulator
from magma.bitutils import seq2int
from bit_vector import BitVector
from bit_vector import BitVector, SIntVector
from inspect import signature
from itertools import product
import pytest
Expand Down Expand Up @@ -81,7 +81,7 @@ def generate_simulator_test_vectors(circuit, input_ranges=None,
input_range = range(start, end)
else:
input_range = input_ranges[i]
args.append([BitVector(x, num_bits=num_bits, signed=True)
args.append([SIntVector(x, num_bits=num_bits)
for x in input_range])
else:
if input_ranges is None:
Expand All @@ -101,7 +101,10 @@ def generate_simulator_test_vectors(circuit, input_ranges=None,
for i, (name, port) in enumerate(circuit.interface.ports.items()):
# circuit defn output is an input to the idefinition
if port.isoutput():
testv[i] = test[j].as_int()
if isinstance(port, SIntType):
testv[i] = test[j].as_sint()
else:
testv[i] = test[j].as_uint()
val = test[j].as_bool_list()
if len(val) == 1:
val = val[0]
Expand All @@ -114,7 +117,10 @@ def generate_simulator_test_vectors(circuit, input_ranges=None,
# circuit defn input is an input of the definition
if port.isinput():
val = simulator.get_value(getattr(circuit, name))
val = BitVector(val, signed=isinstance(port, SIntType)).as_int()
if isinstance(port, SIntType):
val = SIntVector(val).as_sint()
else:
val = BitVector(val).as_uint()
testv[i] = val

tests.append(testv)
Expand Down
5 changes: 3 additions & 2 deletions fault/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,16 @@ def __init__(self, circuit, clock=None):

def get_initial_value(self, port):
if isinstance(port, m._BitType):
return BitVector(None, 1)
return None
elif isinstance(port, m.ArrayType):
return self.get_array_val(port)
else:
raise NotImplementedError(port)

def get_array_val(self, arr, val=None):
if isinstance(arr.T, m._BitKind):
val = BitVector(val, len(arr))
if val is not None:
val = BitVector(val, len(arr))
elif isinstance(arr, m.ArrayType) and isinstance(arr.T, m.ArrayKind):
val = Array([self.get_array_val(x) for x in arr], len(arr))
else:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
-e git://github.com/leonardt/bit_vector.git@int_vector#egg=bit_vector
-e git://github.com/phanrahan/magma.git#egg=magma
-e git://github.com/phanrahan/mantle.git#egg=mantle
6 changes: 3 additions & 3 deletions tests/test_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_tester_basic():
assert tester.test_vectors == [[BitVector(0, 1), BitVector(0, 1)]]
tester.eval()
assert tester.test_vectors == [[BitVector(0, 1), BitVector(0, 1)],
[BitVector(0, 1), BitVector(0, 1)]]
[BitVector(0, 1), None]]


def test_tester_clock():
Expand All @@ -22,7 +22,7 @@ def test_tester_clock():
tester.poke(circ.I, 0)
tester.expect(circ.O, 0)
assert tester.test_vectors == [
[BitVector(0, 1), BitVector(0, 1), BitVector(None, 1)]
[BitVector(0, 1), BitVector(0, 1), None]
]
tester.poke(circ.CLK, 0)
assert tester.test_vectors == [
Expand All @@ -31,7 +31,7 @@ def test_tester_clock():
tester.step()
assert tester.test_vectors == [
[BitVector(0, 1), BitVector(0, 1), BitVector(0, 1)],
[BitVector(0, 1), BitVector(0, 1), BitVector(1, 1)]
[BitVector(0, 1), None, BitVector(1, 1)]
]


Expand Down

0 comments on commit 1d4dd86

Please sign in to comment.