Skip to content

Commit

Permalink
enable autotrain params for lm training
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Apr 29, 2023
1 parent 13431fe commit b708f9e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 30 deletions.
22 changes: 6 additions & 16 deletions src/autotrain/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,22 +340,12 @@ def app(): # username, valid_orgs):
st.sidebar.markdown("### Parameters")
if model_choice != "AutoTrain":
# on_change reset st.session_stat["jobs"]
if task != "lm_training":
param_choice = st.sidebar.selectbox(
"Parameter Choice",
["AutoTrain", "Manual"],
key="param_choice",
on_change=on_change_reset_jobs(),
)
else:
param_choice = st.sidebar.selectbox(
"Parameter Choice",
["Manual", "AutoTrain"],
key="param_choice",
on_change=on_change_reset_jobs(),
index=0,
disabled=True,
)
param_choice = st.sidebar.selectbox(
"Parameter Choice",
["AutoTrain", "Manual"],
key="param_choice",
on_change=on_change_reset_jobs(),
)
else:
param_choice = st.sidebar.selectbox(
"Parameter Choice", ["AutoTrain", "Manual"], key="param_choice", index=0, disabled=True
Expand Down
34 changes: 20 additions & 14 deletions src/autotrain/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,20 +246,26 @@ def _tabular_binary_classification(self):
}

def _lm_training(self):
return {
"learning_rate": LMLearningRate,
"optimizer": Optimizer,
"scheduler": Scheduler,
"train_batch_size": LMTrainBatchSize,
"num_train_epochs": LMEpochs,
"percentage_warmup": PercentageWarmup,
"gradient_accumulation_steps": GradientAccumulationSteps,
"weight_decay": WeightDecay,
"lora_r": LoraR,
"lora_alpha": LoraAlpha,
"lora_dropout": LoraDropout,
"training_type": LMTrainingType,
}
if self.param_choice == "manual":
return {
"learning_rate": LMLearningRate,
"optimizer": Optimizer,
"scheduler": Scheduler,
"train_batch_size": LMTrainBatchSize,
"num_train_epochs": LMEpochs,
"percentage_warmup": PercentageWarmup,
"gradient_accumulation_steps": GradientAccumulationSteps,
"weight_decay": WeightDecay,
"lora_r": LoraR,
"lora_alpha": LoraAlpha,
"lora_dropout": LoraDropout,
"training_type": LMTrainingType,
}
if self.param_choice == "autotrain":
return {
"num_models": NumModels,
}
raise ValueError("param_choice must be either autotrain or manual")

def _tabular_multi_class_classification(self):
return self._tabular_binary_classification()
Expand Down

0 comments on commit b708f9e

Please sign in to comment.