Skip to content

Commit

Permalink
Fix bugs in matrix operations in binary computation.
Browse files Browse the repository at this point in the history
  • Loading branch information
mkskeller committed Apr 9, 2024
1 parent bec7265 commit 74baee3
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
13 changes: 8 additions & 5 deletions Compiler/GC/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,8 +776,9 @@ def load_mem(cls, address):
return cls.from_vec(sb.bit_compose(
sbit.load_mem(address + i + j * n) for j in range(size))
for i in range(n))
if not isinstance(address, int) and len(address) == n:
return cls.from_vec(sbit.load_mem(x) for x in address)
if not isinstance(address, int):
v = [sbit.load_mem(x, size=n).v[0] for x in address]
return cls(v)
else:
return cls.from_vec(sbit.load_mem(address + i)
for i in range(n))
Expand All @@ -787,10 +788,12 @@ def store_in_mem(self, address):
if not util.is_constant(x):
size = max(size, x.n)
v = [sbits.get_type(size).conv(x) for x in self.v]
if not isinstance(address, int) and len(address) == n:
assert max_n == 1
if not isinstance(address, int) and len(address) != 1:
v = self.elements()
assert len(v) == len(address)
for x, y in zip(v, address):
x.store_in_mem(y)
for i, xx in enumerate(x.bit_decompose(n)):
xx.store_in_mem(y + i)
else:
assert isinstance(address, int) or len(address) == 1
for i in range(n):
Expand Down
8 changes: 5 additions & 3 deletions Compiler/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6489,7 +6489,7 @@ def assign_slice_vector(self, slice, vector):

def get_part_size(self):
assert self.value_type.n_elements() == 1
return reduce(operator.mul, self.sizes[1:])
return reduce(operator.mul, self.sizes[1:]) * self.value_type.mem_size()

def get_slice_addresses(self, slice, part_size=None):
part_size = part_size or self.get_part_size()
Expand Down Expand Up @@ -6992,7 +6992,8 @@ def get_column(self, index):
:param index: regint/cint/int
"""
assert self.value_type.n_elements() == 1
addresses = regint.inc(self.sizes[0], self.address + index,
addresses = regint.inc(self.sizes[0], self.address + \
index * self.value_type.mem_size(),
self.get_part_size())
return self.value_type.load_mem(addresses)

Expand All @@ -7003,7 +7004,8 @@ def set_column(self, index, vector):
:param vector: short enought vector of compatible type
"""
assert self.value_type.n_elements() == 1
addresses = regint.inc(self.sizes[0], self.address + index,
addresses = regint.inc(self.sizes[0], self.address + \
index * self.value_type.mem_size(),
self.get_part_size())
self.value_type.conv(vector).store_in_mem(addresses)

Expand Down

0 comments on commit 74baee3

Please sign in to comment.