From f334bbd99e397bc824d48bb039849d4a3aa1e0ab Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Wed, 10 Apr 2024 07:53:53 +0000 Subject: [PATCH] chore: config_name_to_class uses config.model_type now --- optimum/tpu/modeling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/optimum/tpu/modeling.py b/optimum/tpu/modeling.py index e651e260..2af1ac21 100644 --- a/optimum/tpu/modeling.py +++ b/optimum/tpu/modeling.py @@ -24,7 +24,8 @@ from optimum.tpu.modeling_gemma import TpuGemmaForCausalLM def config_name_to_class(pretrained_model_name_or_path: str): - if "gemma" in pretrained_model_name_or_path: + config = AutoConfig.from_pretrained(pretrained_model_name_or_path) + if config.model_type == "gemma": return TpuGemmaForCausalLM return BaseAutoModelForCausalLM