Skip to content

Commit

Permalink
cleanup + bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
Hardik authored and Hardik committed Oct 3, 2018
1 parent 5ba9354 commit b8afedf
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 50 deletions.
9 changes: 9 additions & 0 deletions dnnweaver2/graph.py
Expand Up @@ -132,6 +132,15 @@ def get_op_name(self, name, op_type):
self.op_type_counter[op_type] += 1
return name

def get_ops(self):
total_ops = {}
for opname, op in self.op_registry.iteritems():
for op_type, num_ops in op.get_ops().iteritems():
if op_type not in total_ops:
total_ops[op_type] = 0
total_ops[op_type] += num_ops
return total_ops

@contextmanager
def name_scope(self, name):
current_scope = self.current_scope
Expand Down
6 changes: 5 additions & 1 deletion dnnweaver2/scalar/dtypes.py
Expand Up @@ -33,10 +33,14 @@ def __init__(self, exp_bits):
self.bits = 2
self.exp_bits = exp_bits

class Binary(Dtype):
class Binary(FixedPoint):
def __init__(self):
self.bits = 1
self.op_str = 'Binary'
self.frac_bits = 0
self.int_bits = 1
def __str__(self):
return 'Binary'

class CustomFloat(Dtype):
def __init__(self, bits, exp_bits):
Expand Down
6 changes: 6 additions & 0 deletions dnnweaver2/scalar/ops.py
Expand Up @@ -23,6 +23,7 @@ def __init__(self):
self.CmpOp = {}
self.AddOp = {}
self.SubOp = {}
self.RshiftOp = {}
def MUL(self, dtypes):
assert len(dtypes) == 2
if dtypes not in self.MulOp:
Expand Down Expand Up @@ -53,6 +54,11 @@ def SUB(self, dtypes):
if dtypes not in self.SubOp:
self.SubOp[dtypes] = ScalarOp('Subtract', dtypes)
return self.SubOp[dtypes]
def RSHIFT(self, dtypes):
assert isinstance(dtypes, Dtype), 'Got Dtypes: {}'.format(dtypes)
if dtypes not in self.RshiftOp:
self.RshiftOp[dtypes] = ScalarOp('Rshift', dtypes)
return self.RshiftOp[dtypes]


Ops = ScalarOpTypes()

0 comments on commit b8afedf

Please sign in to comment.