Skip to content

Commit

Permalink
Fix import (#266)
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet committed Dec 29, 2023
1 parent 5c1d0a7 commit 8a275ff
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 43 deletions.
40 changes: 7 additions & 33 deletions swift/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from swift.utils.import_utils import _LazyModule
from .app_ui import gradio_chat_demo, gradio_generation_demo, llm_app_ui
from .infer import llm_infer, merge_lora, prepare_model_template
from .rome import rome_infer
# Recommend using `xxx_main`
from .run import (app_ui_main, dpo_main, infer_main, merge_lora_main,
rome_main, sft_main)
from .sft import llm_sft
from .utils import *

if TYPE_CHECKING:
from .app_ui import gradio_chat_demo, gradio_generation_demo, llm_app_ui
from .infer import llm_infer, merge_lora, prepare_model_template
from .rome import rome_infer
# Recommend using `xxx_main`
from .run import (app_ui_main, dpo_main, infer_main, merge_lora_main,
rome_main, sft_main)
from .sft import llm_sft
else:
_import_structure = {
'app_ui': ['gradio_chat_demo', 'gradio_generation_demo', 'llm_app_ui'],
'infer': ['llm_infer', 'merge_lora', 'prepare_model_template'],
'rome': ['rome_infer'],
'run': [
'app_ui_main', 'dpo_main', 'infer_main', 'merge_lora_main',
'rome_main', 'sft_main'
],
'sft': ['llm_sft'],
}

import sys

sys.modules[__name__] = _LazyModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)
17 changes: 7 additions & 10 deletions swift/trainers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from transformers.trainer_utils import (EvaluationStrategy, FSDPOption,
HPSearchBackend, HubStrategy,
IntervalStrategy, SchedulerType)

from swift.utils.import_utils import _LazyModule

if TYPE_CHECKING:
from .arguments import Seq2SeqTrainingArguments, TrainingArguments
from .dpo_trainers import DPOTrainer
from .trainers import Seq2SeqTrainer, Trainer
from .utils import EvaluationStrategy, FSDPOption, HPSearchBackend, HubStrategy, \
IntervalStrategy, SchedulerType, ShardedDDPOption
else:
_import_structure = {
'arguments': ['Seq2SeqTrainingArguments', 'TrainingArguments'],
'dpo_trainers': ['DPOTrainer'],
'trainers': ['Seq2SeqTrainer', 'Trainer'],
'utils': [
'EvaluationStrategy', 'FSDPOption', 'HPSearchBackend',
'HubStrategy', 'IntervalStrategy', 'SchedulerType',
'ShardedDDPOption'
]
}

import sys
Expand All @@ -27,9 +30,3 @@
module_spec=__spec__,
extra_objects={},
)

try:
# https://github.com/huggingface/transformers/pull/25702
from transformers.trainer_utils import ShardedDDPOption
except ImportError:
pass
9 changes: 9 additions & 0 deletions swift/trainers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@
from typing import List, Union

from torch.nn import Module
from transformers.trainer_utils import (EvaluationStrategy, FSDPOption,
HPSearchBackend, HubStrategy,
IntervalStrategy, SchedulerType)

try:
# https://github.com/huggingface/transformers/pull/25702
from transformers.trainer_utils import ShardedDDPOption
except ImportError:
ShardedDDPOption = None


def can_return_loss(model: Module) -> List[str]:
Expand Down

0 comments on commit 8a275ff

Please sign in to comment.