-
Notifications
You must be signed in to change notification settings - Fork 86
[WIP] Unblock bloom model #439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f60393a
12a9b47
f176db9
9397ce3
6e8f0c8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,6 +57,8 @@ def aten_acosh(self: TFloat) -> TFloat: | |
@torch_op("aten::add") | ||
def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: | ||
# add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor | ||
# TODO(titaiwang): Delete this when we have type promotion | ||
other = op.CastLike(other, self) | ||
alpha = op.CastLike(alpha, other) | ||
other = op.Mul(other, alpha) | ||
return op.Add(self, other) | ||
|
@@ -765,10 +767,18 @@ def aten_bitwise_left_shift(self: TInt, other: TInt) -> TInt: | |
@torch_op("aten::bitwise_not") | ||
def aten_bitwise_not(self: TInt) -> TInt: | ||
# bitwise_not(Tensor self) -> Tensor | ||
|
||
# TODO(titaiwang): Support BOOL input | ||
return op.BitwiseNot(self) | ||
|
||
|
||
@torch_op("aten::bitwise_not", overload=True) | ||
def aten_bitwise_not_bool(self: BOOL) -> BOOL: | ||
# bitwise_not(Tensor self) -> Tensor | ||
# FIXME(titaiwang): This is a hack to get around the fact that we don't have op.BitwiseNot supporting bool now. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We will have to go with this I think. Not a hack as we need a dispatcher to handle different dtypes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need if for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh we cannot do if because the dtypes are different, and we don’t know the dtype within a function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So even trace only won’t work for us here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So how should we know the dtype beforehand? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe there are only certain kinds of ops can be input of BitWise operations that we can do the processing in those ops? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Guess we can forget it... @_onnx_symbolic("aten::bitwise_not")
@_beartype.beartype
def bitwise_not(g: jit_utils.GraphContext, input):
if not symbolic_helper._is_bool(input):
raise errors.SymbolicValueError(
"ONNX export does NOT support exporting bitwise Not "
"for non-boolean input values",
input,
)
return g.op("Not", input) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
the exporter knows about dtypes so it can choose which function to use |
||
# We should remove this once we have a proper implementation. | ||
return op.Not(self) | ||
|
||
|
||
@torch_op("aten::bitwise_or") | ||
def aten_bitwise_or(self: TInt, other: TInt) -> TInt: | ||
# bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor | ||
|
@@ -3146,9 +3156,8 @@ def aten_margin_ranking_loss( | |
@torch_op("aten::masked_fill") | ||
def aten_masked_fill(self: TTensor, mask: BOOL, value: TTensor) -> TTensor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you create another PR for the non-hacks we identified, including changes to mask fill and bitwise not (both versions) so that we can merge into main? The rest can stay here for reference but we are probably not going to merge? |
||
# masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor | ||
mask_cast = op.Cast(mask, to=BOOL.dtype) | ||
value_cast = op.CastLike(value, self) | ||
return op.Where(mask_cast, value_cast, self) | ||
return op.Where(mask, value_cast, self) | ||
|
||
|
||
def aten_masked_scatter(self: TensorType, mask: TensorType, source: TensorType) -> TensorType: | ||
|
@@ -3648,7 +3657,8 @@ def aten_msort(self: TensorType) -> TensorType: | |
@torch_op("aten::mul") | ||
def aten_mul(self: TReal, other: TReal) -> TReal: | ||
# mul.Tensor(Tensor self, Tensor other) -> Tensor | ||
|
||
# TODO(titaiwang): Delete this when we have type promotion | ||
other = op.CastLike(other, self) | ||
return op.Mul(self, other) | ||
|
||
|
||
|
@@ -4697,6 +4707,7 @@ def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16: | |
def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: | ||
# rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor | ||
alpha = op.CastLike(alpha, self) | ||
|
||
return op.Sub(other, op.Mul(self, alpha)) | ||
|
||
|
||
|
@@ -4848,7 +4859,8 @@ def aten_slice( | |
else: | ||
step = op.Constant(value_ints=[1]) | ||
|
||
return op.Slice(self, start, end, dim, step) | ||
# TODO(titaiwang): Delete this Cast when we have type promotion | ||
return op.Cast(op.Slice(self, start, end, dim, step), to=FLOAT.dtype) | ||
|
||
|
||
def aten_slice_backward( | ||
|
@@ -5017,9 +5029,10 @@ def aten_stft( | |
@torch_op("aten::sub") | ||
def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: | ||
# sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor | ||
# TODO(titaiwang): Delete this when we have type promotion | ||
other = op.CastLike(other, self) | ||
alpha = op.CastLike(alpha, other) | ||
other = op.Mul(other, alpha) | ||
|
||
return op.Sub(self, other) | ||
|
||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.