Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion swift/trainers/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self,
self.compute_loss_func = None # Compatible with the older version of transformers

if args.check_model and hasattr(model, 'model_dir'):
with ms_logger_context(logging.CRITICAL):
with ms_logger_context(logging.CRITICAL), self._patch_timeout():
check_local_model_is_latest(
model.model_dir, user_agent={
'invoked_by': 'local_trainer',
Expand Down Expand Up @@ -132,6 +132,24 @@ def _get_mean_metric():
# so reading train_state is skipped here.
self.args.resume_from_checkpoint = None

@contextmanager
def _patch_timeout(self):
from modelscope.hub.api import HubApi
__init__ = HubApi.__init__

def __new_init__(self, *args, **kwargs):
timeout = kwargs.get('timeout')
if timeout is not None and timeout > 5:
kwargs['timeout'] = 5
Comment on lines +142 to +143
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a magic number like 5 for the timeout cap makes the code harder to understand and maintain. It's better to define it as a named constant (e.g., _HUB_API_TIMEOUT_CAP = 5) at the beginning of the method or at the class level. This clarifies the purpose of the value and makes it easier to find and change if needed.

__init__(self, *args, **kwargs)

HubApi.__init__ = __new_init__

try:
yield
finally:
HubApi.__init__ = __init__

@property
def tokenizer(self):
# compat transformers5.0
Expand Down
5 changes: 4 additions & 1 deletion swift/trainers/rlhf_trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ def visualize_samples(self, num_print_samples: int):
break
df = pd.DataFrame(table)
if self.accelerator.process_index == 0:
print_rich_table(df[:num_print_samples])
try:
print_rich_table(df[:num_print_samples])
except Exception as e:
logger.error(e)
Comment on lines +85 to +86
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Catching a broad Exception can mask underlying issues and might unintentionally catch system-exiting exceptions like KeyboardInterrupt. It's generally safer to catch more specific exceptions. If the goal is to catch any error related to table printing, consider catching a tuple of likely exceptions, such as (ValueError, TypeError, ImportError), or at least re-raising system-level exceptions.

if 'wandb' in self.args.report_to:
import wandb

Expand Down
Loading