From 45c90a0c7955fed382e28f746580c6d54ec3fd52 Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Wed, 8 Oct 2025 13:54:16 +0200 Subject: [PATCH 1/2] Read `modules.json` when `dense_paths` is None As that might imply that the user originally provided a local path rather than a Hugging Face Hub ID, meaning that the `dense_paths` variable won't be filled, meaning that we need to read those from `modules.json` Note that this is just a premature quick solution, ideally this should be handled within `backends/src/lib.rs` rather than directly within the `CandleBackend` as otherwise we end up duplicating a lot of unnecessary code --- backends/candle/src/lib.rs | 145 +++++++++++++++++++++++++------------ 1 file changed, 98 insertions(+), 47 deletions(-) diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index bd8a8c60..2efed6df 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -115,6 +115,46 @@ enum Config { XlmRoberta(BertConfig), } +#[derive(Debug, Clone, Deserialize, PartialEq)] +enum ModuleType { + #[serde(rename = "sentence_transformers.models.Dense")] + Dense, + #[serde(rename = "sentence_transformers.models.Normalize")] + Normalize, + #[serde(rename = "sentence_transformers.models.Pooling")] + Pooling, + #[serde(rename = "sentence_transformers.models.Transformer")] + Transformer, +} + +#[derive(Debug, Clone, Deserialize)] +struct ModuleConfig { + #[allow(dead_code)] + idx: usize, + #[allow(dead_code)] + name: String, + path: String, + #[serde(rename = "type")] + module_type: ModuleType, +} + +fn parse_dense_paths_from_modules(model_path: &Path) -> Result, std::io::Error> { + let modules_path = model_path.join("modules.json"); + if !modules_path.exists() { + return Ok(vec![]); + } + + let content = std::fs::read_to_string(&modules_path)?; + let modules: Vec = serde_json::from_str(&content) + .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?; + + Ok(modules + .into_iter() + .filter(|module| module.module_type == ModuleType::Dense) + .map(|module| module.path) + .collect::>()) +} + pub struct CandleBackend { device: Device, model: Box, @@ -511,55 +551,66 @@ impl CandleBackend { } }; - // Load Dense layers from the provided Dense paths + // Load modules.json and read the Dense paths from there, unless `dense_paths` is provided + // in such case simply use the `dense_paths` + // 1. If `dense_paths` is None then try to read the `modules.json` file and parse the + // content to read the paths of the default Dense paths, useful when the model directory + // is provided as the `model-id` rather than the ID from the Hugging Face Hub + // 2. If `dense_paths` is Some (even if empty), respect that explicit choice and do not + // read from modules.json, this allows users to explicitly disable dense layers let mut dense_layers = Vec::new(); - if let Some(dense_paths) = &dense_paths { - if !dense_paths.is_empty() { - tracing::info!("Loading Dense module/s from path/s: {dense_paths:?}"); - - for dense_path in dense_paths.iter() { - let dense_safetensors = - model_path.join(format!("{dense_path}/model.safetensors")); - let dense_pytorch = model_path.join(format!("{dense_path}/pytorch_model.bin")); - - if dense_safetensors.exists() || dense_pytorch.exists() { - let dense_config_path = - model_path.join(format!("{dense_path}/config.json")); - - let dense_config_str = std::fs::read_to_string(&dense_config_path) - .map_err(|err| { - BackendError::Start(format!( - "Unable to read `{dense_path}/config.json` file: {err:?}", - )) - })?; - let dense_config: DenseConfig = serde_json::from_str(&dense_config_str) - .map_err(|err| { - BackendError::Start(format!( - "Unable to parse `{dense_path}/config.json`: {err:?}", - )) - })?; - - let dense_vb = if dense_safetensors.exists() { - unsafe { - VarBuilder::from_mmaped_safetensors( - &[dense_safetensors], - dtype, - &device, - ) - } - .s()? - } else { - VarBuilder::from_pth(&dense_pytorch, dtype, &device).s()? - }; - - let dense_layer = Box::new(Dense::load(dense_vb, &dense_config).s()?) - as Box; - dense_layers.push(dense_layer); - - tracing::info!("Loaded Dense module from path: {dense_path}"); + + let paths_to_load = if let Some(dense_paths) = &dense_paths { + // If dense_paths is explicitly provided (even if empty), respect that choice + dense_paths.clone() + } else { + // Try to parse modules.json only if dense_paths is None + parse_dense_paths_from_modules(model_path).unwrap_or_default() + }; + + if !paths_to_load.is_empty() { + tracing::info!("Loading Dense module/s from path/s: {paths_to_load:?}"); + + for dense_path in paths_to_load.iter() { + let dense_safetensors = model_path.join(format!("{dense_path}/model.safetensors")); + let dense_pytorch = model_path.join(format!("{dense_path}/pytorch_model.bin")); + + if dense_safetensors.exists() || dense_pytorch.exists() { + let dense_config_path = model_path.join(format!("{dense_path}/config.json")); + + let dense_config_str = + std::fs::read_to_string(&dense_config_path).map_err(|err| { + BackendError::Start(format!( + "Unable to read `{dense_path}/config.json` file: {err:?}", + )) + })?; + let dense_config: DenseConfig = serde_json::from_str(&dense_config_str) + .map_err(|err| { + BackendError::Start(format!( + "Unable to parse `{dense_path}/config.json`: {err:?}", + )) + })?; + + let dense_vb = if dense_safetensors.exists() { + unsafe { + VarBuilder::from_mmaped_safetensors( + &[dense_safetensors], + dtype, + &device, + ) + } + .s()? } else { - tracing::warn!("Dense module files not found for path: {dense_path}",); - } + VarBuilder::from_pth(&dense_pytorch, dtype, &device).s()? + }; + + let dense_layer = Box::new(Dense::load(dense_vb, &dense_config).s()?) + as Box; + dense_layers.push(dense_layer); + + tracing::info!("Loaded Dense module from path: {dense_path}"); + } else { + tracing::warn!("Dense module files not found for path: {dense_path}",); } } } From 9894e2a71b5f310a5d9276ac0f031ccfa0b797ab Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Wed, 8 Oct 2025 14:33:24 +0200 Subject: [PATCH 2/2] Move `modules.json` logic to `backends/src/lib.rs` instead --- backends/candle/src/lib.rs | 144 ++++++++++++------------------------- backends/src/lib.rs | 33 ++++++++- 2 files changed, 78 insertions(+), 99 deletions(-) diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index 2efed6df..ff824f55 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -115,46 +115,6 @@ enum Config { XlmRoberta(BertConfig), } -#[derive(Debug, Clone, Deserialize, PartialEq)] -enum ModuleType { - #[serde(rename = "sentence_transformers.models.Dense")] - Dense, - #[serde(rename = "sentence_transformers.models.Normalize")] - Normalize, - #[serde(rename = "sentence_transformers.models.Pooling")] - Pooling, - #[serde(rename = "sentence_transformers.models.Transformer")] - Transformer, -} - -#[derive(Debug, Clone, Deserialize)] -struct ModuleConfig { - #[allow(dead_code)] - idx: usize, - #[allow(dead_code)] - name: String, - path: String, - #[serde(rename = "type")] - module_type: ModuleType, -} - -fn parse_dense_paths_from_modules(model_path: &Path) -> Result, std::io::Error> { - let modules_path = model_path.join("modules.json"); - if !modules_path.exists() { - return Ok(vec![]); - } - - let content = std::fs::read_to_string(&modules_path)?; - let modules: Vec = serde_json::from_str(&content) - .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?; - - Ok(modules - .into_iter() - .filter(|module| module.module_type == ModuleType::Dense) - .map(|module| module.path) - .collect::>()) -} - pub struct CandleBackend { device: Device, model: Box, @@ -551,66 +511,54 @@ impl CandleBackend { } }; - // Load modules.json and read the Dense paths from there, unless `dense_paths` is provided - // in such case simply use the `dense_paths` - // 1. If `dense_paths` is None then try to read the `modules.json` file and parse the - // content to read the paths of the default Dense paths, useful when the model directory - // is provided as the `model-id` rather than the ID from the Hugging Face Hub - // 2. If `dense_paths` is Some (even if empty), respect that explicit choice and do not - // read from modules.json, this allows users to explicitly disable dense layers let mut dense_layers = Vec::new(); - - let paths_to_load = if let Some(dense_paths) = &dense_paths { - // If dense_paths is explicitly provided (even if empty), respect that choice - dense_paths.clone() - } else { - // Try to parse modules.json only if dense_paths is None - parse_dense_paths_from_modules(model_path).unwrap_or_default() - }; - - if !paths_to_load.is_empty() { - tracing::info!("Loading Dense module/s from path/s: {paths_to_load:?}"); - - for dense_path in paths_to_load.iter() { - let dense_safetensors = model_path.join(format!("{dense_path}/model.safetensors")); - let dense_pytorch = model_path.join(format!("{dense_path}/pytorch_model.bin")); - - if dense_safetensors.exists() || dense_pytorch.exists() { - let dense_config_path = model_path.join(format!("{dense_path}/config.json")); - - let dense_config_str = - std::fs::read_to_string(&dense_config_path).map_err(|err| { - BackendError::Start(format!( - "Unable to read `{dense_path}/config.json` file: {err:?}", - )) - })?; - let dense_config: DenseConfig = serde_json::from_str(&dense_config_str) - .map_err(|err| { - BackendError::Start(format!( - "Unable to parse `{dense_path}/config.json`: {err:?}", - )) - })?; - - let dense_vb = if dense_safetensors.exists() { - unsafe { - VarBuilder::from_mmaped_safetensors( - &[dense_safetensors], - dtype, - &device, - ) - } - .s()? + if let Some(dense_paths) = dense_paths { + if !dense_paths.is_empty() { + tracing::info!("Loading Dense module/s from path/s: {dense_paths:?}"); + + for dense_path in dense_paths.iter() { + let dense_safetensors = + model_path.join(format!("{dense_path}/model.safetensors")); + let dense_pytorch = model_path.join(format!("{dense_path}/pytorch_model.bin")); + + if dense_safetensors.exists() || dense_pytorch.exists() { + let dense_config_path = + model_path.join(format!("{dense_path}/config.json")); + + let dense_config_str = std::fs::read_to_string(&dense_config_path) + .map_err(|err| { + BackendError::Start(format!( + "Unable to read `{dense_path}/config.json` file: {err:?}", + )) + })?; + let dense_config: DenseConfig = serde_json::from_str(&dense_config_str) + .map_err(|err| { + BackendError::Start(format!( + "Unable to parse `{dense_path}/config.json`: {err:?}", + )) + })?; + + let dense_vb = if dense_safetensors.exists() { + unsafe { + VarBuilder::from_mmaped_safetensors( + &[dense_safetensors], + dtype, + &device, + ) + } + .s()? + } else { + VarBuilder::from_pth(&dense_pytorch, dtype, &device).s()? + }; + + let dense_layer = Box::new(Dense::load(dense_vb, &dense_config).s()?) + as Box; + dense_layers.push(dense_layer); + + tracing::info!("Loaded Dense module from path: {dense_path}"); } else { - VarBuilder::from_pth(&dense_pytorch, dtype, &device).s()? - }; - - let dense_layer = Box::new(Dense::load(dense_vb, &dense_config).s()?) - as Box; - dense_layers.push(dense_layer); - - tracing::info!("Loaded Dense module from path: {dense_path}"); - } else { - tracing::warn!("Dense module files not found for path: {dense_path}",); + tracing::warn!("Dense module files not found for path: {dense_path}",); + } } } } diff --git a/backends/src/lib.rs b/backends/src/lib.rs index 7b30db65..93921106 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -433,7 +433,38 @@ async fn init_backend( tracing::info!("Dense modules downloaded in {:?}", start.elapsed()); Some(dense_paths) } else { - None + // For local models, try to parse modules.json and handle dense_path logic + let modules_json_path = model_path.join("modules.json"); + if modules_json_path.exists() { + match parse_dense_paths_from_modules(&modules_json_path).await { + Ok(module_paths) => match module_paths.len() { + 0 => Some(vec![]), + 1 => { + let path_to_use = if let Some(ref user_path) = dense_path { + if user_path != &module_paths[0] { + tracing::info!("`{}` found in `modules.json`, but using provided `--dense-path={user_path}` instead", module_paths[0]); + } + user_path.clone() + } else { + module_paths[0].clone() + }; + Some(vec![path_to_use]) + } + _ => { + if dense_path.is_some() { + tracing::warn!("A value for `--dense-path` was provided, but since there's more than one subsequent Dense module, then the provided value will be ignored."); + } + Some(module_paths) + } + }, + Err(err) => { + tracing::warn!("Failed to parse local modules.json: {err}"); + None + } + } + } else { + None + } }; let backend = CandleBackend::new(