From b708f9e601bab8f869a3864f58f24119d08b9d48 Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Sat, 29 Apr 2023 12:19:11 +0200 Subject: [PATCH] enable autotrain params for lm training --- src/autotrain/app.py | 22 ++++++---------------- src/autotrain/params.py | 34 ++++++++++++++++++++-------------- 2 files changed, 26 insertions(+), 30 deletions(-) diff --git a/src/autotrain/app.py b/src/autotrain/app.py index bafd2823b3..04f7578223 100644 --- a/src/autotrain/app.py +++ b/src/autotrain/app.py @@ -340,22 +340,12 @@ def app(): # username, valid_orgs): st.sidebar.markdown("### Parameters") if model_choice != "AutoTrain": # on_change reset st.session_stat["jobs"] - if task != "lm_training": - param_choice = st.sidebar.selectbox( - "Parameter Choice", - ["AutoTrain", "Manual"], - key="param_choice", - on_change=on_change_reset_jobs(), - ) - else: - param_choice = st.sidebar.selectbox( - "Parameter Choice", - ["Manual", "AutoTrain"], - key="param_choice", - on_change=on_change_reset_jobs(), - index=0, - disabled=True, - ) + param_choice = st.sidebar.selectbox( + "Parameter Choice", + ["AutoTrain", "Manual"], + key="param_choice", + on_change=on_change_reset_jobs(), + ) else: param_choice = st.sidebar.selectbox( "Parameter Choice", ["AutoTrain", "Manual"], key="param_choice", index=0, disabled=True diff --git a/src/autotrain/params.py b/src/autotrain/params.py index 63feb8667e..30a9781e7b 100644 --- a/src/autotrain/params.py +++ b/src/autotrain/params.py @@ -246,20 +246,26 @@ def _tabular_binary_classification(self): } def _lm_training(self): - return { - "learning_rate": LMLearningRate, - "optimizer": Optimizer, - "scheduler": Scheduler, - "train_batch_size": LMTrainBatchSize, - "num_train_epochs": LMEpochs, - "percentage_warmup": PercentageWarmup, - "gradient_accumulation_steps": GradientAccumulationSteps, - "weight_decay": WeightDecay, - "lora_r": LoraR, - "lora_alpha": LoraAlpha, - "lora_dropout": LoraDropout, - "training_type": LMTrainingType, - } + if self.param_choice == "manual": + return { + "learning_rate": LMLearningRate, + "optimizer": Optimizer, + "scheduler": Scheduler, + "train_batch_size": LMTrainBatchSize, + "num_train_epochs": LMEpochs, + "percentage_warmup": PercentageWarmup, + "gradient_accumulation_steps": GradientAccumulationSteps, + "weight_decay": WeightDecay, + "lora_r": LoraR, + "lora_alpha": LoraAlpha, + "lora_dropout": LoraDropout, + "training_type": LMTrainingType, + } + if self.param_choice == "autotrain": + return { + "num_models": NumModels, + } + raise ValueError("param_choice must be either autotrain or manual") def _tabular_multi_class_classification(self): return self._tabular_binary_classification()