Skip to content

Commit 9fa22bb

Browse files
authored
[bugfix] patch timeout & fix print_rich_table (#6137)
1 parent d09ae00 commit 9fa22bb

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

swift/trainers/mixin.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __init__(self,
7575
self.compute_loss_func = None # Compatible with the older version of transformers
7676

7777
if args.check_model and hasattr(model, 'model_dir'):
78-
with ms_logger_context(logging.CRITICAL):
78+
with ms_logger_context(logging.CRITICAL), self._patch_timeout():
7979
check_local_model_is_latest(
8080
model.model_dir, user_agent={
8181
'invoked_by': 'local_trainer',
@@ -132,6 +132,24 @@ def _get_mean_metric():
132132
# so reading train_state is skipped here.
133133
self.args.resume_from_checkpoint = None
134134

135+
@contextmanager
136+
def _patch_timeout(self):
137+
from modelscope.hub.api import HubApi
138+
__init__ = HubApi.__init__
139+
140+
def __new_init__(self, *args, **kwargs):
141+
timeout = kwargs.get('timeout')
142+
if timeout is not None and timeout > 5:
143+
kwargs['timeout'] = 5
144+
__init__(self, *args, **kwargs)
145+
146+
HubApi.__init__ = __new_init__
147+
148+
try:
149+
yield
150+
finally:
151+
HubApi.__init__ = __init__
152+
135153
@property
136154
def tokenizer(self):
137155
# compat transformers5.0

swift/trainers/rlhf_trainer/reward_trainer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ def visualize_samples(self, num_print_samples: int):
8080
break
8181
df = pd.DataFrame(table)
8282
if self.accelerator.process_index == 0:
83-
print_rich_table(df[:num_print_samples])
83+
try:
84+
print_rich_table(df[:num_print_samples])
85+
except Exception as e:
86+
logger.error(e)
8487
if 'wandb' in self.args.report_to:
8588
import wandb
8689

0 commit comments

Comments
 (0)