diff --git a/README.md b/README.md index e6985af..e6f119f 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_f m = Model(...) # optional: filter modules from being eligible for float8 conversion -def module_filter_fn(fqn: str, mod: torch.nn.Module): +def module_filter_fn(mod: torch.nn.Module, fqn: str): # don't convert the output module if fqn == "output": return False @@ -91,9 +91,9 @@ from float8_experimental.float8_linear import TensorScalingType # create model m = Model(...) -# optional: configure for compatibility with FSDP. Note that workarounds +# optional: configure for compatibility with FSDP. Note that workarounds # gated with config.enable_amax_init and -# config.enable_pre_and_post_forward are needed for +# config.enable_pre_and_post_forward are needed for # autocast + compile + FSDP + float8 to work from float8_experimental import Float8LinearConfig, TensorScalingType, Float8TensorCastConfig config = Float8LinearConfig( diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 41365df..328304c 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -60,7 +60,7 @@ def swap_linear_layers( module: nn.Module, from_float_func: Callable[[nn.Linear], nn.Linear], *, - module_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None, + module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None, ) -> Optional[nn.Module]: """ Generic function to swap linear layers in a module with a new type of linear layer. @@ -74,13 +74,13 @@ def swap_linear_layers( from_float_func: Function that accepts a linear layer and returns a new type of linear layer. module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that that pass the filter function will be swapped. The inputs to the - filter function are the FQN and module instance. + filter function are the module instance, and the FQN. Returns: nn.Module: The modified module with swapped linear layers. """ if isinstance(module, nn.Linear) and ( - module_filter_fn is None or module_filter_fn("", module) + module_filter_fn is None or module_filter_fn(module, "") ): if len(list(module.children())) > 0: raise AssertionError( @@ -109,9 +109,7 @@ def post_order_traversal( post_order_traversal(child_module, new_fqn, module) if isinstance(module, nn.Linear) and ( - # linear_layer_filter is None or linear_layer_filter(module) - module_filter_fn is None - or module_filter_fn(cur_fqn, module) + module_filter_fn is None or module_filter_fn(module, cur_fqn) ): assert ( parent_module is not None @@ -127,7 +125,7 @@ def post_order_traversal( def swap_linear_with_float8_linear( module: nn.Module, *, - module_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None, + module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None, config: Float8LinearConfig = None, ) -> Optional[nn.Module]: """ @@ -137,7 +135,7 @@ def swap_linear_with_float8_linear( module: Module to modify. module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that that pass the filter function will be swapped. The inputs to the - filter function are the FQN and module instance. + filter function are the module instance and the FQN. config (Float8LinearConfig): configuration for conversion to float8 Returns: diff --git a/float8_experimental/inference.py b/float8_experimental/inference.py index 1742317..e5b17a6 100644 --- a/float8_experimental/inference.py +++ b/float8_experimental/inference.py @@ -213,7 +213,7 @@ def quantize_to_float8( module: nn.Module, quant_config: QuantConfig, *, - module_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None, + module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None, use_fast_accum: bool = True, ) -> Optional[nn.Module]: """ @@ -228,7 +228,7 @@ def quantize_to_float8( quant_config (QuantConfig): Quantization configuration for Float8 conversion. module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that that pass the filter function will be swapped. The inputs to the - filter function are the FQN and module instance. + filter function are the module instance and the FQN. use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True. Returns: diff --git a/test/test_base.py b/test/test_base.py index 10d53b6..644a690 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -649,7 +649,7 @@ def __init__(self, dim: int): size_limit = 32 - def module_filter_fn(fqn, mod): + def module_filter_fn(mod, fqn): return ( mod.in_features >= size_limit and mod.out_features >= size_limit @@ -682,7 +682,7 @@ def __init__(self, dim: int): self.lin2 = nn.Linear(4 * dim, dim) model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3)) - module_filter_fn = lambda fqn, mod: fqn not in [ + module_filter_fn = lambda mod, fqn: fqn not in [ "0.lin2", "2.lin1", ]