Skip to content

[Fix] eliminate_two_op_chain breaks DAG when second op has multiple outputs (#70)#71

Open
ArthurPendragn wants to merge 1 commit into
deem-data:mainfrom
ArthurPendragn:twoOpChainFix-70
Open

[Fix] eliminate_two_op_chain breaks DAG when second op has multiple outputs (#70)#71
ArthurPendragn wants to merge 1 commit into
deem-data:mainfrom
ArthurPendragn:twoOpChainFix-70

Conversation

@ArthurPendragn
Copy link
Copy Markdown
Contributor

  • replace_op_in_outputs now rewires all consumers of op2 to x, correctly handling fan-out on the second op in the chain
  • add tests covering op2 fan-out, root-safe elimination, and combined fan-out on both op1 input and op2

Closes #70

@codecov
Copy link
Copy Markdown

codecov Bot commented May 17, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

Flag Coverage Δ
unittests 92.50% <100.00%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
stratum/optimizer/_numeric_rewrites.py 100.00% <100.00%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown
Collaborator

@e-strauss e-strauss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @ArthurPendragn , the fix looks good, also thanks for adding the additional test cases. I only have some minor comments.

else:
x.outputs = []
replace_op_in_outputs(op2, x)
x.outputs = [out for out in x.outputs if out is not op1]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice finding, i forgot about the util method. can we change the order of line 21 and 22?
so first eliminate op1 as output of x, than update op2 outputs (and implicit add outputs to x), since we should do the copy of the list before we append op2 outputs

Comment on lines +38 to +39
df_ops = [op for op in linearized_dag if isinstance(op, ValueOp) and op.value == 1.0]
a_op = next(op for op in df_ops if len(op.outputs) > 1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think is okay to use simply linearized_dag[0] here since the ValueOp should be always the first operator, since every other op depends on it. maybe with an assertion for isinstance(op, ValueOp) and op.value == 1.0]

Comment on lines +71 to +72
df_ops = [op for op in linearized_dag if isinstance(op, ValueOp) and op.value == 1.0]
a_op = df_ops[0]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

"""
Scenario:
a -> [log, d]
log -> [exp] (root)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the exp op is not root here, isn't it here:

exp -> BinOp (root)

or am I wrong?

Comment on lines +125 to +126
df_ops = [op for op in linearized_dag if isinstance(op, ValueOp) and op.value == 1.0]
a_op = next(op for op in df_ops if len(op.outputs) > 0)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] eliminate_two_op_chain breaks DAG when second op has multiple outputs

2 participants