diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index cdac7674db..cdf8900525 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -75,7 +75,7 @@ def __init__(self, self.compute_loss_func = None # Compatible with the older version of transformers if args.check_model and hasattr(model, 'model_dir'): - with ms_logger_context(logging.CRITICAL): + with ms_logger_context(logging.CRITICAL), self._patch_timeout(): check_local_model_is_latest( model.model_dir, user_agent={ 'invoked_by': 'local_trainer', @@ -132,6 +132,24 @@ def _get_mean_metric(): # so reading train_state is skipped here. self.args.resume_from_checkpoint = None + @contextmanager + def _patch_timeout(self): + from modelscope.hub.api import HubApi + __init__ = HubApi.__init__ + + def __new_init__(self, *args, **kwargs): + timeout = kwargs.get('timeout') + if timeout is not None and timeout > 5: + kwargs['timeout'] = 5 + __init__(self, *args, **kwargs) + + HubApi.__init__ = __new_init__ + + try: + yield + finally: + HubApi.__init__ = __init__ + @property def tokenizer(self): # compat transformers5.0 diff --git a/swift/trainers/rlhf_trainer/reward_trainer.py b/swift/trainers/rlhf_trainer/reward_trainer.py index 693feb12be..f022779bcd 100644 --- a/swift/trainers/rlhf_trainer/reward_trainer.py +++ b/swift/trainers/rlhf_trainer/reward_trainer.py @@ -80,7 +80,10 @@ def visualize_samples(self, num_print_samples: int): break df = pd.DataFrame(table) if self.accelerator.process_index == 0: - print_rich_table(df[:num_print_samples]) + try: + print_rich_table(df[:num_print_samples]) + except Exception as e: + logger.error(e) if 'wandb' in self.args.report_to: import wandb