From f3296513b81244ae7b5ff06e37bf86c3f79adf6e Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Wed, 7 Jul 2021 09:34:12 +0200 Subject: [PATCH] [TEST] Reduce run time (#199) - make automated settings' last linear layers non-differentiable to save time - update new test suite, as it assumes all model parameters are differentiable - update `ggn_vector_product` to work with models that contain non-differentiable parameters and fully-document `ggnvp.py` --- * [DOC] Fully-document GGNVP and ignore non-differentiable parameters * [TEST] Make automated settings' last linear layers non-differentiable The final layers can have large that make the computation of second-order quantities expensive. Disabling their `requires_grad` speeds up the tests. * [TEST] Check non-differentiable parameters while collecting results * [DOC] Fully document automated test setting helpers --- backpack/hessianfree/ggnvp.py | 80 ++++---- fully_documented.txt | 3 + test/extensions/automated_settings.py | 179 ++++++++++-------- test/extensions/implementation/autograd.py | 17 +- test/extensions/implementation/backpack.py | 59 ++---- test/extensions/problem.py | 38 ++++ .../secondorder/sqrt_ggn/test_sqrt_ggn.py | 4 +- 7 files changed, 213 insertions(+), 167 deletions(-) diff --git a/backpack/hessianfree/ggnvp.py b/backpack/hessianfree/ggnvp.py index 92c08363..165aff25 100644 --- a/backpack/hessianfree/ggnvp.py +++ b/backpack/hessianfree/ggnvp.py @@ -1,42 +1,56 @@ -from .hvp import hessian_vector_product -from .lop import L_op -from .rop import R_op - - -def ggn_vector_product(loss, output, model, v): +"""Autodiff-only matrix-free multiplication by the generalized Gauss-Newton/Fisher.""" +from typing import List, Tuple + +from torch import Tensor +from torch.nn import Module +from torch.nn.parameter import Parameter + +from backpack.hessianfree.hvp import hessian_vector_product +from backpack.hessianfree.lop import L_op +from backpack.hessianfree.rop import R_op + + +def ggn_vector_product( + loss: Tensor, output: Tensor, model: Module, v: List[Tensor] +) -> Tuple[Tensor]: + """Multiply a vector with the generalized Gauss-Newton/Fisher. + + Note: + ``G v = J.T @ H @ J @ v`` where ``J`` is the Jacobian of ``output`` w.r.t. + ``model``'s trainable parameters and `H` is the Hessian of `loss` w.r.t. + ``output``. ``v`` is the flattened and concatenated version of the passed + list of vectors. + + Args: + loss: Scalar tensor that represents the loss. + output: Model output. + model: The model. + v: Vector specified as list of tensors matching the trainable parameters. + + Returns: + GGN-vector product in list format, i.e. as list that matches the sizes + of trainable model parameters. """ - Multiplies the vector `v` with the Generalized Gauss-Newton, - `ggn_v = J.T @ H @ J @ v` - - where `J` is the Jacobian of `output` w.r.t. `model.parameters()` - and `H` is the Hessian of `loss` w.r.t. `output`. + return ggn_vector_product_from_plist( + loss, output, [p for p in model.parameters() if p.requires_grad], v + ) - Example usage: - ``` - X, Y = data() - model = torch.nn.Linear(784, 10) - lossfunc = torch.nn.CrossEntropyLoss() - output = model(X) - loss = lossfunc(output, Y) +def ggn_vector_product_from_plist( + loss: Tensor, output: Tensor, plist: List[Parameter], v: List[Tensor] +) -> Tuple[Tensor]: + """Multiply a vector with a sub-block of the generalized Gauss-Newton/Fisher. - v = list([torch.randn_like(p) for p in model.parameters]) + Args: + loss: Scalar tensor that represents the loss. + output: Model output. + plist: List of trainable parameters whose GGN block is used for multiplication. + v: Vector specified as list of tensors matching the sizes of ``plist``. - GGNv = ggn_vector_product(loss, output, model, v) - ``` - - Parameters: - ----------- - loss: torch.Tensor - output: torch.Tensor - model: torch.nn.Module - v: [torch.Tensor] - List of tensors matching the sizes of model.parameters() + Returns: + GGN-vector product in list format, i.e. as list that matches the sizes of + ``plist``. """ - return ggn_vector_product_from_plist(loss, output, list(model.parameters()), v) - - -def ggn_vector_product_from_plist(loss, output, plist, v): Jv = R_op(output, plist, v) HJv = hessian_vector_product(loss, output, Jv) JTHJv = L_op(output, plist, HJv) diff --git a/fully_documented.txt b/fully_documented.txt index 15528fdb..f3542299 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -39,10 +39,13 @@ backpack/extensions/secondorder/diag_hessian/conv2d.py backpack/extensions/secondorder/diag_hessian/conv3d.py backpack/extensions/secondorder/sqrt_ggn/ +backpack/hessianfree/ggnvp.py + backpack/utils/linear.py backpack/utils/subsampling.py backpack/utils/__init__.py +test/extensions/automated_settings.py test/extensions/problem.py test/extensions/test_backprop_extension.py test/extensions/firstorder/firstorder_settings.py diff --git a/test/extensions/automated_settings.py b/test/extensions/automated_settings.py index 58fbec8a..f2334c51 100644 --- a/test/extensions/automated_settings.py +++ b/test/extensions/automated_settings.py @@ -1,35 +1,45 @@ +"""Contains helpers to create CNN test cases.""" from test.core.derivatives.utils import classification_targets +from typing import Any, Tuple, Type -import torch +from torch import Tensor, rand +from torch.nn import Conv2d, CrossEntropyLoss, Flatten, Linear, Module, ReLU, Sequential -### -# Helpers -### +def set_requires_grad(model: Module, new_requires_grad: bool) -> None: + """Set the ``requires_grad`` attribute of the model parameters. -def make_simple_act_setting(act_cls, bias): + Args: + model: Network or layer. + new_requires_grad: New value for ``requires_grad``. """ - input: Activation function & Bias setting - return: simple CNN Network + for p in model.parameters(): + p.requires_grad = new_requires_grad - This function is used to automatically create a - simple CNN Network consisting of CNN & Linear layer - for different activation functions. - It is used to test `test.extensions`. + +def make_simple_act_setting(act_cls: Type[Module], bias: bool) -> dict: + """Create a simple CNN with activation as test case dictionary. + + Make parameters of final linear layer non-differentiable to save run time. + + Args: + act_cls: Class of the activation function. + bias: Use bias in the convolution. + + Returns: + Dictionary representation of the simple CNN test case. """ - def make_simple_cnn(act_cls, bias): - return torch.nn.Sequential( - torch.nn.Conv2d(3, 2, 2, bias=bias), - act_cls(), - torch.nn.Flatten(), - torch.nn.Linear(72, 5), - ) + def _make_simple_cnn(act_cls: Type[Module], bias: bool) -> Sequential: + linear = Linear(72, 5) + set_requires_grad(linear, False) + + return Sequential(Conv2d(3, 2, 2, bias=bias), act_cls(), Flatten(), linear) dict_setting = { - "input_fn": lambda: torch.rand(3, 3, 7, 7), - "module_fn": lambda: make_simple_cnn(act_cls, bias), - "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(), + "input_fn": lambda: rand(3, 3, 7, 7), + "module_fn": lambda: _make_simple_cnn(act_cls, bias), + "loss_function_fn": lambda: CrossEntropyLoss(), "target_fn": lambda: classification_targets((3,), 5), "id_prefix": "automated-simple-cnn-act", } @@ -37,40 +47,37 @@ def make_simple_cnn(act_cls, bias): return dict_setting -def make_simple_cnn_setting(input_size, conv_class, conv_params): - """ - input_size: tuple of input size of (N*C*Image Size) - conv_class: convolutional class - conv_params: configurations for convolutional class - return: simple CNN Network - - This function is used to automatically create a - simple CNN Network consisting of CNN & Linear layer - for different convolutional layers. - It is used to test `test.extensions`. +def make_simple_cnn_setting( + input_size: Tuple[int], conv_cls: Type[Module], conv_params: Tuple[Any] +) -> dict: + """Create ReLU CNN with convolution hyperparameters as test case dictionary. + + Make parameters of final linear layer non-differentiable to save run time. + + Args: + input_size: Input shape ``[N, C_in, ...]``. + conv_cls: Class of convolution layer. + conv_params: Convolution hyperparameters. + + Returns: + Dictionary representation of the test case. """ - def make_cnn(conv_class, output_size, conv_params): - """Note: output class size is assumed to be 5""" - return torch.nn.Sequential( - conv_class(*conv_params), - torch.nn.ReLU(), - torch.nn.Flatten(), - torch.nn.Linear(output_size, 5), - ) + def _make_cnn( + conv_cls: Type[Module], output_dim: int, conv_params: Tuple + ) -> Sequential: + linear = Linear(output_dim, 5) + set_requires_grad(linear, False) - def get_output_shape(module, module_params, input): - """Returns the output shape for a given layer.""" - output = module(*module_params)(input) - return output.numel() // output.shape[0] + return Sequential(conv_cls(*conv_params), ReLU(), Flatten(), linear) - input = torch.rand(input_size) - output_size = get_output_shape(conv_class, conv_params, input) + input = rand(input_size) + output_dim = _get_output_dim(conv_cls(*conv_params), input) dict_setting = { - "input_fn": lambda: torch.rand(input_size), - "module_fn": lambda: make_cnn(conv_class, output_size, conv_params), - "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"), + "input_fn": lambda: rand(input_size), + "module_fn": lambda: _make_cnn(conv_cls, output_dim, conv_params), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"), "target_fn": lambda: classification_targets((3,), 5), "id_prefix": "automated-simple-cnn", } @@ -78,49 +85,59 @@ def get_output_shape(module, module_params, input): return dict_setting -def make_simple_pooling_setting(input_size, conv_class, pool_cls, pool_params): - """ - input_size: tuple of input size of (N*C*Image Size) - conv_class: convolutional class - conv_params: configurations for convolutional class - return: simple CNN Network - - This function is used to automatically create a - simple CNN Network consisting of CNN & Linear layer - for different convolutional layers. - It is used to test `test.extensions`. +def make_simple_pooling_setting( + input_size: Tuple[int], + conv_cls: Type[Module], + pool_cls: Type[Module], + pool_params: Tuple[Any], +) -> dict: + """Create CNN with convolution and pooling layer as test case dictionary. + + Make parameters of final linear layer non-differentiable to save run time. + + Args: + input_size: Input shape ``[N, C_in, ...]``. + conv_cls: Class of convolution layer. + pool_cls: Class of pooling layer. + pool_params: Pooling hyperparameters. + + Returns: + Dictionary representation of the test case. """ - def make_cnn(conv_class, output_size, conv_params, pool_cls, pool_params): - """Note: output class size is assumed to be 5""" - return torch.nn.Sequential( - conv_class(*conv_params), - torch.nn.ReLU(), - pool_cls(*pool_params), - torch.nn.Flatten(), - torch.nn.Linear(output_size, 5), + def _make_cnn( + conv_cls: Type[Module], + output_size: int, + conv_params: Tuple[Any], + pool_cls: Type[Module], + pool_params: Tuple[Any], + ) -> Sequential: + linear = Linear(output_size, 5) + set_requires_grad(linear, False) + + return Sequential( + conv_cls(*conv_params), ReLU(), pool_cls(*pool_params), Flatten(), linear ) - def get_output_shape(module, module_params, input, pool, pool_params): - """Returns the output shape for a given layer.""" - output_1 = module(*module_params)(input) - output = pool_cls(*pool_params)(output_1) - return output.numel() // output.shape[0] - conv_params = (3, 2, 2) - input = torch.rand(input_size) - output_size = get_output_shape( - conv_class, conv_params, input, pool_cls, pool_params + input = rand(input_size) + output_dim = _get_output_dim( + Sequential(conv_cls(*conv_params), pool_cls(*pool_params)), input ) dict_setting = { - "input_fn": lambda: torch.rand(input_size), - "module_fn": lambda: make_cnn( - conv_class, output_size, conv_params, pool_cls, pool_params + "input_fn": lambda: rand(input_size), + "module_fn": lambda: _make_cnn( + conv_cls, output_dim, conv_params, pool_cls, pool_params ), - "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="sum"), + "loss_function_fn": lambda: CrossEntropyLoss(reduction="sum"), "target_fn": lambda: classification_targets((3,), 5), "id_prefix": "automated-simple-cnn", } return dict_setting + + +def _get_output_dim(module: Module, input: Tensor) -> int: + output = module(input) + return output.numel() // output.shape[0] diff --git a/test/extensions/implementation/autograd.py b/test/extensions/implementation/autograd.py index c86350eb..2edcf9f1 100644 --- a/test/extensions/implementation/autograd.py +++ b/test/extensions/implementation/autograd.py @@ -18,14 +18,14 @@ def batch_grad(self) -> List[Tensor]: # noqa: D102 N = self.problem.input.shape[0] batch_grads = [ torch.zeros(N, *p.size()).to(self.problem.device) - for p in self.problem.model.parameters() + for p in self.problem.trainable_parameters() ] loss_list = torch.zeros((N)) gradients_list = [] for b in range(N): _, _, loss = self.problem.forward_pass(sample_idx=b) - gradients = torch.autograd.grad(loss, self.problem.model.parameters()) + gradients = torch.autograd.grad(loss, self.problem.trainable_parameters()) gradients_list.append(gradients) loss_list[b] = loss @@ -47,14 +47,14 @@ def sgs(self) -> List[Tensor]: # noqa: D102 N = self.problem.input.shape[0] sgs = [ torch.zeros(*p.size()).to(self.problem.device) - for p in self.problem.model.parameters() + for p in self.problem.trainable_parameters() ] loss_list = torch.zeros((N)) gradients_list = [] for b in range(N): _, _, loss = self.problem.forward_pass(sample_idx=b) - gradients = torch.autograd.grad(loss, self.problem.model.parameters()) + gradients = torch.autograd.grad(loss, self.problem.trainable_parameters()) loss_list[b] = loss gradients_list.append(gradients) @@ -81,7 +81,7 @@ def extract_ith_element_of_diag_ggn(i, p, loss, output): return GGN_v[i] diag_ggns = [] - for p in list(self.problem.model.parameters()): + for p in list(self.problem.trainable_parameters()): diag_ggn_p = torch.zeros_like(p).view(-1) for parameter_index in range(p.numel()): @@ -146,7 +146,7 @@ def extract_ith_element_of_diag_h(i, p, df_dx): return Hv[i] diag_hs = [] - for p in list(self.problem.model.parameters()): + for p in list(self.problem.trainable_parameters()): diag_h_p = torch.zeros_like(p).view(-1) df_dx = torch.autograd.grad(loss, [p], create_graph=True, retain_graph=True) @@ -179,8 +179,9 @@ def diag_h_batch(self) -> List[Tensor]: # noqa: D102 def ggn(self) -> Tensor: # noqa: D102 _, output, loss = self.problem.forward_pass() model = self.problem.model + params = list(self.problem.trainable_parameters()) - num_params = sum(p.numel() for p in model.parameters()) + num_params = sum(p.numel() for p in params) ggn = torch.zeros(num_params, num_params).to(self.problem.device) for i in range(num_params): @@ -189,7 +190,7 @@ def ggn(self) -> Tensor: # noqa: D102 e_i[i] = 1.0 # convert to model parameter shapes - e_i_list = vector_to_parameter_list(e_i, model.parameters()) + e_i_list = vector_to_parameter_list(e_i, params) ggn_i_list = ggn_vector_product(loss, output, model, e_i_list) ggn_i = parameters_to_vector(ggn_i_list) diff --git a/test/extensions/implementation/backpack.py b/test/extensions/implementation/backpack.py index c8d4d90b..3340654e 100644 --- a/test/extensions/implementation/backpack.py +++ b/test/extensions/implementation/backpack.py @@ -30,15 +30,13 @@ def batch_grad(self) -> List[Tensor]: # noqa:D102 with backpack(new_ext.BatchGrad()): _, _, loss = self.problem.forward_pass() loss.backward() - batch_grads = [p.grad_batch for p in self.problem.model.parameters()] - return batch_grads + return self.problem.collect_data("grad_batch") def batch_l2_grad(self) -> List[Tensor]: # noqa:D102 with backpack(new_ext.BatchL2Grad()): _, _, loss = self.problem.forward_pass() loss.backward() - batch_l2_grad = [p.batch_l2 for p in self.problem.model.parameters()] - return batch_l2_grad + return self.problem.collect_data("batch_l2") def batch_l2_grad_extension_hook(self) -> List[Tensor]: """Individual gradient squared ℓ₂ norms via extension hook. @@ -50,15 +48,13 @@ def batch_l2_grad_extension_hook(self) -> List[Tensor]: with backpack(new_ext.BatchGrad(), extension_hook=hook): _, _, loss = self.problem.forward_pass() loss.backward() - batch_l2_grad = [p.batch_l2_hook for p in self.problem.model.parameters()] - return batch_l2_grad + return self.problem.collect_data("batch_l2_hook") def sgs(self) -> List[Tensor]: # noqa:D102 with backpack(new_ext.SumGradSquared()): _, _, loss = self.problem.forward_pass() loss.backward() - sgs = [p.sum_grad_squared for p in self.problem.model.parameters()] - return sgs + return self.problem.collect_data("sum_grad_squared") def sgs_extension_hook(self) -> List[Tensor]: """Individual gradient second moment via extension hook. @@ -70,47 +66,37 @@ def sgs_extension_hook(self) -> List[Tensor]: with backpack(new_ext.BatchGrad(), extension_hook=hook): _, _, loss = self.problem.forward_pass() loss.backward() - sgs = [p.sum_grad_squared_hook for p in self.problem.model.parameters()] - return sgs + return self.problem.collect_data("sum_grad_squared_hook") def variance(self) -> List[Tensor]: # noqa:D102 with backpack(new_ext.Variance()): _, _, loss = self.problem.forward_pass() loss.backward() - variances = [p.variance for p in self.problem.model.parameters()] - return variances + return self.problem.collect_data("variance") def diag_ggn(self) -> List[Tensor]: # noqa:D102 with backpack(new_ext.DiagGGNExact()): _, _, loss = self.problem.forward_pass() loss.backward() - diag_ggn = [p.diag_ggn_exact for p in self.problem.model.parameters()] - return diag_ggn + return self.problem.collect_data("diag_ggn_exact") def diag_ggn_exact_batch(self) -> List[Tensor]: # noqa:D102 with backpack(new_ext.BatchDiagGGNExact()): _, _, loss = self.problem.forward_pass() loss.backward() - diag_ggn_exact_batch = [ - p.diag_ggn_exact_batch for p in self.problem.model.parameters() - ] - return diag_ggn_exact_batch + return self.problem.collect_data("diag_ggn_exact_batch") def diag_ggn_mc(self, mc_samples) -> List[Tensor]: # noqa:D102 with backpack(new_ext.DiagGGNMC(mc_samples=mc_samples)): _, _, loss = self.problem.forward_pass() loss.backward() - diag_ggn_mc = [p.diag_ggn_mc for p in self.problem.model.parameters()] - return diag_ggn_mc + return self.problem.collect_data("diag_ggn_mc") def diag_ggn_mc_batch(self, mc_samples) -> List[Tensor]: # noqa:D102 with backpack(new_ext.BatchDiagGGNMC(mc_samples=mc_samples)): _, _, loss = self.problem.forward_pass() loss.backward() - diag_ggn_mc_batch = [ - p.diag_ggn_mc_batch for p in self.problem.model.parameters() - ] - return diag_ggn_mc_batch + return self.problem.collect_data("diag_ggn_mc_batch") def diag_ggn_mc_chunk(self, mc_samples: int, chunks: int = 10) -> List[Tensor]: """Like ``diag_ggn_mc``, but can handle more samples by chunking. @@ -198,40 +184,31 @@ def diag_h(self) -> List[Tensor]: # noqa:D102 with backpack(new_ext.DiagHessian()): _, _, loss = self.problem.forward_pass() loss.backward() - diag_h = [p.diag_h for p in self.problem.model.parameters()] - return diag_h + return self.problem.collect_data("diag_h") def kfac(self, mc_samples: int = 1) -> List[List[Tensor]]: # noqa:D102 with backpack(new_ext.KFAC(mc_samples=mc_samples)): _, _, loss = self.problem.forward_pass() loss.backward() - kfac = [p.kfac for p in self.problem.model.parameters()] - - return kfac + return self.problem.collect_data("kfac") def kflr(self) -> List[List[Tensor]]: # noqa:D102 with backpack(new_ext.KFLR()): _, _, loss = self.problem.forward_pass() loss.backward() - kflr = [p.kflr for p in self.problem.model.parameters()] - - return kflr + return self.problem.collect_data("kflr") def kfra(self) -> List[List[Tensor]]: # noqa:D102 with backpack(new_ext.KFRA()): _, _, loss = self.problem.forward_pass() loss.backward() - kfra = [p.kfra for p in self.problem.model.parameters()] - - return kfra + return self.problem.collect_data("kfra") def diag_h_batch(self) -> List[Tensor]: # noqa:D102 with backpack(new_ext.BatchDiagHessian()): _, _, loss = self.problem.forward_pass() loss.backward() - diag_h_batch = [p.diag_h_batch for p in self.problem.model.parameters()] - - return diag_h_batch + return self.problem.collect_data("diag_h_batch") def ggn(self) -> Tensor: # noqa:D102 return self._square_sqrt_ggn(self.sqrt_ggn()) @@ -245,8 +222,7 @@ def sqrt_ggn(self) -> List[Tensor]: with backpack(new_ext.SqrtGGNExact()): _, _, loss = self.problem.forward_pass() loss.backward() - - return [p.sqrt_ggn_exact for p in self.problem.model.parameters()] + return self.problem.collect_data("sqrt_ggn_exact") def sqrt_ggn_mc(self, mc_samples: int) -> List[Tensor]: """Compute the approximate matrix square root of the generalized Gauss-Newton. @@ -260,8 +236,7 @@ def sqrt_ggn_mc(self, mc_samples: int) -> List[Tensor]: with backpack(new_ext.SqrtGGNMC(mc_samples=mc_samples)): _, _, loss = self.problem.forward_pass() loss.backward() - - return [p.sqrt_ggn_mc for p in self.problem.model.parameters()] + return self.problem.collect_data("sqrt_ggn_mc") def ggn_mc(self, mc_samples: int, chunks: int = 1) -> Tensor: # noqa:D102 samples = self.chunk_sizes(mc_samples, chunks) diff --git a/test/extensions/problem.py b/test/extensions/problem.py index 05f25afe..0d940cec 100644 --- a/test/extensions/problem.py +++ b/test/extensions/problem.py @@ -2,8 +2,10 @@ import copy from test.core.derivatives.utils import get_available_devices +from typing import Any, Iterator, List import torch +from torch.nn.parameter import Parameter from backpack import extend @@ -191,3 +193,39 @@ def get_reduction_factor(self, loss, unreduced_loss): f"'mean': {mean_loss}, 'sum': {sum_loss}, loss: {loss}", ) return factor + + def trainable_parameters(self) -> Iterator[Parameter]: + """Yield the model's trainable parameters. + + Yields: + Model parameter with gradients enabled. + """ + for p in self.model.parameters(): + if p.requires_grad: + yield p + + def collect_data(self, savefield: str) -> List[Any]: + """Collect BackPACK attributes from trainable parameters. + + Args: + savefield: Attribute name. + + Returns: + List of attributes saved under the trainable model parameters. + + Raises: + RuntimeError: If a non-differentiable parameter with the attribute is + encountered. + """ + data = [] + + for p in self.model.parameters(): + if p.requires_grad: + data.append(getattr(p, savefield)) + else: + if hasattr(p, savefield): + raise RuntimeError( + f"Found non-differentiable parameter with attribute '{savefield}'." + ) + + return data diff --git a/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py b/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py index ba3a0436..75b1dbee 100644 --- a/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py +++ b/test/extensions/secondorder/sqrt_ggn/test_sqrt_ggn.py @@ -41,9 +41,7 @@ def small_problem( Yields: Instantiated test case whose model's are small enough. """ - num_params = sum( - p.numel() for p in instantiated_problem.model.parameters() if p.requires_grad - ) + num_params = sum(p.numel() for p in instantiated_problem.trainable_parameters()) if num_params <= max_num_params: yield instantiated_problem else: