Skip to content

Commit

Permalink
implement some existed operators for TF (#55)
Browse files Browse the repository at this point in the history
* impl some forward funcs for tf exist in torch

* fix: wrong gpu device name

* fix typo

* use list_logical_devices which is valid for tf.device
  • Loading branch information
Co1lin committed Sep 26, 2022
1 parent a8307f2 commit 1fc1c1c
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 1 deletion.
6 changes: 6 additions & 0 deletions nnsmith/abstract/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1782,6 +1782,12 @@ class ReduceMean(ReduceBase):
out_dtypes = [(i,) for i in DTYPE_FLOATS]


@mark_materialize("core")
class ReduceProd(ReduceBase):
in_dtypes = [(i,) for i in DTYPE_NON_BOOLS]
out_dtypes = [(i,) for i in DTYPE_NON_BOOLS]


@mark_materialize("core")
class ArgMin(ReduceBase):
in_dtypes = [(i,) for i in DTYPE_NON_BOOLS]
Expand Down
2 changes: 1 addition & 1 deletion nnsmith/materialize/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,6 @@ def device(self) -> tf.device:
class TFModelGPU(TFModel):
@property
def device(self) -> tf.device:
gpus = tf.config.list_physical_devices("GPU")
gpus = tf.config.list_logical_devices("GPU")
assert gpus, "No GPU available"
return tf.device(gpus[0].name)
82 changes: 82 additions & 0 deletions nnsmith/materialize/tensorflow/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,3 +293,85 @@ def forward_fn(op: NHWCConv2d):
dtype=op.input_like[0].dtype.tensorflow(),
autocast=False,
)


@operator_impl(Squeeze)
def forward_fn(op: Squeeze):
if op.extra_attrs["reduce_dim"] is not None:
return lambda x: tf.squeeze(x, axis=op.extra_attrs["reduce_dim"])
return lambda x: tf.squeeze(x)


@operator_impl(ReduceSum)
def forward_fn(op: ReduceSum):
if op.extra_attrs["reduce_dim"] is not None:
return lambda x: tf.math.reduce_sum(x, axis=op.extra_attrs["reduce_dim"])
return lambda x: tf.math.reduce_sum(x)


@operator_impl(ReduceMin)
def forward_fn(op: ReduceMin):
if op.extra_attrs["reduce_dim"] is not None:
return lambda x: tf.math.reduce_min(x, axis=op.extra_attrs["reduce_dim"])
return lambda x: tf.math.reduce_min(x)


@operator_impl(ReduceMax)
def forward_fn(op: ReduceMax):
if op.extra_attrs["reduce_dim"] is not None:
return lambda x: tf.math.reduce_max(x, axis=op.extra_attrs["reduce_dim"])
return lambda x: tf.math.reduce_max(x)


@operator_impl(ReduceMean)
def forward_fn(op: ReduceMean):
if op.extra_attrs["reduce_dim"] is not None:
return lambda x: tf.math.reduce_mean(x, axis=op.extra_attrs["reduce_dim"])
return lambda x: tf.math.reduce_mean(x)


@operator_impl(ReduceProd)
def forward_fn(op: ReduceProd):
if op.extra_attrs["reduce_dim"] is not None:
return lambda x: tf.math.reduce_prod(x, axis=op.extra_attrs["reduce_dim"])
return lambda x: tf.math.reduce_prod(x)


@operator_impl(ArgMin)
def forward_fn(op: ArgMin):
if op.extra_attrs["reduce_dim"] is not None:
return lambda x: tf.math.argmin(x, axis=op.extra_attrs["reduce_dim"])
return lambda x: tf.math.argmin(x)


@operator_impl(ArgMin)
def forward_fn(op: ArgMin):
if op.extra_attrs["reduce_dim"] is not None:
return lambda x: tf.math.argmax(x, axis=op.extra_attrs["reduce_dim"])
return lambda x: tf.math.argmax(x)


@operator_impl(Tril)
def forward_fn(op: Tril):
return lambda x: tf.experimental.numpy.tril(x, k=op.diagonal)


@operator_impl(Triu)
def forward_fn(op: Triu):
return lambda x: tf.experimental.numpy.triu(x, k=op.diagonal)


@operator_impl(Concat)
def forward_fn(op: Concat):
axis = op.extra_attrs["axis"]
return lambda *args: tf.concat(args, axis=axis)


@operator_impl(Cast)
def forward_fn(op: Cast):
return lambda x: tf.cast(x, dtype=op.extra_attrs["to"].tensorflow())


@operator_impl(MatMul)
def forward_fn(op: MatMul):
return tf.linalg.matmul

0 comments on commit 1fc1c1c

Please sign in to comment.