Skip to content

Commit

Permalink
fix colab
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Jun 17, 2024
1 parent 3813e3c commit 3569d35
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 2 deletions.
28 changes: 28 additions & 0 deletions colabs/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
backend: local
base_model: openai-community/gpt2
data:
column_mapping:
text: text
path: ''
train_split: null
valid_split: null
hub:
push_to_hub: true
token: ${{HF_TOKEN}}
username: ${{HF_USERNAME}}
log: tensorboard
params:
batch_size: 2
block_size: 1024
epochs: 3
gradient_accumulation: 4
lr: 3.0e-05
mixed_precision: fp16
model_max_length: 2048
optimizer: adamw_torch
peft: true
scheduler: linear
target_modules: all-linear
unsloth: false
project_name: autotrain-ao159-ddjag
task: llm
2 changes: 1 addition & 1 deletion src/autotrain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
warnings.filterwarnings("ignore", category=UserWarning, module="huggingface_hub")

logger = Logger().get_logger()
__version__ = "0.7.123"
__version__ = "0.7.124"


def is_colab():
Expand Down
10 changes: 9 additions & 1 deletion src/autotrain/app/colab.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def colab_app():
"Token Classification",
"DreamBooth LoRA",
"Image Classification",
"Image Regression",
"Object Detection",
"Tabular Classification",
"Tabular Regression",
Expand All @@ -56,6 +57,7 @@ def colab_app():
"Token Classification": "token-classification",
"DreamBooth LoRA": "dreambooth",
"Image Classification": "image-classification",
"Image Regression": "image-regression",
"Object Detection": "image-object-detection",
"Tabular Classification": "tabular:classification",
"Tabular Regression": "tabular:regression",
Expand Down Expand Up @@ -267,6 +269,10 @@ def update_col_mapping(*args):
col_mapping.value = '{"image": "image", "label": "label"}'
dataset_source_dropdown.disabled = False
valid_split.disabled = False
elif task == "image-regression":
col_mapping.value = '{"image": "image", "label": "target"}'
dataset_source_dropdown.disabled = False
valid_split.disabled = False
elif task == "image-object-detection":
col_mapping.value = '{"image": "image", "objects": "objects"}'
dataset_source_dropdown.disabled = False
Expand Down Expand Up @@ -379,8 +385,10 @@ def start_training(b):
"push_to_hub": push_to_hub,
},
}
if task_dropdown.value.startswith("llm"):
if TASK_MAP[task_dropdown.value].startswith("llm"):
config["data"]["chat_template"] = chat_template
if config["data"]["chat_template"] == "none":
config["data"]["chat_template"] = None
else:
config = {
"task": TASK_MAP[task_dropdown.value],
Expand Down

0 comments on commit 3569d35

Please sign in to comment.