Skip to content

Commit

Permalink
Avoid lhs_shape & rhs_shape parameters in conv_general_dilated
Browse files Browse the repository at this point in the history
They're removed as of google/jax#14211

PiperOrigin-RevId: 508756760
  • Loading branch information
Jake VanderPlas authored and romanngg committed Feb 10, 2023
1 parent 0272225 commit adaf9e2
Showing 1 changed file with 5 additions and 11 deletions.
16 changes: 5 additions & 11 deletions neural_tangents/_src/utils/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,23 +503,17 @@ def _conv_general_dilated_e(
trimmed_invals: List[ShapedArray],
trimmed_cts_in: ShapedArray
) -> Dict[str, Any]:
# `conv_general_dilated` has `lhs_shape` and `rhs_shape` arguments that are
# for some reason not inferred from the `lhs` and `rhs` themselves.
# TODO(romann): ask JAX why these are there.
lhs, rhs = trimmed_invals
dn = params['dimension_numbers']

if (params['feature_group_count'] == params['lhs_shape'][dn[0][1]] and
params['feature_group_count'] == params['rhs_shape'][dn[1][0]]):
if (params['feature_group_count'] > lhs.shape[dn[0][1]] or
params['feature_group_count'] > rhs.shape[dn[1][0]]):
params['feature_group_count'] = 1

if (params['batch_group_count'] == params['rhs_shape'][dn[1][0]] and
params['batch_group_count'] == params['lhs_shape'][dn[0][0]]):
if (params['batch_group_count'] > rhs.shape[dn[1][0]] or
params['batch_group_count'] > lhs.shape[dn[0][0]]):
params['batch_group_count'] = 1

lhs, rhs = trimmed_invals
params['lhs_shape'] = lhs.shape
params['rhs_shape'] = rhs.shape

return params

STRUCTURE_RULES[lax.conv_general_dilated_p] = _conv_general_dilated_s
Expand Down

0 comments on commit adaf9e2

Please sign in to comment.