Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLFlowLogger fails when logging hyperparameters as Trainer already does automatically #19889

Open
CristoJV opened this issue May 22, 2024 · 0 comments
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.1.x

Comments

@CristoJV
Copy link

Bug description

I encountered an MLFlow exception when logging my model hyperparameters at hooks on_fit_start or on_train_start:

def on_fit_start(self):
    if self.trainer.is_global_zero:
        hparams = copy.deepcopy(self.hparams)
        hparams = self.clean_hparams(hparams)
        if isinstance(self.logger, MLFlowLogger):
            self.logger.log_hyperparams(hparams)
        elif isinstance(self.logger, TensorBoardLogger):
            self.logger.log_hyperparams(hparams)

The on_fit_start hook successfully logs hyperparameters after cleaning them (verified with the MLFlow Client). However, immediately after, the following exception occurs:

mlflow.exceptions.RestException:
INVALID_PARAMETER_VALUE: Changing param values is not allowed. Param with key='loss_params/alpha' was already logged with value='[0.15253213047981262, 0.170266255736351, 0.15075302124023438, 0.28234609961509705, 0.2441024
9292850494]' for run ID='f74bdff19b6c4e9aa3abf5fd054f9c1c'. Attempted logging new value 'tensor([0.1525, 0.1703, 0.1508, 0.2823, 0.2441])'.

This exception is raised because MLFlow does not allow changing parameter values once they are logged. This led me to investigate if hyperparameters were being logged twice. As I found out by checking the stack trace, the trainer internally calls log_hyperparameters within _run, causing the hyperparameters to be logged twice:

face_sequence.py 405 main                                                                                                                                                                                                    
trainer.fit(                                                                                                                                                                                                                 
                                                                                                                                                                                                                             
trainer.py 544 fit                                                                                                                                                                                                           
call._call_and_handle_interrupt(                                                                                                                                                                                             
                                                                                                                                                                                                                             
call.py 43 _call_and_handle_interrupt                                                                                                                                                                                        
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)                                                                                                                                        
                                                                                                                                                                                                                             
subprocess_script.py 102 launch                                                                                                                                                                                              
return function(*args, **kwargs)                                                                                                                                                                                             
                                                                                                                                                                                                                             
trainer.py 580 _fit_impl                                                                                                                                                                                                     
self._run(model, ckpt_path=ckpt_path)

trainer.py 972 _run
_log_hyperparams(self)

utilities.py 93 _log_hyperparams
logger.log_hyperparams(hparams_initial)

rank_zero.py 42 wrapped_fn
return fn(*args, **kwargs)

mlflow.py 233 log_hyperparams
self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100])

client.py 1093 log_batch
return self._tracking_client.log_batch(

client.py 444 log_batch
self.store.log_batch(

rest_store.py 323 log_batch
self._call_endpoint(LogBatch, req_body)

rest_store.py 59 _call_endpoint

rest_store.py 59 _call_endpoint
return call_endpoint(self.get_host_creds(), endpoint, method, json_body, response_proto)

rest_utils.py 219 call_endpoint
response = verify_rest_response(response, endpoint)

rest_utils.py 151 verify_rest_response
raise RestException(json.loads(response.text))

Here is the extracted code from train.py (reduced version):

def _run(
    self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
    if self.state.fn == TrainerFn.FITTING:
        min_epochs, max_epochs = _parse_loop_limits(
            self.min_steps, self.max_steps, self.min_epochs, self.max_epochs, self
        )
        self.fit_loop.min_epochs = min_epochs
        self.fit_loop.max_epochs = max_epochs

    _log_hyperparams(self)

    if self.strategy.restore_checkpoint_after_setup:
        log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")
        self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)

    return results

And the _log_hyperparams function:

def _log_hyperparams(trainer: "pl.Trainer") -> None:
    if not trainer.loggers:
        return

    pl_module = trainer.lightning_module
    datamodule_log_hyperparams = trainer.datamodule._log_hyperparams if trainer.datamodule is not None else False

    hparams_initial = None
    if pl_module._log_hyperparams and datamodule_log_hyperparams:
        datamodule_hparams = trainer.datamodule.hparams_initial
        lightning_hparams = pl_module.hparams_initial
        inconsistent_keys = []
        for key in lightning_hparams.keys() & datamodule_hparams.keys():
            lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key]
            if (
                type(lm_val) != type(dm_val)
                or (isinstance(lm_val, Tensor) and id(lm_val) != id(dm_val))
                or lm_val != dm_val
            ):
                inconsistent_keys.append(key)
        if inconsistent_keys:
            raise RuntimeError(
                f"Error while merging hparams: the keys {inconsistent_keys} are present "
                "in both the LightningModule's and LightningDataModule's hparams "
                "but have different values."
            )
        hparams_initial = {**lightning_hparams, **datamodule_hparams}
    elif pl_module._log_hyperparams:
        hparams_initial = pl_module.hparams_initial
    elif datamodule_log_hyperparams:
        hparams_initial = trainer.datamodule.hparams_initial

    for logger in trainer.loggers:
        if hparams_initial is not None:
            logger.log_hyperparams(hparams_initial)
        logger.log_graph(pl_module)
        logger.save()

Is there any workaround to avoid the trainer logging the hyperparameters forcefully?

pytorch lightning version: Version: 2.1.4

What version are you seeing the problem on?

v2.1

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

@CristoJV CristoJV added bug Something isn't working needs triage Waiting to be triaged by maintainers labels May 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.1.x
Projects
None yet
Development

No branches or pull requests

1 participant