@@ -1684,7 +1684,7 @@ def mem_space_to_kind(mem_space: MemorySpace) -> str:
1684
1684
assert False , "unreachable"
1685
1685
1686
1686
1687
- @cache (max_size = 4096 , trace_context_in_key = False )
1687
+ @cache (max_size = 4096 , trace_context_in_key = True )
1688
1688
def update_aval_with_sharding (aval , sharding ):
1689
1689
if isinstance (sharding , NamedSharding ):
1690
1690
return aval .update (
@@ -2082,6 +2082,17 @@ def modify_spec_for_auto_manual(spec, mesh) -> P:
2082
2082
if mesh ._name_to_type [u ] == AxisType .Explicit }
2083
2083
return P (* new_spec , unreduced = new_unreduced , reduced = new_reduced )
2084
2084
2085
+ def remove_size_one_mesh_axis (spec , mesh ) -> P :
2086
+ new_spec = [] # type: ignore
2087
+ for s in spec :
2088
+ if s is None :
2089
+ new_spec .append (s ) # type: ignore
2090
+ elif isinstance (s , tuple ):
2091
+ new_spec .append (tuple (i for i in s if mesh .shape [i ] != 1 ))
2092
+ else :
2093
+ new_spec .append (None if mesh .shape [s ] == 1 else s ) # type: ignore
2094
+ return P (* new_spec , unreduced = spec .unreduced , reduced = spec .reduced )
2095
+
2085
2096
def _maybe_modify_sharding (sharding , ndim ):
2086
2097
if len (sharding .spec ) == 0 or all (s is None for s in sharding .spec ):
2087
2098
out = sharding
@@ -2090,6 +2101,8 @@ def _maybe_modify_sharding(sharding, ndim):
2090
2101
else :
2091
2102
out = sharding .update (spec = modify_spec_for_auto_manual (
2092
2103
sharding .spec , sharding .mesh ))
2104
+ if config .remove_size_one_mesh_axis_from_type .value :
2105
+ out = out .update (spec = remove_size_one_mesh_axis (out .spec , out .mesh ))
2093
2106
if len (out .spec ) != ndim :
2094
2107
out = _make_lengths_same (out , ndim )
2095
2108
return out
@@ -2108,7 +2121,7 @@ def _check_divisibility(sharding, shape):
2108
2121
f" { size } times, but does not evenly divide the dimension size { sh } ."
2109
2122
f" Got shape: { shape } and sharding { sharding } " )
2110
2123
2111
- @cache (max_size = 4096 , trace_context_in_key = False )
2124
+ @cache (max_size = 4096 , trace_context_in_key = True )
2112
2125
def get_sharding (sharding , shape ):
2113
2126
"""Modifies and checks the sharding.
2114
2127
0 commit comments