diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 761c2721..7440e0c6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -119,7 +119,7 @@ jobs: run: cargo binstall cargo-llvm-cov --force --no-confirm - name: Generate coverage - run: cargo llvm-cov --workspace --all-features --lcov --output-path lcov.info + run: make coverage - name: Upload to codecov uses: codecov/codecov-action@v2 @@ -163,7 +163,10 @@ jobs: with: toolchain: nightly + - name: Install test dependencies + run: make setup + - run: rustup component add --toolchain nightly-x86_64-unknown-linux-gnu clippy - name: Run Clippy - run: cargo clippy --all-features --all-targets -- -D warnings + run: make lint diff --git a/.vscode/settings.json b/.vscode/settings.json index d1cdf47a..979fd32b 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,5 +1,5 @@ { - "rust-analyzer.cargo.features": ["bench"], + "rust-analyzer.cargo.features": ["dev-utils"], "rust-analyzer.procMacro.enable": true, "rust-analyzer.cargo.buildScripts.enable": true, "rust-analyzer.linkedProjects": [ diff --git a/Cargo.lock b/Cargo.lock index b3276d96..6e1cb176 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -760,12 +760,14 @@ version = "0.1.2" dependencies = [ "anyhow", "clap", - "clap-markdown", "clap_derive", "directories", + "encoderfile-core", "figment", "lazy_static", + "ndarray", "ort", + "rand 0.9.2", "schemars", "serde", "serde_json", diff --git a/Makefile b/Makefile index f25a3369..4e9322c1 100644 --- a/Makefile +++ b/Makefile @@ -12,6 +12,22 @@ format: @echo "Formatting rust..." @cargo fmt +.PHONY: lint +lint: + cargo clippy \ + --all-features \ + --all-targets \ + -- \ + -D warnings + +.PHONY: coverage +coverage: + cargo llvm-cov \ + --workspace \ + --all-features \ + --lcov \ + --output-path lcov.info + ,PHONY: licenses licenses: @echo "Generating licenses..." @@ -59,5 +75,3 @@ generate-docs: # generate JSON schema for encoderfile config @cargo run \ --bin generate-encoderfile-config-schema -# generate CLI docs for encoderfile build - @cargo run --bin generate-encoderfile-cli-docs --features="_internal" diff --git a/docs/reference/cli.md b/docs/reference/cli.md index cdb1e710..edf78f0d 100644 --- a/docs/reference/cli.md +++ b/docs/reference/cli.md @@ -71,7 +71,10 @@ encoderfile: transform: path: ./transforms/normalize.lua # OR inline transform: - # transform: "return normalize(output)" + # transform: "function Postprocess(logits) return logits:lp_normalize(2.0, 2.0) end" + + # Whether to validate transform with a dry-run (optional, defaults to true) + validate_transform: true # Whether to build the binary (optional, defaults to true) build: true diff --git a/encoderfile-core/benches/benchmark_transforms.rs b/encoderfile-core/benches/benchmark_transforms.rs index 1682c332..a35c0c89 100644 --- a/encoderfile-core/benches/benchmark_transforms.rs +++ b/encoderfile-core/benches/benchmark_transforms.rs @@ -1,3 +1,4 @@ +use encoderfile_core::transforms::Postprocessor; use ndarray::{Array2, Array3}; use rand::Rng; @@ -21,9 +22,9 @@ fn get_random_3d(x: usize, y: usize, z: usize) -> Array3 { #[divan::bench(args = [(16, 16, 16), (32, 128, 384), (32, 256, 768)])] fn bench_embedding_l2_normalization(bencher: divan::Bencher, (x, y, z): (usize, usize, usize)) { - let engine = encoderfile_core::transforms::Transform::new(include_str!( + let engine = encoderfile_core::transforms::EmbeddingTransform::new(Some(include_str!( "../../transforms/embedding/l2_normalize_embeddings.lua" - )) + ))) .unwrap(); let test_tensor = get_random_3d(x, y, z); @@ -35,8 +36,8 @@ fn bench_embedding_l2_normalization(bencher: divan::Bencher, (x, y, z): (usize, #[divan::bench(args = [(16, 2), (32, 8), (128, 32)])] fn bench_seq_cls_softmax(bencher: divan::Bencher, (x, y): (usize, usize)) { - let engine = encoderfile_core::transforms::Transform::new(include_str!( - "../../transforms/sequence_classification/softmax_logits.lua" + let engine = encoderfile_core::transforms::SequenceClassificationTransform::new(Some( + include_str!("../../transforms/sequence_classification/softmax_logits.lua"), )) .unwrap(); @@ -49,8 +50,8 @@ fn bench_seq_cls_softmax(bencher: divan::Bencher, (x, y): (usize, usize)) { #[divan::bench(args = [(16, 16, 2), (32, 128, 8), (128, 256, 32)])] fn bench_tok_cls_softmax(bencher: divan::Bencher, (x, y, z): (usize, usize, usize)) { - let engine = encoderfile_core::transforms::Transform::new(include_str!( - "../../transforms/token_classification/softmax_logits.lua" + let engine = encoderfile_core::transforms::TokenClassificationTransform::new(Some( + include_str!("../../transforms/token_classification/softmax_logits.lua"), )) .unwrap(); diff --git a/encoderfile-core/src/cli.rs b/encoderfile-core/src/cli.rs index 3a07282d..5833b2ba 100644 --- a/encoderfile-core/src/cli.rs +++ b/encoderfile-core/src/cli.rs @@ -3,11 +3,12 @@ use crate::{ EmbeddingRequest, ModelType, SentenceEmbeddingRequest, SequenceClassificationRequest, TokenClassificationRequest, }, - runtime::AppState, + runtime::{AppState, get_model, get_model_config, get_model_type, get_tokenizer}, server::{run_grpc, run_http, run_mcp}, services::{embedding, sentence_embedding, sequence_classification, token_classification}, }; use anyhow::Result; +use clap::Parser; use clap_derive::{Parser, Subcommand, ValueEnum}; use opentelemetry::trace::TracerProvider as _; use opentelemetry_otlp::{Protocol, WithExportConfig}; @@ -15,6 +16,37 @@ use opentelemetry_sdk::trace::SdkTracerProvider; use std::{fmt::Display, io::Write}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; +pub async fn cli_entrypoint( + model_bytes: &[u8], + config_str: &str, + tokenizer_json: &str, + model_type: &str, + model_id: &str, + transform_str: Option<&str>, +) -> Result<()> { + let cli = Cli::parse(); + + let session = get_model(model_bytes); + let config = get_model_config(config_str); + let tokenizer = get_tokenizer(tokenizer_json, &config); + let model_type = get_model_type(model_type); + let transform_str = transform_str.map(|t| t.to_string()); + let model_id = model_id.to_string(); + + let state = AppState { + session, + config, + tokenizer, + model_type, + model_id, + transform_str, + }; + + cli.command.execute(state).await?; + + Ok(()) +} + macro_rules! generate_cli_route { ($req:ident, $fn:path, $format:ident, $out_dir:expr, $state:expr) => {{ let result = $fn($req, &$state)?; diff --git a/encoderfile-core/src/common/mod.rs b/encoderfile-core/src/common/mod.rs index 14a800d6..640a3fe5 100644 --- a/encoderfile-core/src/common/mod.rs +++ b/encoderfile-core/src/common/mod.rs @@ -1,4 +1,5 @@ mod embedding; +mod model_config; mod model_metadata; mod model_type; mod sentence_embedding; @@ -7,6 +8,7 @@ mod token; mod token_classification; pub use embedding::*; +pub use model_config::*; pub use model_metadata::*; pub use model_type::*; pub use sentence_embedding::*; diff --git a/encoderfile-core/src/common/model_config.rs b/encoderfile-core/src/common/model_config.rs new file mode 100644 index 00000000..32f216ac --- /dev/null +++ b/encoderfile-core/src/common/model_config.rs @@ -0,0 +1,87 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +#[derive(Debug, Serialize, Deserialize)] +pub struct ModelConfig { + pub model_type: String, + pub pad_token_id: u32, + pub num_labels: Option, + pub id2label: Option>, + pub label2id: Option>, +} + +impl ModelConfig { + pub fn id2label(&self, id: u32) -> Option<&str> { + self.id2label.as_ref()?.get(&id).map(|s| s.as_str()) + } + + pub fn label2id(&self, label: &str) -> Option { + self.label2id.as_ref()?.get(label).copied() + } + + pub fn num_labels(&self) -> Option { + if self.num_labels.is_some() { + return self.num_labels; + } + + if let Some(id2label) = &self.id2label { + return Some(id2label.len()); + } + + if let Some(label2id) = &self.label2id { + return Some(label2id.len()); + } + + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_num_labels() { + let test_labels: Vec<(String, u32)> = vec![("a", 1), ("b", 2), ("c", 3)] + .into_iter() + .map(|(i, j)| (i.to_string(), j)) + .collect(); + + let label2id: HashMap = test_labels.clone().into_iter().collect(); + let id2label: HashMap = test_labels + .clone() + .into_iter() + .map(|(i, j)| (j, i)) + .collect(); + + let config = ModelConfig { + model_type: "MyModel".to_string(), + pad_token_id: 0, + num_labels: Some(3), + id2label: Some(id2label.clone()), + label2id: Some(label2id.clone()), + }; + + assert_eq!(config.num_labels(), Some(3)); + + let config = ModelConfig { + model_type: "MyModel".to_string(), + pad_token_id: 0, + num_labels: None, + id2label: Some(id2label.clone()), + label2id: Some(label2id.clone()), + }; + + assert_eq!(config.num_labels(), Some(3)); + + let config = ModelConfig { + model_type: "MyModel".to_string(), + pad_token_id: 0, + num_labels: None, + id2label: None, + label2id: Some(label2id.clone()), + }; + + assert_eq!(config.num_labels(), Some(3)); + } +} diff --git a/encoderfile-core/src/common/model_type.rs b/encoderfile-core/src/common/model_type.rs index 97b11ebe..4b3b2121 100644 --- a/encoderfile-core/src/common/model_type.rs +++ b/encoderfile-core/src/common/model_type.rs @@ -1,8 +1,9 @@ #[derive(Debug, Clone, serde::Serialize, utoipa::ToSchema)] #[serde(rename_all = "snake_case")] +#[repr(u8)] pub enum ModelType { - Embedding, - SequenceClassification, - TokenClassification, - SentenceEmbedding, + Embedding = 1, + SequenceClassification = 2, + TokenClassification = 3, + SentenceEmbedding = 4, } diff --git a/encoderfile-core/src/dev_utils/mod.rs b/encoderfile-core/src/dev_utils/mod.rs index fe824224..daa7f695 100644 --- a/encoderfile-core/src/dev_utils/mod.rs +++ b/encoderfile-core/src/dev_utils/mod.rs @@ -1,7 +1,6 @@ use crate::{ - common::ModelType, - runtime::{AppState, ModelConfig}, - transforms::Transform, + common::{ModelConfig, ModelType}, + runtime::AppState, }; use ort::session::Session; use parking_lot::Mutex; @@ -22,7 +21,7 @@ pub fn get_state(dir: &str, model_type: ModelType) -> AppState { config, model_type, model_id: "test-model".to_string(), - transform_factory: || Transform::new("").unwrap(), + transform_str: None, } } diff --git a/encoderfile-core/src/inference/embedding.rs b/encoderfile-core/src/inference/embedding.rs index 4535927a..fcd51403 100644 --- a/encoderfile-core/src/inference/embedding.rs +++ b/encoderfile-core/src/inference/embedding.rs @@ -5,6 +5,7 @@ use crate::{ common::{TokenEmbedding, TokenEmbeddingSequence, TokenInfo}, error::ApiError, runtime::AppState, + transforms::{EmbeddingTransform, Postprocessor}, }; #[tracing::instrument(skip_all)] @@ -24,7 +25,7 @@ pub fn embedding<'a>( .expect("Model does not return tensor of shape [n_batch, n_tokens, hidden_dim]") .into_owned(); - outputs = state.transform().postprocess(outputs)?; + outputs = EmbeddingTransform::new(state.transform_str())?.postprocess(outputs)?; let embeddings = postprocess(outputs, encodings); diff --git a/encoderfile-core/src/inference/sentence_embedding.rs b/encoderfile-core/src/inference/sentence_embedding.rs index ea937d14..31150930 100644 --- a/encoderfile-core/src/inference/sentence_embedding.rs +++ b/encoderfile-core/src/inference/sentence_embedding.rs @@ -1,7 +1,12 @@ use ndarray::{Array2, Axis, Ix2, Ix3}; use tokenizers::Encoding; -use crate::{common::SentenceEmbedding, error::ApiError, runtime::AppState}; +use crate::{ + common::SentenceEmbedding, + error::ApiError, + runtime::AppState, + transforms::{Postprocessor, SentenceEmbeddingTransform}, +}; #[tracing::instrument(skip_all)] pub fn sentence_embedding<'a>( @@ -28,9 +33,9 @@ pub fn sentence_embedding<'a>( .expect("Model does not return tensor of shape [n_batch, n_tokens, hidden_dim]") .into_owned(); - let transform = state.transform(); + let transform = SentenceEmbeddingTransform::new(state.transform_str())?; - let pooled_outputs = transform.pool(outputs, a_mask_arr)?; + let pooled_outputs = transform.postprocess((outputs, a_mask_arr))?; let embeddings = postprocess(pooled_outputs, encodings); diff --git a/encoderfile-core/src/inference/sequence_classification.rs b/encoderfile-core/src/inference/sequence_classification.rs index 20d9051f..fa219826 100644 --- a/encoderfile-core/src/inference/sequence_classification.rs +++ b/encoderfile-core/src/inference/sequence_classification.rs @@ -1,7 +1,8 @@ use crate::{ - common::SequenceClassificationResult, + common::{ModelConfig, SequenceClassificationResult}, error::ApiError, - runtime::{AppState, ModelConfig}, + runtime::AppState, + transforms::{Postprocessor, SequenceClassificationTransform}, }; use ndarray::{Array2, Axis, Ix2}; use ndarray_stats::QuantileExt; @@ -24,7 +25,7 @@ pub fn sequence_classification<'a>( .expect("Model does not return tensor of shape [n_batch, n_labels]") .into_owned(); - outputs = state.transform().postprocess(outputs)?; + outputs = SequenceClassificationTransform::new(state.transform_str())?.postprocess(outputs)?; let results = postprocess(outputs, &state.config); diff --git a/encoderfile-core/src/inference/token_classification.rs b/encoderfile-core/src/inference/token_classification.rs index 6e6da224..7ea58566 100644 --- a/encoderfile-core/src/inference/token_classification.rs +++ b/encoderfile-core/src/inference/token_classification.rs @@ -1,7 +1,8 @@ use crate::{ - common::{TokenClassification, TokenClassificationResult, TokenInfo}, + common::{ModelConfig, TokenClassification, TokenClassificationResult, TokenInfo}, error::ApiError, - runtime::{AppState, ModelConfig}, + runtime::AppState, + transforms::{Postprocessor, TokenClassificationTransform}, }; use ndarray::{Array3, Axis, Ix3}; use ndarray_stats::QuantileExt; @@ -24,7 +25,7 @@ pub fn token_classification<'a>( .expect("Model does not return tensor of shape [n_batch, n_tokens, n_labels]") .into_owned(); - outputs = state.transform().postprocess(outputs)?; + outputs = TokenClassificationTransform::new(state.transform_str())?.postprocess(outputs)?; let predictions = postprocess(outputs, encodings, &state.config); diff --git a/encoderfile-core/src/lib.rs b/encoderfile-core/src/lib.rs index d7501b35..5a4330e1 100644 --- a/encoderfile-core/src/lib.rs +++ b/encoderfile-core/src/lib.rs @@ -26,5 +26,7 @@ pub mod transport; pub mod dev_utils; pub use assets::get_banner; +#[cfg(feature = "transport")] +pub use cli::cli_entrypoint; #[cfg(feature = "runtime")] pub use runtime::AppState; diff --git a/encoderfile-core/src/runtime/config.rs b/encoderfile-core/src/runtime/config.rs index d8e02fb5..549403e3 100644 --- a/encoderfile-core/src/runtime/config.rs +++ b/encoderfile-core/src/runtime/config.rs @@ -1,8 +1,4 @@ -use std::collections::HashMap; - -use serde::{Deserialize, Serialize}; - -use crate::common::ModelType; +use crate::common::{ModelConfig, ModelType}; use std::sync::{Arc, OnceLock}; static MODEL_TYPE: OnceLock = OnceLock::new(); @@ -28,21 +24,3 @@ pub fn get_model_type(model_type: &str) -> ModelType { }) .clone() } - -#[derive(Debug, Serialize, Deserialize)] -pub struct ModelConfig { - pub model_type: String, - pub pad_token_id: u32, - pub id2label: Option>, - pub label2id: Option>, -} - -impl ModelConfig { - pub fn id2label(&self, id: u32) -> Option<&str> { - self.id2label.as_ref()?.get(&id).map(|s| s.as_str()) - } - - pub fn label2id(&self, label: &str) -> Option { - self.label2id.as_ref()?.get(label).copied() - } -} diff --git a/encoderfile-core/src/runtime/mod.rs b/encoderfile-core/src/runtime/mod.rs index a77cb679..b914b3ac 100644 --- a/encoderfile-core/src/runtime/mod.rs +++ b/encoderfile-core/src/runtime/mod.rs @@ -2,10 +2,8 @@ mod config; mod model; mod state; mod tokenizer; -mod transform; -pub use config::{ModelConfig, get_model_config, get_model_type}; +pub use config::{get_model_config, get_model_type}; pub use model::{Model, get_model}; pub use state::AppState; pub use tokenizer::{encode_text, get_tokenizer, get_tokenizer_from_string}; -pub use transform::get_transform; diff --git a/encoderfile-core/src/runtime/state.rs b/encoderfile-core/src/runtime/state.rs index b4f2f923..5ad845f4 100644 --- a/encoderfile-core/src/runtime/state.rs +++ b/encoderfile-core/src/runtime/state.rs @@ -4,7 +4,7 @@ use ort::session::Session; use parking_lot::Mutex; use tokenizers::Tokenizer; -use crate::{common::ModelType, runtime::config::ModelConfig, transforms::Transform}; +use crate::common::{ModelConfig, ModelType}; #[derive(Debug, Clone)] pub struct AppState { @@ -13,11 +13,14 @@ pub struct AppState { pub config: Arc, pub model_type: ModelType, pub model_id: String, - pub transform_factory: fn() -> Transform, + pub transform_str: Option, } impl AppState { - pub fn transform(&self) -> Transform { - (self.transform_factory)() + pub fn transform_str(&self) -> Option<&str> { + match &self.transform_str { + Some(t) => Some(t.as_ref()), + None => None, + } } } diff --git a/encoderfile-core/src/runtime/tokenizer.rs b/encoderfile-core/src/runtime/tokenizer.rs index ad531f31..a6d52918 100644 --- a/encoderfile-core/src/runtime/tokenizer.rs +++ b/encoderfile-core/src/runtime/tokenizer.rs @@ -1,4 +1,4 @@ -use crate::{error::ApiError, runtime::config::ModelConfig}; +use crate::{common::ModelConfig, error::ApiError}; use anyhow::Result; use std::str::FromStr; use std::sync::{Arc, OnceLock}; diff --git a/encoderfile-core/src/runtime/transform.rs b/encoderfile-core/src/runtime/transform.rs deleted file mode 100644 index c4b2e45a..00000000 --- a/encoderfile-core/src/runtime/transform.rs +++ /dev/null @@ -1,12 +0,0 @@ -use crate::transforms::Transform; - -#[cfg(not(tarpaulin_include))] -pub fn get_transform(transform_str: Option<&str>) -> Transform { - if let Some(script) = transform_str { - let engine = Transform::new(script).expect("Failed to create transform"); - - return engine; - } - - Transform::new("").expect("Failed to create transform") -} diff --git a/encoderfile-core/src/transforms/engine.rs b/encoderfile-core/src/transforms/engine.rs deleted file mode 100644 index c95a4dc2..00000000 --- a/encoderfile-core/src/transforms/engine.rs +++ /dev/null @@ -1,375 +0,0 @@ -use crate::error::ApiError; - -use super::tensor::Tensor; -use mlua::prelude::*; -use ndarray::{Array, Array2, Array3, Axis}; - -#[derive(Debug)] -pub struct Transform { - #[allow(dead_code)] - lua: Lua, - postprocessor: Option, -} - -impl Transform { - #[tracing::instrument(name = "new_transform", skip_all)] - pub fn new(transform: &str) -> Result { - let lua = new_lua()?; - - lua.load(transform) - .exec() - .map_err(|e| ApiError::LuaError(e.to_string()))?; - - let postprocessor = lua - .globals() - .get::>("Postprocess") - .map_err(|e| ApiError::LuaError(e.to_string()))?; - - Ok(Self { lua, postprocessor }) - } - - pub fn pool(&self, data: Array3, mask: Array2) -> Result, ApiError> { - let func = match &self.postprocessor { - Some(p) => p, - None => { - let batch = data.len_of(Axis(0)); - let hidden = data.len_of(Axis(2)); - - let mut out = Array2::::zeros((batch, hidden)); - - for b in 0..batch { - let emb = data.slice(ndarray::s![b, .., ..]); // [seq_len, hidden] - let m = mask.slice(ndarray::s![b, ..]); // [seq_len] - - // expand mask to [seq_len, hidden] - let m2 = m.insert_axis(Axis(1)); - - let weighted = &emb * &m2; // zero out padded tokens - let sum = weighted.sum_axis(Axis(0)); // sum over seq_len - let count = m.sum(); // number of real tokens - - out.slice_mut(ndarray::s![b, ..]).assign(&(sum / count)); - } - - return Ok(out); - } - }; - - let data_shape: Vec = data.shape().to_vec(); - let batch_size = data_shape[0]; - let embedding_dim = data_shape[2]; - let tensor = Tensor(data.into_dyn()); - - let Tensor(result) = func - .call::((tensor, Tensor(mask.into_dyn()))) - .map_err(|e| ApiError::LuaError(e.to_string()))?; - - // before pooling, input vector is shape [batch_size, n_tokens, embedding_dim] - // result should be [batch_size, embedding_dim] - if [batch_size, embedding_dim] != result.shape() { - return Err(ApiError::LuaError(format!( - "Postprocess function returned tensor of dim {:?}, expected {:?}", - result.shape(), - data_shape - ))); - } - - #[cfg(not(tarpaulin_include))] - result.into_dimensionality().map_err(|e| { - tracing::error!("Failed to cast array into Ix2: {e}"); - ApiError::InternalError( - "Failed to cast array into correct dim. This is not supposed to happen.", - ) - }) - } - - #[tracing::instrument(name = "transform_postprocess", skip_all)] - pub fn postprocess( - &self, - data: Array, - ) -> Result, ApiError> { - let func = match &self.postprocessor { - Some(p) => p, - None => return Ok(data), - }; - - let data_shape: Vec = data.shape().to_vec(); - let tensor = Tensor(data.into_dyn()); - - let Tensor(result) = func - .call::(tensor) - .map_err(|e| ApiError::LuaError(e.to_string()))?; - - if data_shape.as_slice() != result.shape() { - return Err(ApiError::LuaError(format!( - "Postprocess function returned tensor of dim {:?}, expected {:?}", - result.shape(), - data_shape - ))); - } - - #[cfg(not(tarpaulin_include))] - result.into_dimensionality::().map_err(|e| { - tracing::error!("Failed to cast array into Ix3: {e}"); - ApiError::InternalError( - "Failed to cast array into correct dim. This is not supposed to happen.", - ) - }) - } -} - -fn new_lua() -> Result { - let lua = Lua::new_with( - mlua::StdLib::TABLE | mlua::StdLib::STRING | mlua::StdLib::MATH, - mlua::LuaOptions::default(), - ) - .map_err(|e| { - tracing::error!( - "Failed to create new Lua engine. This should not happen. More details: {:?}", - e - ); - ApiError::InternalError("Failed to create new Lua engine") - })?; - - let globals = lua.globals(); - globals - .set( - "Tensor", - lua.create_function(|lua, value| Tensor::from_lua(value, lua)) - .map_err(|e| { - tracing::error!("Failed to create Lua tensor library: More details: {:?}", e); - ApiError::InternalError("Failed to create new Lua tensor library") - })?, - ) - .map_err(|e| { - tracing::error!("Failed to create Lua tensor library: More details: {:?}", e); - ApiError::InternalError("Failed to create new Lua tensor library") - })?; - - Ok(lua) -} - -#[cfg(test)] -mod tests { - use super::*; - - fn new_test_lua() -> Lua { - new_lua().expect("Failed to create new lua") - } - - #[test] - fn test_create_tensor() { - let lua = new_test_lua(); - lua.load( - r#" - function MyTensor() - return Tensor({1, 2, 3}) - end - "#, - ) - .exec() - .unwrap(); - - let function = lua - .globals() - .get::("MyTensor") - .expect("Failed to get MyTensor"); - - assert!(function.call::(()).is_ok()) - } - - #[test] - fn test_no_pooling() { - let engine = Transform::new("").expect("Failed to create engine"); - - let arr = ndarray::Array3::::from_elem((16, 32, 128), 2.0); - let mask = ndarray::Array2::::from_elem((16, 32), 1.0); - - let result = engine - .pool(arr.clone(), mask) - .expect("Failed to compute pool"); - - assert_eq!(result.shape(), [16, 128]); - - // if all elements are the same and all mask = 1, should return mean axis array - assert_eq!(arr.mean_axis(Axis(1)), Some(result)); - } - - #[test] - fn test_successful_pool() { - let engine = Transform::new( - r##" - function Postprocess(arr, mask) - -- sum along second axis (lol) - return arr:sum_axis(2) - end - "##, - ) - .expect("Failed to create engine"); - - let arr = ndarray::Array3::::from_elem((16, 32, 128), 2.0); - let mask = ndarray::Array2::::from_elem((16, 32), 1.0); - - let result = engine.pool(arr, mask).expect("Failed to compute pool"); - - assert_eq!(result.shape(), [16, 128]) - } - - #[test] - fn test_bad_dim_pool() { - let engine = Transform::new( - r##" - function Postprocess(arr, mask) - return arr - end - "##, - ) - .expect("Failed to create engine"); - - let arr = ndarray::Array3::::from_elem((16, 32, 128), 2.0); - let mask = ndarray::Array2::::from_elem((16, 32), 1.0); - - let result = engine.pool(arr, mask); - - assert!(result.is_err()); - } - - #[test] - fn test_no_transform_postprocessing() { - let engine = Transform::new("").expect("Failed to create Transform"); - - let arr = ndarray::Array2::::from_elem((3, 3), 2.0); - - let result = engine.postprocess(arr.clone()).expect("Failed"); - - assert_eq!(arr, result); - } - - #[test] - fn test_bad_output_transform_postprocessing() { - let engine = Transform::new( - r##" - function Postprocess(x) - return 1 - end - "##, - ) - .unwrap(); - - let arr = ndarray::Array2::::from_elem((3, 3), 2.0); - - let result = engine.postprocess(arr.clone()); - - assert!(result.is_err()) - } - - #[test] - fn test_bad_dimensionality_transform_postprocessing() { - let engine = Transform::new( - r##" - function Postprocess(x) - return x:sum_axis(1) - end - "##, - ) - .unwrap(); - - let arr = ndarray::Array2::::from_elem((3, 3), 2.0); - let result = engine.postprocess(arr.clone()); - - assert!(result.is_err()); - - if let Err(e) = result { - match e { - ApiError::LuaError(s) => { - assert!(s.contains("Postprocess function returned tensor of dim")) - } - _ => panic!("Didn't return lua error"), - } - } - } -} - -#[cfg(test)] -mod sandbox_tests { - use super::*; - - fn new_test_lua() -> Lua { - new_lua().expect("Failed to create new lua") - } - - #[test] - fn test_no_unsafe_stdlibs_loaded() { - let engine = new_test_lua(); - - // Should evaluate to nil, not a table or function - let val: mlua::Value = engine.load("return os").eval().unwrap(); - assert!(matches!(val, mlua::Value::Nil)); - - let val: mlua::Value = engine.load("return io").eval().unwrap(); - assert!(matches!(val, mlua::Value::Nil)); - - let val: mlua::Value = engine.load("return debug").eval().unwrap(); - assert!(matches!(val, mlua::Value::Nil)); - } - - #[test] - fn test_cannot_access_environment_or_execute_commands() { - let lua = new_lua().expect("Failed to create new Lua"); - - // `os.execute` shouldn't exist or be callable - let res = lua - .load("return type(os) == 'table' and type(os.execute) == 'function'") - .eval::(); - - assert!( - matches!(res, Ok(false) | Err(_)), - "os.execute should not be callable" - ); - } - - #[test] - fn test_no_file_system_access_via_package() { - let lua = new_test_lua(); - - // 'require' should not be usable - let res = lua.load("require('os')").exec(); - assert!(res.is_err()); - - // 'package' table should not exist - let res = lua.load("package").eval::(); - assert!(res.unwrap().is_nil()) - } - - #[test] - fn test_tensor_function_is_only_safe_binding() { - let lua = new_test_lua(); - - // Tensor should exist - let tensor_res = lua.load("return Tensor").eval::(); - assert!(tensor_res.is_ok()); - - // But nothing else custom - let res = lua.load("return DangerousFunction").eval::(); - assert!(res.unwrap().is_nil()); - } - - #[test] - fn test_limited_math_and_string_stdlibs() { - let lua = new_test_lua(); - - // math should work - assert_eq!(lua.load("return math.sqrt(9)").eval::().unwrap(), 3.0); - - // string manipulation should work - assert_eq!( - lua.load("return string.upper('sandbox')") - .eval::() - .unwrap(), - "SANDBOX" - ); - - // io.open should NOT exist - assert!(lua.load("return io.open").eval::().is_err()); - } -} diff --git a/encoderfile-core/src/transforms/engine/embedding.rs b/encoderfile-core/src/transforms/engine/embedding.rs new file mode 100644 index 00000000..59c15ceb --- /dev/null +++ b/encoderfile-core/src/transforms/engine/embedding.rs @@ -0,0 +1,125 @@ +use crate::error::ApiError; + +use super::{super::tensor::Tensor, EmbeddingTransform, Postprocessor}; +use ndarray::{Array3, Ix3}; + +impl Postprocessor for EmbeddingTransform { + type Input = Array3; + type Output = Array3; + + fn postprocess(&self, data: Self::Input) -> Result { + let func = match self.postprocessor() { + Some(p) => p, + None => return Ok(data), + }; + + let batch_size = data.shape()[0]; + let seq_len = data.shape()[1]; + + let tensor = Tensor(data.into_dyn()); + + let result = func + .call::(tensor) + .map_err(|e| ApiError::LuaError(e.to_string()))? + .into_inner() + .into_dimensionality::().map_err(|e| { + tracing::error!("Transform error: Failed to cast array into Ix3: {e}. Check your lua transform to make sure it returns a tensor of shape [batch_size, seq_len, *]"); + ApiError::LuaError("Error postprocessing embeddings".to_string()) + })?; + + let result_shape = result.shape(); + + if batch_size != result_shape[0] || seq_len != result_shape[1] { + tracing::error!( + "Transform error: expected tensor of shape [{}, {}, *], got tensor of shape {:?}", + batch_size, + seq_len, + result_shape + ); + + return Err(ApiError::LuaError( + "Error postprocessing embeddings".to_string(), + )); + } + + Ok(result) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_embedding_no_transform() { + let engine = EmbeddingTransform::new(Some("")).expect("Failed to create Transform"); + + let arr = ndarray::Array3::::from_elem((16, 32, 128), 2.0); + + let result = engine.postprocess(arr.clone()).expect("Failed"); + + assert_eq!(arr, result); + } + + #[test] + fn test_embedding_identity_transform() { + let engine = EmbeddingTransform::new(Some( + r##" + function Postprocess(arr) + return arr + end + "##, + )) + .expect("Failed to create engine"); + + let arr = ndarray::Array3::::from_elem((16, 32, 128), 2.0); + + let result = engine.postprocess(arr.clone()).expect("Failed"); + + assert_eq!(arr, result); + } + + #[test] + fn test_embedding_transform_bad_fn() { + let engine = EmbeddingTransform::new(Some( + r##" + function Postprocess(arr) + return 1 + end + "##, + )) + .expect("Failed to create engine"); + + let arr = ndarray::Array3::::from_elem((16, 32, 128), 2.0); + + let result = engine.postprocess(arr.clone()); + + assert!(result.is_err()) + } + + #[test] + fn test_bad_dimensionality_transform_postprocessing() { + let engine = EmbeddingTransform::new(Some( + r##" + function Postprocess(x) + return x:sum_axis(1) + end + "##, + )) + .unwrap(); + + let arr = ndarray::Array3::::from_elem((3, 3, 3), 2.0); + let result = engine.postprocess(arr.clone()); + + assert!(result.is_err()); + + if let Err(e) = result { + match e { + ApiError::LuaError(s) => { + assert!(s.contains("Error postprocessing embeddings")) + } + _ => panic!("Didn't return lua error"), + } + } + } +} diff --git a/encoderfile-core/src/transforms/engine/mod.rs b/encoderfile-core/src/transforms/engine/mod.rs new file mode 100644 index 00000000..eab32492 --- /dev/null +++ b/encoderfile-core/src/transforms/engine/mod.rs @@ -0,0 +1,202 @@ +use crate::{common::ModelType, error::ApiError}; + +use super::tensor::Tensor; +use mlua::prelude::*; + +mod embedding; +mod sentence_embedding; +mod sequence_classification; +mod token_classification; + +macro_rules! transform { + ($type_name:ident, $mt:ident) => { + pub type $type_name = Transform<{ ModelType::$mt as u8 }>; + }; +} + +transform!(EmbeddingTransform, Embedding); +transform!(SequenceClassificationTransform, SequenceClassification); +transform!(TokenClassificationTransform, TokenClassification); +transform!(SentenceEmbeddingTransform, SentenceEmbedding); + +pub trait Postprocessor: TransformSpec { + type Input; + type Output; + + fn postprocess(&self, data: Self::Input) -> Result; +} + +pub trait TransformSpec { + fn has_postprocessor(&self) -> bool; +} + +#[derive(Debug)] +pub struct Transform { + #[allow(dead_code)] + lua: Lua, + postprocessor: Option, +} + +impl Transform { + fn postprocessor(&self) -> &Option { + &self.postprocessor + } + + #[tracing::instrument(name = "new_transform", skip_all)] + pub fn new(transform: Option<&str>) -> Result { + let lua = new_lua()?; + + lua.load(transform.unwrap_or("")) + .exec() + .map_err(|e| ApiError::LuaError(e.to_string()))?; + + let postprocessor = lua + .globals() + .get::>("Postprocess") + .map_err(|e| ApiError::LuaError(e.to_string()))?; + + Ok(Self { lua, postprocessor }) + } +} + +impl TransformSpec for Transform { + fn has_postprocessor(&self) -> bool { + self.postprocessor.is_some() + } +} + +fn new_lua() -> Result { + let lua = Lua::new_with( + mlua::StdLib::TABLE | mlua::StdLib::STRING | mlua::StdLib::MATH, + mlua::LuaOptions::default(), + ) + .map_err(|e| { + tracing::error!( + "Failed to create new Lua engine. This should not happen. More details: {:?}", + e + ); + ApiError::InternalError("Failed to create new Lua engine") + })?; + + let globals = lua.globals(); + globals + .set( + "Tensor", + lua.create_function(|lua, value| Tensor::from_lua(value, lua)) + .map_err(|e| { + tracing::error!("Failed to create Lua tensor library: More details: {:?}", e); + ApiError::InternalError("Failed to create new Lua tensor library") + })?, + ) + .map_err(|e| { + tracing::error!("Failed to create Lua tensor library: More details: {:?}", e); + ApiError::InternalError("Failed to create new Lua tensor library") + })?; + + Ok(lua) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn new_test_lua() -> Lua { + new_lua().expect("Failed to create new lua") + } + + #[test] + fn test_create_tensor() { + let lua = new_test_lua(); + lua.load( + r#" + function MyTensor() + return Tensor({1, 2, 3}) + end + "#, + ) + .exec() + .unwrap(); + + let function = lua + .globals() + .get::("MyTensor") + .expect("Failed to get MyTensor"); + + assert!(function.call::(()).is_ok()) + } + + #[test] + fn test_no_unsafe_stdlibs_loaded() { + let engine = new_test_lua(); + + // Should evaluate to nil, not a table or function + let val: mlua::Value = engine.load("return os").eval().unwrap(); + assert!(matches!(val, mlua::Value::Nil)); + + let val: mlua::Value = engine.load("return io").eval().unwrap(); + assert!(matches!(val, mlua::Value::Nil)); + + let val: mlua::Value = engine.load("return debug").eval().unwrap(); + assert!(matches!(val, mlua::Value::Nil)); + } + + #[test] + fn test_cannot_access_environment_or_execute_commands() { + let lua = new_lua().expect("Failed to create new Lua"); + + // `os.execute` shouldn't exist or be callable + let res = lua + .load("return type(os) == 'table' and type(os.execute) == 'function'") + .eval::(); + + assert!( + matches!(res, Ok(false) | Err(_)), + "os.execute should not be callable" + ); + } + + #[test] + fn test_no_file_system_access_via_package() { + let lua = new_test_lua(); + + // 'require' should not be usable + let res = lua.load("require('os')").exec(); + assert!(res.is_err()); + + // 'package' table should not exist + let res = lua.load("package").eval::(); + assert!(res.unwrap().is_nil()) + } + + #[test] + fn test_tensor_function_is_only_safe_binding() { + let lua = new_test_lua(); + + // Tensor should exist + let tensor_res = lua.load("return Tensor").eval::(); + assert!(tensor_res.is_ok()); + + // But nothing else custom + let res = lua.load("return DangerousFunction").eval::(); + assert!(res.unwrap().is_nil()); + } + + #[test] + fn test_limited_math_and_string_stdlibs() { + let lua = new_test_lua(); + + // math should work + assert_eq!(lua.load("return math.sqrt(9)").eval::().unwrap(), 3.0); + + // string manipulation should work + assert_eq!( + lua.load("return string.upper('sandbox')") + .eval::() + .unwrap(), + "SANDBOX" + ); + + // io.open should NOT exist + assert!(lua.load("return io.open").eval::().is_err()); + } +} diff --git a/encoderfile-core/src/transforms/engine/sentence_embedding.rs b/encoderfile-core/src/transforms/engine/sentence_embedding.rs new file mode 100644 index 00000000..a40065ea --- /dev/null +++ b/encoderfile-core/src/transforms/engine/sentence_embedding.rs @@ -0,0 +1,171 @@ +use crate::error::ApiError; + +use super::{super::tensor::Tensor, Postprocessor, SentenceEmbeddingTransform}; +use ndarray::{Array2, Array3, Ix2}; + +impl Postprocessor for SentenceEmbeddingTransform { + type Input = (Array3, Array2); + type Output = Array2; + + fn postprocess(&self, (data, mask): Self::Input) -> Result { + let func = match &self.postprocessor { + Some(p) => p, + None => { + let Tensor(mean_pooled) = Tensor(data.into_dyn()) + .mean_pool(Tensor(mask.into_dyn())) + .map_err(|e| { + tracing::error!( + "Failed to mean pool. This should not happen. More details: {:?}", + e + ); + ApiError::InternalError("Failed to postprocess embeddings") + })?; + + return mean_pooled.into_dimensionality::() + .map_err(|e| { + tracing::error!("Failed to cast mean pool results into Ix2. This should not happen. More details: {:?}", e); + ApiError::InternalError("Failed to postprocess embeddings") + }); + } + }; + + let batch_size = data.shape()[0]; + + let tensor = Tensor(data.into_dyn()); + + let result = func + .call::((tensor, Tensor(mask.into_dyn()))) + .map_err(|e| ApiError::LuaError(e.to_string()))? + .into_inner() + .into_dimensionality::().map_err(|e| { + tracing::error!("Failed to cast array into Ix2: {e}. Check your lua transform to make sure it returns a tensor of shape [batch_size, *]"); + ApiError::LuaError("Error postprocessing embeddings".to_string()) + })?; + + let result_shape = result.shape(); + + if batch_size != result_shape[0] { + tracing::error!( + "Transform error: expected tensor of shape [{}, *], got tensor of shape {:?}", + batch_size, + result_shape + ); + + return Err(ApiError::LuaError( + "Error postprocessing embeddings".to_string(), + )); + } + + Ok(result) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use ndarray::Axis; + + #[test] + fn test_no_pooling() { + let engine = SentenceEmbeddingTransform::new(Some("")).expect("Failed to create engine"); + + let arr = ndarray::Array3::::from_elem((16, 32, 128), 2.0); + let mask = ndarray::Array2::::from_elem((16, 32), 1.0); + + let result = engine + .postprocess((arr.clone(), mask)) + .expect("Failed to compute pool"); + + assert_eq!(result.shape(), [16, 128]); + + // if all elements are the same and all mask = 1, should return mean axis array + assert_eq!(arr.mean_axis(Axis(1)), Some(result)); + } + + #[test] + fn test_successful_pool() { + let engine = SentenceEmbeddingTransform::new(Some( + r##" + function Postprocess(arr, mask) + -- sum along second axis (lol) + return arr:sum_axis(2) + end + "##, + )) + .expect("Failed to create engine"); + + let arr = ndarray::Array3::::from_elem((16, 32, 128), 2.0); + let mask = ndarray::Array2::::from_elem((16, 32), 1.0); + + let result = engine + .postprocess((arr, mask)) + .expect("Failed to compute pool"); + + assert_eq!(result.shape(), [16, 128]) + } + + #[test] + fn test_bad_dim_pool() { + let engine = SentenceEmbeddingTransform::new(Some( + r##" + function Postprocess(arr, mask) + return arr + end + "##, + )) + .expect("Failed to create engine"); + + let arr = ndarray::Array3::::from_elem((16, 32, 128), 2.0); + let mask = ndarray::Array2::::from_elem((16, 32), 1.0); + + let result = engine.postprocess((arr, mask)); + + assert!(result.is_err()); + } + + #[test] + fn test_sentence_embedding_transform_bad_fn() { + let engine = SentenceEmbeddingTransform::new(Some( + r##" + function Postprocess(arr, mask) + return 1 + end + "##, + )) + .expect("Failed to create engine"); + + let arr = ndarray::Array3::::from_elem((16, 32, 128), 2.0); + let mask = ndarray::Array2::::from_elem((16, 32), 1.0); + + let result = engine.postprocess((arr.clone(), mask)); + + assert!(result.is_err()) + } + + #[test] + fn test_bad_dimensionality_transform_postprocessing() { + let engine = SentenceEmbeddingTransform::new(Some( + r##" + function Postprocess(arr, mask) + return arr + end + "##, + )) + .unwrap(); + + let arr = ndarray::Array3::::from_elem((3, 3, 3), 2.0); + let mask = ndarray::Array2::::from_elem((3, 3), 1.0); + let result = engine.postprocess((arr.clone(), mask)); + + assert!(result.is_err()); + + if let Err(e) = result { + match e { + ApiError::LuaError(s) => { + assert!(s.contains("Error postprocessing embeddings")) + } + _ => panic!("Didn't return lua error"), + } + } + } +} diff --git a/encoderfile-core/src/transforms/engine/sequence_classification.rs b/encoderfile-core/src/transforms/engine/sequence_classification.rs new file mode 100644 index 00000000..92610e08 --- /dev/null +++ b/encoderfile-core/src/transforms/engine/sequence_classification.rs @@ -0,0 +1,124 @@ +use crate::error::ApiError; + +use super::{super::tensor::Tensor, Postprocessor, SequenceClassificationTransform}; +use ndarray::{Array2, Ix2}; + +impl Postprocessor for SequenceClassificationTransform { + type Input = Array2; + type Output = Array2; + + fn postprocess(&self, data: Self::Input) -> Result { + let func = match self.postprocessor() { + Some(p) => p, + None => return Ok(data), + }; + + let expected_shape = data.shape().to_owned(); + + let tensor = Tensor(data.into_dyn()); + + let result = func + .call::(tensor) + .map_err(|e| ApiError::LuaError(e.to_string()))? + .into_inner() + .into_dimensionality::().map_err(|e| { + tracing::error!("Failed to cast array into Ix2: {e}. Check your lua transform to make sure it returns a tensor of shape [batch_size, num_classes]"); + ApiError::LuaError("Error postprocessing sequence classifications".to_string()) + })?; + + let result_shape = result.shape(); + + if expected_shape.as_slice() != result_shape { + tracing::error!( + "Transform error: expected tensor of shape {:?}, got tensor of shape {:?}", + expected_shape.as_slice(), + result_shape + ); + + return Err(ApiError::LuaError( + "Error postprocessing sequence classifications".to_string(), + )); + } + + Ok(result) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sequence_cls_no_transform() { + let engine = + SequenceClassificationTransform::new(Some("")).expect("Failed to create Transform"); + + let arr = ndarray::Array2::::from_elem((16, 2), 2.0); + + let result = engine.postprocess(arr.clone()).expect("Failed"); + + assert_eq!(arr, result); + } + + #[test] + fn test_seq_cls_transform() { + let engine = SequenceClassificationTransform::new(Some( + r##" + function Postprocess(arr) + return arr + end + "##, + )) + .expect("Failed to create engine"); + + let arr = ndarray::Array2::::from_elem((16, 2), 2.0); + + let result = engine.postprocess(arr.clone()).expect("Failed"); + + assert_eq!(arr, result); + } + + #[test] + fn test_seq_cls_transform_bad_fn() { + let engine = SequenceClassificationTransform::new(Some( + r##" + function Postprocess(arr) + return 1 + end + "##, + )) + .expect("Failed to create engine"); + + let arr = ndarray::Array2::::from_elem((16, 2), 2.0); + + let result = engine.postprocess(arr.clone()); + + assert!(result.is_err()) + } + + #[test] + fn test_bad_dimensionality_transform_postprocessing() { + let engine = SequenceClassificationTransform::new(Some( + r##" + function Postprocess(x) + return x:sum_axis(1) + end + "##, + )) + .unwrap(); + + let arr = ndarray::Array2::::from_elem((2, 2), 2.0); + let result = engine.postprocess(arr.clone()); + + assert!(result.is_err()); + + if let Err(e) = result { + match e { + ApiError::LuaError(s) => { + assert!(s.contains("Error postprocessing sequence classifications")) + } + _ => panic!("Didn't return lua error"), + } + } + } +} diff --git a/encoderfile-core/src/transforms/engine/token_classification.rs b/encoderfile-core/src/transforms/engine/token_classification.rs new file mode 100644 index 00000000..3b460675 --- /dev/null +++ b/encoderfile-core/src/transforms/engine/token_classification.rs @@ -0,0 +1,124 @@ +use crate::error::ApiError; + +use super::{super::tensor::Tensor, Postprocessor, TokenClassificationTransform}; +use ndarray::{Array3, Ix3}; + +impl Postprocessor for TokenClassificationTransform { + type Input = Array3; + type Output = Array3; + + fn postprocess(&self, data: Self::Input) -> Result { + let func = match self.postprocessor() { + Some(p) => p, + None => return Ok(data), + }; + + let expected_shape = data.shape().to_owned(); + + let tensor = Tensor(data.into_dyn()); + + let result = func + .call::(tensor) + .map_err(|e| ApiError::LuaError(e.to_string()))? + .into_inner() + .into_dimensionality::().map_err(|e| { + tracing::error!("Failed to cast array into Ix3: {e}. Check your lua transform to make sure it returns a tensor of shape [batch_size, seq_len, num_classes]"); + ApiError::LuaError("Error postprocessing token classifications".to_string()) + })?; + + let result_shape = result.shape(); + + if expected_shape.as_slice() != result_shape { + tracing::error!( + "Transform error: expected tensor of shape {:?}, got tensor of shape {:?}", + expected_shape.as_slice(), + result_shape + ); + + return Err(ApiError::LuaError( + "Error postprocessing token classifications".to_string(), + )); + } + + Ok(result) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_token_cls_no_transform() { + let engine = + TokenClassificationTransform::new(Some("")).expect("Failed to create Transform"); + + let arr = ndarray::Array3::::from_elem((32, 16, 2), 2.0); + + let result = engine.postprocess(arr.clone()).expect("Failed"); + + assert_eq!(arr, result); + } + + #[test] + fn test_token_cls_identity_transform() { + let engine = TokenClassificationTransform::new(Some( + r##" + function Postprocess(arr) + return arr + end + "##, + )) + .expect("Failed to create engine"); + + let arr = ndarray::Array3::::from_elem((16, 32, 2), 2.0); + + let result = engine.postprocess(arr.clone()).expect("Failed"); + + assert_eq!(arr, result); + } + + #[test] + fn test_token_cls_transform_bad_fn() { + let engine = TokenClassificationTransform::new(Some( + r##" + function Postprocess(arr) + return 1 + end + "##, + )) + .expect("Failed to create engine"); + + let arr = ndarray::Array3::::from_elem((16, 32, 2), 2.0); + + let result = engine.postprocess(arr.clone()); + + assert!(result.is_err()) + } + + #[test] + fn test_bad_dimensionality_transform_postprocessing() { + let engine = TokenClassificationTransform::new(Some( + r##" + function Postprocess(x) + return x:sum_axis(1) + end + "##, + )) + .unwrap(); + + let arr = ndarray::Array3::::from_elem((3, 3, 3), 2.0); + let result = engine.postprocess(arr.clone()); + + assert!(result.is_err()); + + if let Err(e) = result { + match e { + ApiError::LuaError(s) => { + assert!(s.contains("Error postprocessing token classifications")) + } + _ => panic!("Didn't return lua error"), + } + } + } +} diff --git a/encoderfile-core/src/transforms/mod.rs b/encoderfile-core/src/transforms/mod.rs index 89742cb5..f782d819 100644 --- a/encoderfile-core/src/transforms/mod.rs +++ b/encoderfile-core/src/transforms/mod.rs @@ -2,5 +2,5 @@ mod engine; mod tensor; mod utils; -pub use engine::Transform; +pub use engine::*; pub use tensor::Tensor; diff --git a/encoderfile-core/src/transforms/tensor/mod.rs b/encoderfile-core/src/transforms/tensor/mod.rs index e0f2b193..4ce8c2a7 100644 --- a/encoderfile-core/src/transforms/tensor/mod.rs +++ b/encoderfile-core/src/transforms/tensor/mod.rs @@ -9,6 +9,12 @@ mod tests; #[derive(Debug, Clone, PartialEq)] pub struct Tensor(pub ArrayD); +impl Tensor { + pub fn into_inner(self) -> ArrayD { + self.0 + } +} + impl FromLua for Tensor { fn from_lua(value: LuaValue, _lua: &Lua) -> Result { match value { @@ -77,7 +83,7 @@ impl LuaUserData for Tensor { impl Tensor { #[tracing::instrument(skip_all)] - fn mean_pool(&self, Tensor(mask): Tensor) -> Result { + pub fn mean_pool(&self, Tensor(mask): Tensor) -> Result { assert_eq!(self.0.ndim(), mask.ndim() + 1); let ndim = self.0.ndim(); diff --git a/encoderfile-core/tests/transforms/main.rs b/encoderfile-core/tests/transforms/main.rs index bedc6cfb..7912fd72 100644 --- a/encoderfile-core/tests/transforms/main.rs +++ b/encoderfile-core/tests/transforms/main.rs @@ -1,12 +1,15 @@ -use encoderfile_core::transforms::Transform; +use encoderfile_core::transforms::{ + EmbeddingTransform, Postprocessor, SequenceClassificationTransform, + TokenClassificationTransform, +}; use ndarray::{Array2, Array3, Axis}; use ort::tensor::ArrayExtensions; #[test] fn test_l2_normalization() { - let engine = Transform::new(include_str!( + let engine = EmbeddingTransform::new(Some(include_str!( "../../../transforms/embedding/l2_normalize_embeddings.lua" - )) + ))) .expect("Failed to create engine"); let test_arr = Array3::::from_elem((8, 16, 36), 1.0); @@ -26,9 +29,9 @@ fn test_l2_normalization() { #[test] fn test_softmax_sequence_cls() { - let engine = Transform::new(include_str!( + let engine = SequenceClassificationTransform::new(Some(include_str!( "../../../transforms/sequence_classification/softmax_logits.lua" - )) + ))) .expect("Failed to create engine"); // run on array of shape [batch_size, n_labels] @@ -45,9 +48,9 @@ fn test_softmax_sequence_cls() { #[test] fn test_softmax_token_cls() { - let engine = Transform::new(include_str!( + let engine = TokenClassificationTransform::new(Some(include_str!( "../../../transforms/token_classification/softmax_logits.lua" - )) + ))) .expect("Failed to create engine"); // run on array of shape [batch_size, n_tokens, n_labels] diff --git a/encoderfile-utils/Cargo.toml b/encoderfile-utils/Cargo.toml index 45a0989e..bb270763 100644 --- a/encoderfile-utils/Cargo.toml +++ b/encoderfile-utils/Cargo.toml @@ -2,10 +2,6 @@ name = "generate-encoderfile-config-schema" path = "bin/generate_encoderfile_config_schema.rs" -[[bin]] -name = "generate-encoderfile-cli-docs" -path = "bin/generate_cli_docs.rs" - [package] name = "encoderfile-utils" version = "0.1.2" diff --git a/encoderfile-utils/bin/generate_cli_docs.rs b/encoderfile-utils/bin/generate_cli_docs.rs deleted file mode 100644 index 092f053f..00000000 --- a/encoderfile-utils/bin/generate_cli_docs.rs +++ /dev/null @@ -1,9 +0,0 @@ -use anyhow::Result; -use clap_markdown::help_markdown; -use encoderfile::cli::Cli; - -fn main() -> Result<()> { - let markdown = help_markdown::(); - std::fs::write("docs/reference/encoderfile_util_cli.md", markdown)?; - Ok(()) -} diff --git a/encoderfile/Cargo.toml b/encoderfile/Cargo.toml index 88f81a64..bc181247 100644 --- a/encoderfile/Cargo.toml +++ b/encoderfile/Cargo.toml @@ -16,18 +16,24 @@ clap = "4.5.52" clap_derive = "4.5.49" directories = "6.0.0" lazy_static = "1.5.0" +rand = "0.9.2" schemars = "1.1.0" serde_json = "1.0.145" sha2 = "0.10.9" tera = "1.20.1" +ndarray = "0.16.1" [features] -bench = [] -_internal = [ "clap-markdown",] +dev-utils = [] [dev-dependencies] tempfile = "3.23.0" +[dependencies.encoderfile-core] +path = "../encoderfile-core" +features = ["transforms"] +default-features = false + [dependencies.figment] version = "0.10.19" features = [ "env", "serde_yaml", "yaml",] @@ -36,10 +42,6 @@ features = [ "env", "serde_yaml", "yaml",] version = "1.0.228" features = [ "serde_derive",] -[dependencies.clap-markdown] -version = "0.1.5" -optional = true - [dependencies.ort] version = "=2.0.0-rc.10" diff --git a/encoderfile/src/cli.rs b/encoderfile/src/cli.rs index 5992f9a4..7d437cb1 100644 --- a/encoderfile/src/cli.rs +++ b/encoderfile/src/cli.rs @@ -69,12 +69,18 @@ impl BuildArgs { config.encoderfile.build = false; } + // validate model config + let model_config = config.encoderfile.model_config()?; + // validate model config .encoderfile .model_type .validate_model(&config.encoderfile.path.model_weights_path()?)?; + // validate transform + crate::transforms::validate_transform(&config.encoderfile, &model_config)?; + // setup write directory let write_dir = config.encoderfile.get_generated_dir(); std::fs::create_dir_all(write_dir.join("src/")) diff --git a/encoderfile/src/config.rs b/encoderfile/src/config.rs index 13c4470c..3e9c8dbd 100644 --- a/encoderfile/src/config.rs +++ b/encoderfile/src/config.rs @@ -1,6 +1,11 @@ -use anyhow::{Result, bail}; +use anyhow::{Context, Result, bail}; +use encoderfile_core::common::ModelConfig; use schemars::JsonSchema; -use std::{io::Read, path::PathBuf}; +use std::{ + fs::File, + io::{BufReader, Read}, + path::PathBuf, +}; use super::model::ModelType; use figment::{ @@ -9,7 +14,6 @@ use figment::{ }; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; -use tera::Context; #[derive(Debug, Serialize, Deserialize, JsonSchema)] pub struct Config { @@ -34,11 +38,23 @@ pub struct EncoderfileConfig { pub output_path: Option, pub cache_dir: Option, pub transform: Option, + #[serde(default = "default_validate_transform")] + pub validate_transform: bool, #[serde(default = "default_build")] pub build: bool, } impl EncoderfileConfig { + pub fn model_config(&self) -> Result { + let model_config_path = self.path.model_config_path()?; + + let file = File::open(model_config_path)?; + + let reader = BufReader::new(file); + + serde_json::from_reader(reader).with_context(|| "Failed to deserialize model config") + } + pub fn output_path(&self) -> PathBuf { match &self.output_path { Some(p) => p.to_path_buf(), @@ -57,8 +73,8 @@ impl EncoderfileConfig { None => default_cache_dir(), } } - pub fn to_tera_ctx(&self) -> Result { - let mut ctx = Context::new(); + pub fn to_tera_ctx(&self) -> Result { + let mut ctx = tera::Context::new(); let transform = match &self.transform { None => None, @@ -102,7 +118,7 @@ impl Transform { let mut code = String::new(); - std::fs::File::open(path)?.read_to_string(&mut code)?; + File::open(path)?.read_to_string(&mut code)?; Ok(code) } @@ -166,6 +182,10 @@ fn default_build() -> bool { true } +fn default_validate_transform() -> bool { + true +} + fn encoderfile_core_version() -> &'static str { env!("ENCODERFILE_CORE_DEP_STR") } @@ -283,6 +303,7 @@ mod tests { model_type: ModelType::Embedding, output_path: Some(base.clone()), cache_dir: Some(base.clone()), + validate_transform: false, transform: None, build: true, }; @@ -303,6 +324,7 @@ mod tests { model_type: ModelType::SequenceClassification, output_path: Some(base.clone()), cache_dir: Some(base.clone()), + validate_transform: false, transform: Some(Transform::Inline("1+1".into())), build: true, }; diff --git a/encoderfile/src/lib.rs b/encoderfile/src/lib.rs index 73fdd0af..16e4891b 100644 --- a/encoderfile/src/lib.rs +++ b/encoderfile/src/lib.rs @@ -2,3 +2,4 @@ pub mod cli; pub mod config; pub mod model; pub mod templates; +pub mod transforms; diff --git a/encoderfile/src/transforms/mod.rs b/encoderfile/src/transforms/mod.rs new file mode 100644 index 00000000..99bbf983 --- /dev/null +++ b/encoderfile/src/transforms/mod.rs @@ -0,0 +1,3 @@ +mod validation; + +pub use validation::validate_transform; diff --git a/encoderfile/src/transforms/validation/embedding.rs b/encoderfile/src/transforms/validation/embedding.rs new file mode 100644 index 00000000..4e08abad --- /dev/null +++ b/encoderfile/src/transforms/validation/embedding.rs @@ -0,0 +1,120 @@ +use super::{ + TransformValidatorExt, + utils::{BATCH_SIZE, HIDDEN_DIM, SEQ_LEN, random_tensor, validation_err, validation_err_ctx}, +}; +use anyhow::{Context, Result}; +use encoderfile_core::{ + common::ModelConfig, + transforms::{EmbeddingTransform, Postprocessor}, +}; + +impl TransformValidatorExt for EmbeddingTransform { + fn dry_run(&self, _model_config: &ModelConfig) -> Result<()> { + // create dummy hidden states with shape [batch_size, seq_len, hidden_dim] + let dummy_hidden_states = random_tensor(&[BATCH_SIZE, SEQ_LEN, HIDDEN_DIM], (-1.0, 1.0))?; + let shape = dummy_hidden_states.shape().to_owned(); + + let res = self.postprocess(dummy_hidden_states) + .with_context(|| { + validation_err_ctx( + format!( + "Failed to run postprocessing on dummy logits (randomly generated in range -1.0..1.0) of shape {:?}", + shape.as_slice(), + ) + ) + })?; + + // result must return tensor of rank 3. + if res.ndim() != 3 { + validation_err(format!( + "Transform must return tensor of rank 3. Got tensor of shape {:?}.", + res.shape() + ))? + } + + // result must have same batch_size and seq_len + if res.shape()[0] != BATCH_SIZE || res.shape()[1] != SEQ_LEN { + validation_err(format!( + "Transform must preserve batch and seq dims [{} {}, *]. Got shape {:?}", + BATCH_SIZE, + SEQ_LEN, + res.shape() + ))? + } + + if res.shape()[2] < 1 { + validation_err(format!( + "Transform returned a tensor with last dimension 0. Shape: {:?}", + res.shape() + ))? + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use crate::{ + config::{EncoderfileConfig, ModelPath}, + model::ModelType, + }; + + use super::*; + + fn test_encoderfile_config() -> EncoderfileConfig { + EncoderfileConfig { + name: "my-model".to_string(), + version: "0.0.1".to_string(), + path: ModelPath::Directory(std::path::PathBuf::from("models/embedding")), + model_type: ModelType::Embedding, + cache_dir: None, + output_path: None, + transform: None, + validate_transform: true, + build: true, + } + } + + fn test_model_config() -> ModelConfig { + let config_json = include_str!("../../../../models/embedding/config.json"); + + serde_json::from_str(config_json).unwrap() + } + + #[test] + fn test_identity_validation() { + let encoderfile_config = test_encoderfile_config(); + let model_config = test_model_config(); + + EmbeddingTransform::new(Some("function Postprocess(arr) return arr end")) + .expect("Failed to create transform") + .validate(&encoderfile_config, &model_config) + .expect("Failed to validate"); + } + + #[test] + fn test_bad_return_type() { + let encoderfile_config = test_encoderfile_config(); + let model_config = test_model_config(); + + let result = EmbeddingTransform::new(Some("function Postprocess(arr) return 1 end")) + .expect("Failed to create transform") + .validate(&encoderfile_config, &model_config); + + assert!(result.is_err()); + } + + #[test] + fn test_bad_dimensionality() { + let encoderfile_config = test_encoderfile_config(); + let model_config = test_model_config(); + + let result = + EmbeddingTransform::new(Some("function Postprocess(arr) return arr:sum_axis(1) end")) + .expect("Failed to create transform") + .validate(&encoderfile_config, &model_config); + + assert!(result.is_err()); + } +} diff --git a/encoderfile/src/transforms/validation/mod.rs b/encoderfile/src/transforms/validation/mod.rs new file mode 100644 index 00000000..6dc45b47 --- /dev/null +++ b/encoderfile/src/transforms/validation/mod.rs @@ -0,0 +1,181 @@ +use anyhow::{Context, Result}; +use encoderfile_core::{common::ModelConfig, transforms::TransformSpec}; + +use crate::{config::EncoderfileConfig, model::ModelType}; + +mod embedding; +mod sentence_embedding; +mod sequence_classification; +mod token_classification; +mod utils; + +pub trait TransformValidatorExt: TransformSpec { + fn validate( + &self, + encoderfile_config: &EncoderfileConfig, + model_config: &ModelConfig, + ) -> Result<()> { + // if validate_transform set to false, return + if !encoderfile_config.validate_transform { + return Ok(()); + } + + // fail if `Postprocess` function is not found + // NOTE: This should be removed if we add any additional functions, e.g., a Preprocess function + if !self.has_postprocessor() { + utils::validation_err( + "Could not find `Postprocess` function in provided transform. Please make sure it exists.", + )? + } + + self.dry_run(model_config) + } + + fn dry_run(&self, model_config: &ModelConfig) -> Result<()>; +} + +macro_rules! validate_transform { + ($transform_type:ident, $transform_str:expr, $encoderfile_config:expr, $model_config:expr) => { + encoderfile_core::transforms::$transform_type::new($transform_str) + .with_context(|| utils::validation_err_ctx("Failed to create transform"))? + .validate($encoderfile_config, $model_config) + }; +} + +pub fn validate_transform( + encoderfile_config: &EncoderfileConfig, + model_config: &ModelConfig, +) -> Result<()> { + // try to fetch transform string + // will fail if a path to a transform does not exist + let transform_string = match &encoderfile_config.transform { + Some(t) => t.transform()?, + None => return Ok(()), + }; + + let transform_str = Some(transform_string.as_ref()); + + match encoderfile_config.model_type { + ModelType::Embedding => validate_transform!( + EmbeddingTransform, + transform_str, + encoderfile_config, + model_config + ), + ModelType::SequenceClassification => validate_transform!( + SequenceClassificationTransform, + transform_str, + encoderfile_config, + model_config + ), + ModelType::TokenClassification => validate_transform!( + TokenClassificationTransform, + transform_str, + encoderfile_config, + model_config + ), + ModelType::SentenceEmbedding => validate_transform!( + SentenceEmbeddingTransform, + transform_str, + encoderfile_config, + model_config + ), + } +} + +#[cfg(test)] +mod tests { + use encoderfile_core::transforms::EmbeddingTransform; + + use crate::config::{ModelPath, Transform}; + + use super::*; + + fn test_encoderfile_config() -> EncoderfileConfig { + EncoderfileConfig { + name: "my-model".to_string(), + version: "0.0.1".to_string(), + path: ModelPath::Directory(std::path::PathBuf::from("models/embedding")), + model_type: ModelType::Embedding, + cache_dir: None, + output_path: None, + transform: None, + validate_transform: true, + build: true, + } + } + + fn test_model_config() -> ModelConfig { + let config_json = include_str!("../../../../models/embedding/config.json"); + + serde_json::from_str(config_json).unwrap() + } + + #[test] + fn test_empty_transform() { + let result = EmbeddingTransform::new(None) + .expect("Failed to make embedding transform") + .validate(&test_encoderfile_config(), &test_model_config()); + + assert!(result.is_err()) + } + + #[test] + fn test_no_validation() { + let mut config = test_encoderfile_config(); + config.validate_transform = false; + + EmbeddingTransform::new(None) + .expect("Failed to make embedding transform") + .validate(&config, &test_model_config()) + .expect("Should be ok") + } + + #[test] + fn test_validate() { + let transform_str = "function Postprocess(arr) return arr end"; + + let encoderfile_config = EncoderfileConfig { + name: "my-model".to_string(), + version: "0.0.1".to_string(), + path: ModelPath::Directory(std::path::PathBuf::from("models/embedding")), + model_type: ModelType::Embedding, + cache_dir: None, + output_path: None, + transform: Some(Transform::Inline(transform_str.to_string())), + validate_transform: true, + build: true, + }; + + let model_config_str = + include_str!(concat!("../../../../models/", "embedding", "/config.json")); + + let model_config = + serde_json::from_str(model_config_str).expect("Failed to create model config"); + + validate_transform(&encoderfile_config, &model_config).expect("Failed to validate"); + } + + #[test] + fn test_validate_empty() { + let encoderfile_config = EncoderfileConfig { + name: "my-model".to_string(), + version: "0.0.1".to_string(), + path: ModelPath::Directory(std::path::PathBuf::from("models/embedding")), + model_type: ModelType::Embedding, + cache_dir: None, + output_path: None, + transform: None, + validate_transform: true, + build: true, + }; + + let model_config_str = + include_str!(concat!("../../../../models/", "embedding", "/config.json")); + + let model_config = + serde_json::from_str(model_config_str).expect("Failed to create model config"); + + validate_transform(&encoderfile_config, &model_config).expect("Failed to validate"); + } +} diff --git a/encoderfile/src/transforms/validation/sentence_embedding.rs b/encoderfile/src/transforms/validation/sentence_embedding.rs new file mode 100644 index 00000000..dc8102ec --- /dev/null +++ b/encoderfile/src/transforms/validation/sentence_embedding.rs @@ -0,0 +1,126 @@ +use super::{ + TransformValidatorExt, + utils::{ + BATCH_SIZE, HIDDEN_DIM, SEQ_LEN, create_dummy_attention_mask, random_tensor, + validation_err, validation_err_ctx, + }, +}; +use anyhow::{Context, Result}; +use encoderfile_core::{ + common::ModelConfig, + transforms::{Postprocessor, SentenceEmbeddingTransform}, +}; + +impl TransformValidatorExt for SentenceEmbeddingTransform { + fn dry_run(&self, _model_config: &ModelConfig) -> Result<()> { + // create dummy hidden states with shape [batch_size, seq_len, hidden_dim] + let dummy_hidden_states = random_tensor(&[BATCH_SIZE, SEQ_LEN, HIDDEN_DIM], (-1.0, 1.0))?; + let dummy_attention_mask = create_dummy_attention_mask(BATCH_SIZE, SEQ_LEN, 3)?; + let shape = dummy_hidden_states.shape().to_owned(); + + let res = self.postprocess((dummy_hidden_states, dummy_attention_mask)) + .with_context(|| { + validation_err_ctx( + format!( + "Failed to run postprocessing on dummy hidden states (randomly generated in range -1.0..1.0) of shape {:?}", + shape.as_slice(), + ) + ) + })?; + + // result must return tensor of rank 2 + if res.ndim() != 2 { + validation_err(format!( + "Transform must return tensor of rank 3. Got tensor of shape {:?}.", + res.shape() + ))? + } + + // result must have same batch_size + if res.shape()[0] != BATCH_SIZE { + validation_err(format!( + "Transform must preserve batch size [{}, *]. Got shape {:?}", + BATCH_SIZE, + res.shape() + ))? + } + + if res.shape()[1] < 1 { + validation_err(format!( + "Transform returned a tensor with last dimension 0. Shape: {:?}", + res.shape() + ))? + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use crate::{ + config::{EncoderfileConfig, ModelPath}, + model::ModelType, + }; + + use super::*; + + fn test_encoderfile_config() -> EncoderfileConfig { + EncoderfileConfig { + name: "my-model".to_string(), + version: "0.0.1".to_string(), + path: ModelPath::Directory(std::path::PathBuf::from("models/embedding")), + model_type: ModelType::SentenceEmbedding, + cache_dir: None, + output_path: None, + transform: None, + validate_transform: true, + build: true, + } + } + + fn test_model_config() -> ModelConfig { + let config_json = include_str!("../../../../models/embedding/config.json"); + + serde_json::from_str(config_json).unwrap() + } + + #[test] + fn test_successful_mean_pool() { + let encoderfile_config = test_encoderfile_config(); + let model_config = test_model_config(); + + SentenceEmbeddingTransform::new(Some( + "function Postprocess(arr, mask) return arr:mean_pool(mask) end", + )) + .expect("Failed to create transform") + .validate(&encoderfile_config, &model_config) + .expect("Failed to validate"); + } + + #[test] + fn test_bad_return_type() { + let encoderfile_config = test_encoderfile_config(); + let model_config = test_model_config(); + + let result = + SentenceEmbeddingTransform::new(Some("function Postprocess(arr, mask) return 1 end")) + .expect("Failed to create transform") + .validate(&encoderfile_config, &model_config); + + assert!(result.is_err()); + } + + #[test] + fn test_bad_dimensionality() { + let encoderfile_config = test_encoderfile_config(); + let model_config = test_model_config(); + + let result = + SentenceEmbeddingTransform::new(Some("function Postprocess(arr, mask) return arr end")) + .expect("Failed to create transform") + .validate(&encoderfile_config, &model_config); + + assert!(result.is_err()); + } +} diff --git a/encoderfile/src/transforms/validation/sequence_classification.rs b/encoderfile/src/transforms/validation/sequence_classification.rs new file mode 100644 index 00000000..c83819c2 --- /dev/null +++ b/encoderfile/src/transforms/validation/sequence_classification.rs @@ -0,0 +1,121 @@ +use super::{ + TransformValidatorExt, + utils::{BATCH_SIZE, random_tensor, validation_err, validation_err_ctx}, +}; +use anyhow::{Context, Result}; +use encoderfile_core::{ + common::ModelConfig, + transforms::{Postprocessor, SequenceClassificationTransform}, +}; + +impl TransformValidatorExt for SequenceClassificationTransform { + fn dry_run(&self, model_config: &ModelConfig) -> Result<()> { + let num_labels = match model_config.num_labels() { + Some(n) => n, + None => validation_err( + "Model config does not have `num_labels`, `id2label`, or `label2id` field. Please make sure you're using a SequenceClassification model.", + )?, + }; + + let dummy_logits = random_tensor(&[BATCH_SIZE, num_labels], (-1.0, 1.0))?; + let shape = dummy_logits.shape().to_owned(); + + let res = self.postprocess(dummy_logits) + .with_context(|| { + validation_err_ctx( + format!( + "Failed to run postprocessing on dummy logits (randomly generated in range -1.0..1.0) of shape {:?}", + shape.as_slice(), + ) + ) + })?; + + // result must return tensor of rank 2 + if res.ndim() != 2 { + validation_err(format!( + "Transform must return tensor of rank 2. Got tensor of shape {:?}.", + res.shape() + ))? + } + + // result must have same shape as original + if res.shape() != shape { + validation_err(format!( + "Transform must return Tensor of shape [batch_size, num_labels]. Expected shape [{}, {}], got shape {:?}", + BATCH_SIZE, + num_labels, + res.shape() + ))? + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use crate::{ + config::{EncoderfileConfig, ModelPath}, + model::ModelType, + }; + + use super::*; + + fn test_encoderfile_config() -> EncoderfileConfig { + EncoderfileConfig { + name: "my-model".to_string(), + version: "0.0.1".to_string(), + path: ModelPath::Directory(std::path::PathBuf::from("models/sequence_classification")), + model_type: ModelType::SequenceClassification, + cache_dir: None, + output_path: None, + transform: None, + validate_transform: true, + build: true, + } + } + + fn test_model_config() -> ModelConfig { + let config_json = include_str!("../../../../models/sequence_classification/config.json"); + + serde_json::from_str(config_json).unwrap() + } + + #[test] + fn test_identity_validation() { + let encoderfile_config = test_encoderfile_config(); + let model_config = test_model_config(); + + SequenceClassificationTransform::new(Some("function Postprocess(arr) return arr end")) + .expect("Failed to create transform") + .validate(&encoderfile_config, &model_config) + .expect("Failed to validate"); + } + + #[test] + fn test_bad_return_type() { + let encoderfile_config = test_encoderfile_config(); + let model_config = test_model_config(); + + let result = + SequenceClassificationTransform::new(Some("function Postprocess(arr) return 1 end")) + .expect("Failed to create transform") + .validate(&encoderfile_config, &model_config); + + assert!(result.is_err()); + } + + #[test] + fn test_bad_dimensionality() { + let encoderfile_config = test_encoderfile_config(); + let model_config = test_model_config(); + + let result = SequenceClassificationTransform::new(Some( + "function Postprocess(arr) return arr:sum_axis(1) end", + )) + .expect("Failed to create transform") + .validate(&encoderfile_config, &model_config); + + assert!(result.is_err()); + } +} diff --git a/encoderfile/src/transforms/validation/token_classification.rs b/encoderfile/src/transforms/validation/token_classification.rs new file mode 100644 index 00000000..55cbcabd --- /dev/null +++ b/encoderfile/src/transforms/validation/token_classification.rs @@ -0,0 +1,122 @@ +use super::{ + TransformValidatorExt, + utils::{BATCH_SIZE, SEQ_LEN, random_tensor, validation_err, validation_err_ctx}, +}; +use anyhow::{Context, Result}; +use encoderfile_core::{ + common::ModelConfig, + transforms::{Postprocessor, TokenClassificationTransform}, +}; + +impl TransformValidatorExt for TokenClassificationTransform { + fn dry_run(&self, model_config: &ModelConfig) -> Result<()> { + let num_labels = match model_config.num_labels() { + Some(n) => n, + None => validation_err( + "Model config does not have `num_labels`, `id2label`, or `label2id` field. Please make sure you're using a TokenClassification model.", + )?, + }; + + let dummy_logits = random_tensor(&[BATCH_SIZE, SEQ_LEN, num_labels], (-1.0, 1.0))?; + let shape = dummy_logits.shape().to_owned(); + + let res = self.postprocess(dummy_logits) + .with_context(|| { + validation_err_ctx( + format!( + "Failed to run postprocessing on dummy logits (randomly generated in range -1.0..1.0) of shape {:?}", + shape.as_slice(), + ) + ) + })?; + + // result must return tensor of rank 3 + if res.ndim() != 3 { + validation_err(format!( + "Transform must return tensor of rank 3. Got tensor of shape {:?}.", + res.shape() + ))? + } + + // result must have same shape as original + if res.shape() != shape { + validation_err(format!( + "Transform must return Tensor of shape [batch_size, seq_len, num_labels]. Expected shape [{}, {}, {}], got shape {:?}", + BATCH_SIZE, + SEQ_LEN, + num_labels, + res.shape() + ))? + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use crate::{ + config::{EncoderfileConfig, ModelPath}, + model::ModelType, + }; + + use super::*; + + fn test_encoderfile_config() -> EncoderfileConfig { + EncoderfileConfig { + name: "my-model".to_string(), + version: "0.0.1".to_string(), + path: ModelPath::Directory(std::path::PathBuf::from("models/token_classification")), + model_type: ModelType::TokenClassification, + cache_dir: None, + output_path: None, + transform: None, + validate_transform: true, + build: true, + } + } + + fn test_model_config() -> ModelConfig { + let config_json = include_str!("../../../../models/token_classification/config.json"); + + serde_json::from_str(config_json).unwrap() + } + + #[test] + fn test_identity_validation() { + let encoderfile_config = test_encoderfile_config(); + let model_config = test_model_config(); + + TokenClassificationTransform::new(Some("function Postprocess(arr) return arr end")) + .expect("Failed to create transform") + .validate(&encoderfile_config, &model_config) + .expect("Failed to validate"); + } + + #[test] + fn test_bad_return_type() { + let encoderfile_config = test_encoderfile_config(); + let model_config = test_model_config(); + + let result = + TokenClassificationTransform::new(Some("function Postprocess(arr) return 1 end")) + .expect("Failed to create transform") + .validate(&encoderfile_config, &model_config); + + assert!(result.is_err()); + } + + #[test] + fn test_bad_dimensionality() { + let encoderfile_config = test_encoderfile_config(); + let model_config = test_model_config(); + + let result = TokenClassificationTransform::new(Some( + "function Postprocess(arr) return arr:sum_axis(1) end", + )) + .expect("Failed to create transform") + .validate(&encoderfile_config, &model_config); + + assert!(result.is_err()); + } +} diff --git a/encoderfile/src/transforms/validation/utils.rs b/encoderfile/src/transforms/validation/utils.rs new file mode 100644 index 00000000..ebe3b6bf --- /dev/null +++ b/encoderfile/src/transforms/validation/utils.rs @@ -0,0 +1,60 @@ +use anyhow::{Context, Result, bail}; +use ndarray::{Array, ArrayD, Dimension, IxDyn}; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; + +pub const ERR_HEADER: &str = "❌ Transform validation failed"; +const SEED: u64 = 42; + +// test values +pub const BATCH_SIZE: usize = 32; +pub const SEQ_LEN: usize = 128; +pub const HIDDEN_DIM: usize = 384; + +pub fn random_tensor( + shape: &[usize], + (range_start, range_end): (f32, f32), +) -> Result> { + let mut rng = StdRng::seed_from_u64(SEED); + + let total = shape.iter().product(); + + ArrayD::from_shape_vec( + shape, + (0..total) + .map(|_| rng.random_range(range_start..range_end)) + .collect(), + ) + .and_then(|i| i.into_dimensionality::()) + .with_context( + || validation_err_ctx("Failed to construct random ArrayD for dry-run validation. This shouldn't happen. More details"), + ) +} + +pub fn create_dummy_attention_mask( + batch: usize, + seq: usize, + pad_up_to: usize, +) -> Result> { + let mut data = Vec::with_capacity(batch * seq); + + for _ in 0..batch { + let real = seq - pad_up_to; + data.extend(std::iter::repeat_n(1.0, real)); + data.extend(std::iter::repeat_n(0.0, pad_up_to)); + } + + ArrayD::from_shape_vec(IxDyn(&[batch, seq]), data) + .and_then(|i| i.into_dimensionality::()) + .with_context( + || validation_err_ctx("Failed to construct dummy attention_mask for dry-run validation. This shouldn't happen. More details"), + ) +} + +pub fn validation_err_ctx(msg: T) -> String { + format!("{}: {}", ERR_HEADER, msg) +} + +pub fn validation_err(msg: E) -> Result { + bail!("{}: {}", ERR_HEADER, msg) +} diff --git a/encoderfile/templates/Cargo.toml.tera b/encoderfile/templates/Cargo.toml.tera index 28c763bc..01ee1359 100644 --- a/encoderfile/templates/Cargo.toml.tera +++ b/encoderfile/templates/Cargo.toml.tera @@ -4,7 +4,6 @@ version = "{{ version }}" edition = "2024" [dependencies] -clap = "4.5.52" anyhow = "1.0.100" tokio = { version = "1.48.0", features = ["full"] } diff --git a/encoderfile/templates/main.rs.tera b/encoderfile/templates/main.rs.tera index 200f976f..da8055fb 100644 --- a/encoderfile/templates/main.rs.tera +++ b/encoderfile/templates/main.rs.tera @@ -1,18 +1,9 @@ // Generated by encoderfile ❤️ DO NOT MODIFY use anyhow::Result; -use clap::Parser; use encoderfile_core::{ factory, - cli::Cli, - runtime::{ - get_model, - get_model_config, - get_tokenizer, - get_model_type, - get_transform, - }, - AppState, + cli_entrypoint, }; factory! { @@ -28,26 +19,12 @@ factory! { #[tokio::main] async fn main() -> Result<()> { - let cli = Cli::parse(); - - let session = get_model(&assets::MODEL_WEIGHTS); - let config = get_model_config(assets::MODEL_CONFIG_JSON); - let tokenizer = get_tokenizer(assets::TOKENIZER_JSON, &config); - let model_type = get_model_type(assets::MODEL_TYPE_STR); - let model_id = assets::MODEL_ID.to_string(); - let transform_factory = || get_transform(assets::TRANSFORM); - - - let state = AppState { - session, - config, - tokenizer, - model_type, - model_id, - transform_factory, - }; - - cli.command.execute(state).await?; - - Ok(()) + cli_entrypoint( + &assets::MODEL_WEIGHTS, + assets::MODEL_CONFIG_JSON, + assets::TOKENIZER_JSON, + assets::MODEL_TYPE_STR, + assets::MODEL_ID, + assets::TRANSFORM, + ).await } diff --git a/schemas/encoderfile-config-schema.json b/schemas/encoderfile-config-schema.json index 99d99494..44a0e0d3 100644 --- a/schemas/encoderfile-config-schema.json +++ b/schemas/encoderfile-config-schema.json @@ -49,6 +49,10 @@ } ] }, + "validate_transform": { + "type": "boolean", + "default": true + }, "version": { "type": "string", "default": "0.1.0"