Skip to content

Add SOFTPLUS_FWD and SWISH_FWD pointwise ops#355

Merged
rsuderman merged 2 commits intoiree-org:mainfrom
rsuderman:softplus_swish
Apr 20, 2026
Merged

Add SOFTPLUS_FWD and SWISH_FWD pointwise ops#355
rsuderman merged 2 commits intoiree-org:mainfrom
rsuderman:softplus_swish

Conversation

@rsuderman
Copy link
Copy Markdown
Contributor

Emit torch.aten.softplus with configurable beta and threshold attributes (defaulting to 1.0 and 20.0) and torch.aten.silu for swish. Adds sample and lit tests for both ops.

rsuderman and others added 2 commits April 17, 2026 15:21
Emit torch.aten.softplus with configurable beta and threshold attributes
(defaulting to 1.0 and 20.0) and torch.aten.silu for swish. Adds sample
and lit tests for both ops.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Rob Suderman <rob.suderman@gmail.com>
Copy link
Copy Markdown
Contributor

@keshavvinayak01 keshavvinayak01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@rsuderman rsuderman merged commit aeed46f into iree-org:main Apr 20, 2026
10 checks passed
Copy link
Copy Markdown
Member

@sjain-stanford sjain-stanford left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks compat with cudnn / hipdnn's swish for beta != 1. The swish function is not a SILU when beta !=1.

Swish: https://en.wikipedia.org/wiki/Swish_function
https://docs.pytorch.org/docs/stable/generated/torch.nn.SiLU.html

Requesting changes (in a follow-on since this landed before I could review):
Include setSwishBeta and getSwishBeta methods on the attribute (with default 1.0f). Then either 1) reject non-default beta in pre-validate, or 2) emit aten.sigmoid + aten.mul for beta != 1, and aten.silu otherwise.

In general, please wait for a comment from either me or one of the other contributors who is familiar with adding new ops (like @IanWood1).

double xD = static_cast<double>(x);
// SOFTPLUS(x) = log(1 + exp(x)); for x > threshold, result ~= x.
// Matches torch.aten.softplus with beta=1.0 and threshold=20.0.
constexpr double threshold = 20.0;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use k-prefix convention for constants.. kSoftplusThreshold

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The k prefix only applies to global constants. We could change the clang-tidy rules to apply to local ones, too. There are some other cases without the prefix. e.g.

constexpr size_t channelsIdx = 1;

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants