Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/pytorch/llm/scripts/qwen_7b/qlora_ddp/infer.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ CUDA_VISIBLE_DEVICES=0 \
python src/llm_infer.py \
--model_type qwen-7b \
--sft_type lora \
--template_type chatml \
--template_type default \
--dtype bf16 \
--ckpt_dir "runs/qwen-7b/vx_xxx/checkpoint-xxx" \
--eval_human true \
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/llm/scripts/qwen_7b/qlora_ddp/sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ torchrun \
--quantization_bit 4 \
--bnb_4bit_comp_dtype bf16 \
--lora_rank 64 \
--lora_alpha 16 \
--lora_alpha 32 \
--lora_dropout_p 0.05 \
--lora_target_modules ALL \
--gradient_checkpointing true \
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/llm/scripts/qwen_7b_chat/qlora/sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ python src/llm_sft.py \
--quantization_bit 4 \
--bnb_4bit_comp_dtype bf16 \
--lora_rank 64 \
--lora_alpha 16 \
--lora_alpha 32 \
--lora_dropout_p 0.05 \
--lora_target_modules ALL \
--gradient_checkpointing true \
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/llm/scripts/qwen_7b_chat/qlora_ddp/sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ torchrun \
--quantization_bit 4 \
--bnb_4bit_comp_dtype bf16 \
--lora_rank 64 \
--lora_alpha 16 \
--lora_alpha 32 \
--lora_dropout_p 0.05 \
--lora_target_modules ALL \
--gradient_checkpointing true \
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/llm/scripts/qwen_agent/qlora_ddp/sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ torchrun \
--quantization_bit 4 \
--bnb_4bit_comp_dtype bf16 \
--lora_rank 64 \
--lora_alpha 16 \
--lora_alpha 32 \
--lora_dropout_p 0.05 \
--lora_target_modules ALL \
--gradient_checkpointing true \
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/llm/scripts/qwen_vl/qlora_ddp/infer.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ CUDA_VISIBLE_DEVICES=0 \
python src/llm_infer.py \
--model_type qwen-vl \
--sft_type lora \
--template_type chatml \
--template_type default \
--dtype bf16 \
--ckpt_dir "runs/qwen-vl/vx_xxx/checkpoint-xxx" \
--eval_human false \
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/llm/scripts/qwen_vl/qlora_ddp/sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ torchrun \
--quantization_bit 4 \
--bnb_4bit_comp_dtype bf16 \
--lora_rank 64 \
--lora_alpha 16 \
--lora_alpha 32 \
--lora_dropout_p 0.05 \
--lora_target_modules ALL \
--gradient_checkpointing true \
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/llm/scripts/qwen_vl_chat/qlora/sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ python src/llm_sft.py \
--quantization_bit 4 \
--bnb_4bit_comp_dtype bf16 \
--lora_rank 64 \
--lora_alpha 16 \
--lora_alpha 32 \
--lora_dropout_p 0.05 \
--lora_target_modules ALL \
--gradient_checkpointing true \
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/llm/scripts/qwen_vl_chat/qlora_ddp/sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ torchrun \
--quantization_bit 4 \
--bnb_4bit_comp_dtype bf16 \
--lora_rank 64 \
--lora_alpha 16 \
--lora_alpha 32 \
--lora_dropout_p 0.05 \
--lora_target_modules ALL \
--gradient_checkpointing true \
Expand Down
8 changes: 6 additions & 2 deletions examples/pytorch/llm/src/llm_infer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from dataclasses import dataclass, field
Expand Down Expand Up @@ -102,9 +103,12 @@ def llm_infer(args: InferArguments) -> None:
print_model_info(model)

# ### Inference
template_type = MODEL_MAPPING[args.model_type]['template']
preprocess_func = get_preprocess(
template_type, tokenizer, args.system, args.max_length, batched=False)
args.template_type,
tokenizer,
args.system,
args.max_length,
batched=False)
streamer = TextStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_config = GenerationConfig(
Expand Down
1 change: 1 addition & 0 deletions examples/pytorch/llm/src/llm_sft.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from dataclasses import dataclass, field
Expand Down
4 changes: 2 additions & 2 deletions examples/pytorch/llm/src/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
from .preprocess import TEMPLATE_MAPPING, get_preprocess
from .utils import (broadcast_string, download_dataset,
find_all_linear_for_lora, get_dist_setting, inference,
is_dist, is_master, plot_images, select_bnb, select_dtype,
show_layers)
is_dist, is_local_master, is_master, plot_images,
select_bnb, select_dtype, show_layers)
1 change: 1 addition & 0 deletions examples/pytorch/llm/src/utils/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import ast
import os
import re
Expand Down
24 changes: 12 additions & 12 deletions examples/pytorch/llm/src/utils/model.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from types import MethodType
from typing import NamedTuple, Optional

import torch
import torch.distributed as dist
from modelscope import (AutoConfig, AutoModel, AutoModelForCausalLM,
AutoTokenizer, Model, read_config, snapshot_download)
from torch import dtype as Dtype

from swift import get_logger
from .utils import broadcast_string, is_dist, is_master
from .utils import is_local_master

logger = get_logger()

Expand Down Expand Up @@ -313,16 +314,15 @@ def get_model_tokenizer(model_type: str,

model_dir = kwargs.pop('model_dir', None)
if model_dir is None:
if is_master():
model_dir = model_id
if not os.path.exists(model_id):
revision = data.get('revision', 'master')
model_dir = snapshot_download(
model_id,
revision,
ignore_file_pattern=ignore_file_pattern)
if is_dist():
model_dir = broadcast_string(model_dir)
if not is_local_master():
dist.barrier()
model_dir = model_id
if not os.path.exists(model_id):
revision = data.get('revision', 'master')
model_dir = snapshot_download(
model_id, revision, ignore_file_pattern=ignore_file_pattern)
if is_local_master():
dist.barrier()

model, tokenizer = get_function(model_dir, torch_dtype, load_model,
**kwargs)
Expand Down
1 change: 1 addition & 0 deletions examples/pytorch/llm/src/utils/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from transformers import PreTrainedTokenizer
Expand Down
1 change: 1 addition & 0 deletions examples/pytorch/llm/src/utils/trainer_patch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os

import json
Expand Down
6 changes: 6 additions & 0 deletions examples/pytorch/llm/src/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import os
import shutil
Expand Down Expand Up @@ -47,6 +48,11 @@ def is_master():
return rank in {-1, 0}


def is_local_master():
local_rank = get_dist_setting()[1]
return local_rank in {-1, 0}


def is_dist():
"""Determine if the training is distributed"""
rank, local_rank, _, _ = get_dist_setting()
Expand Down
2 changes: 1 addition & 1 deletion swift/tuners/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class PromptConfig(SwiftConfig):
'help':
'When set to True, prompt is attached in front of the embedding'
})

extract_embedding: bool = field(
default=False,
metadata={
Expand Down