-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Broadcast array arg in binary ops if it's a valid leaf array type #51
Conversation
Is the "not device scalar" case useful? In the general case, |
If the broadcasting for any of the leaf arrays is illegal, we would see an error, which IMO is a reasonable user experience. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @alexfikl for taking a look. A few more thoughts below.
arraycontext/container/arithmetic.py
Outdated
@@ -343,6 +343,17 @@ def {fname}(arg1): | |||
else: | |||
raise ValueError(msg)""") | |||
gen(f"return cls({zip_init_args})") | |||
if _cls_has_array_context_attr: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should get its own flag, defaulted to the same value as bcast_number
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, see here: 93013c9.
arraycontext/container/arithmetic.py
Outdated
@@ -343,6 +343,17 @@ def {fname}(arg1): | |||
else: | |||
raise ValueError(msg)""") | |||
gen(f"return cls({zip_init_args})") | |||
if _cls_has_array_context_attr: | |||
gen("if isinstance(arg2," |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ordering here matters a great deal. (Maybe there should be a comment stating this.) These cases should be sorted from most likely to least. Is this the second-most-likely case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, this use-case would come just after host-scalars, as implemented in 93013c9.
arraycontext/container/arithmetic.py
Outdated
lambda a: {op_str.format("a", "arg2")}, | ||
arg1)) | ||
""") | ||
|
||
gen(f""" | ||
if {bool(outer_bcast_type_names)}: # optimized away | ||
if isinstance(arg2, {tup_str(outer_bcast_type_names)}): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should also work (and be tested) for the reverse operators.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a test in f0e0587.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
Just left a very small nitpick.
Co-authored-by: Alex Fikl <alexfikl@gmail.com>
replaced with only if Co-authored-by: Alex Fikl <alexfikl@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Some comments below.
arraycontext/context.py
Outdated
def get_array_types(self): | ||
""" | ||
Returns a :class:`tuple` of types that are valid base array classes | ||
the context can operate on. | ||
""" | ||
return () |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I think this could be an attribute, to avoid the function call overhead. (If you need non-global import to set it, set it in the constructor.)
- I think documenting this is probably OK.
- Would it be sensible to only allow one type here, to avoid the splat above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. 9255f22
Would it be sensible to only allow one type here, to avoid the splat above?
I think there is some value in keeping it tuple, default value is much nicer since the only intent is to perform type checking using it + also accounts for the slightest of chances that an array context might not have a single base array type.
arraycontext/container/arithmetic.py
Outdated
if bcast_actx_array_type: | ||
all_outer_bcast_type_names = ( | ||
outer_bcast_type_names | ||
+ ("*arg1.array_context.get_array_types()",)) | ||
else: | ||
all_outer_bcast_type_names = outer_bcast_type_names |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Move this closer to the usage site.
- Make this produce a tuple that's either empty or has the actx array type(s) in it. Add both together in the argument of
tup_str
.
(Same below for reverse.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I think 2040a8d makes it better.
Unsubscribing... @-mention or request review once it's ready for a look or needs attention. |
Co-authored-by: Andreas Kloeckner <andreask@illinois.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Just a few more minor things, then this is ready to go.
Unsubscribing... @-mention or request review once it's ready for a look or needs attention. |
Closes #49.