Skip to content

Commit

Permalink
fix unsloth
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Jun 13, 2024
1 parent 49a1285 commit b9c8c60
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
11 changes: 1 addition & 10 deletions src/autotrain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
warnings.filterwarnings("ignore", category=UserWarning, module="huggingface_hub")

logger = Logger().get_logger()
__version__ = "0.7.120.dev0"
__version__ = "0.7.121.dev0"


def is_colab():
Expand All @@ -51,12 +51,3 @@ def is_colab():
return True
except ImportError:
return False


def is_unsloth_available():
try:
from unsloth import FastLanguageModel

return True
except ImportError:
return False
22 changes: 16 additions & 6 deletions src/autotrain/trainers/clm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

from autotrain import is_unsloth_available, logger
from autotrain import logger
from autotrain.trainers.clm.callbacks import LoadBestPeftModelCallback, SavePeftModelCallback
from autotrain.trainers.common import (
ALLOW_REMOTE_CODE,
Expand Down Expand Up @@ -567,19 +567,29 @@ def get_model(config, tokenizer):
)
model_type = model_config.model_type
unsloth_target_modules = None
logger.info(f"Unsloth available: {is_unsloth_available()}")
can_use_unloth = config.unsloth and is_unsloth_available() and config.trainer in ("default", "sft")
if model_type in ("llama", "mistral", "gemma", "qwen2"):
is_unsloth_available = False
can_use_unloth = False
try:
from unsloth import FastLanguageModel

is_unsloth_available = True
except ImportError:
pass
logger.info(f"Unsloth available: {is_unsloth_available}")

if config.unsloth and is_unsloth_available and config.trainer in ("default", "sft"):
can_use_unloth = True

if model_type in ("llama", "mistral", "gemma", "qwen2") and config.unsloth:
if config.target_modules.strip().lower() == "all-linear":
unsloth_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
else:
unsloth_target_modules = get_target_modules(config)
else:
can_use_unloth = False

logger.info(f"Can use unsloth: {can_use_unloth}")
if can_use_unloth:
from unsloth import FastLanguageModel

load_in_4bit = False
load_in_8bit = False
if config.peft and config.quantization == "int4":
Expand Down

0 comments on commit b9c8c60

Please sign in to comment.