diff --git a/packages/tasks/src/model-data.ts b/packages/tasks/src/model-data.ts index 625f319fce..fd5b88f190 100644 --- a/packages/tasks/src/model-data.ts +++ b/packages/tasks/src/model-data.ts @@ -44,6 +44,10 @@ export interface ModelData { quant_method?: string; }; tokenizer_config?: TokenizerConfig; + processor_config?: { + chat_template?: string; + }; + chat_template_jinja?: string; adapter_transformers?: { model_name?: string; model_class?: string; diff --git a/packages/tasks/src/model-libraries-snippets.ts b/packages/tasks/src/model-libraries-snippets.ts index 1651d9a224..6966b29f58 100644 --- a/packages/tasks/src/model-libraries-snippets.ts +++ b/packages/tasks/src/model-libraries-snippets.ts @@ -1476,6 +1476,11 @@ export const terratorch = (model: ModelData): string[] => [ model = BACKBONE_REGISTRY.build("${model.id}")`, ]; +const hasChatTemplate = (model: ModelData): boolean => + model.config?.tokenizer_config?.chat_template !== undefined || + model.config?.processor_config?.chat_template !== undefined || + model.config?.chat_template_jinja !== undefined; + export const transformers = (model: ModelData): string[] => { const info = model.transformersInfo; if (!info) { @@ -1498,7 +1503,7 @@ export const transformers = (model: ModelData): string[] => { `${processorVarName} = ${info.processor}.from_pretrained("${model.id}"` + remote_code_snippet + ")", `model = ${info.auto_model}.from_pretrained("${model.id}"` + remote_code_snippet + ")" ); - if (model.tags.includes("conversational")) { + if (model.tags.includes("conversational") && hasChatTemplate(model)) { if (model.tags.includes("image-text-to-text")) { autoSnippet.push( "messages = [",