Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
08fcf86
add encoderfile core dep
besaleli Dec 3, 2025
6a51af0
add transform validator
besaleli Dec 3, 2025
3e94228
update
besaleli Dec 3, 2025
9410066
update
besaleli Dec 3, 2025
3b44179
update
besaleli Dec 3, 2025
07cfc94
update
besaleli Dec 3, 2025
338a105
update
besaleli Dec 3, 2025
1e3ea10
update
besaleli Dec 3, 2025
38601c9
update
besaleli Dec 3, 2025
9109959
seq cls validation
besaleli Dec 3, 2025
ab5cfbf
fmt
besaleli Dec 3, 2025
557ce5f
add token classification validation
besaleli Dec 3, 2025
0bde95e
add sentence embedding validation
besaleli Dec 3, 2025
72dc2b4
add sentence embedding validation
besaleli Dec 3, 2025
e0f818f
update configs
besaleli Dec 3, 2025
ae483cc
update docs
besaleli Dec 3, 2025
cb6bb36
update makefile
besaleli Dec 3, 2025
834dd7e
update makefile
besaleli Dec 3, 2025
e157b8e
update
besaleli Dec 3, 2025
32746ca
clippy
besaleli Dec 3, 2025
9769eb3
update
besaleli Dec 4, 2025
9447673
simplify
besaleli Dec 4, 2025
5c34e3c
simplify
besaleli Dec 4, 2025
7d19be0
simplify
besaleli Dec 4, 2025
2128168
simplify
besaleli Dec 4, 2025
031257e
remove old code
besaleli Dec 4, 2025
5b16b52
update
besaleli Dec 4, 2025
e5b2cae
fmt
besaleli Dec 4, 2025
0710e1d
update Makefile
besaleli Dec 4, 2025
c917425
update Makefile
besaleli Dec 4, 2025
0f5255d
add tests
besaleli Dec 4, 2025
82dcc7d
fmt
besaleli Dec 4, 2025
f3eb55e
more tests
besaleli Dec 4, 2025
7e2f3c2
fmt
besaleli Dec 4, 2025
aaa42dc
update
besaleli Dec 4, 2025
f3bd438
add bad dimensionality tests
besaleli Dec 4, 2025
b06639f
add bad dimensionality tests
besaleli Dec 4, 2025
c6bb0d9
update
besaleli Dec 4, 2025
f00d16c
update
besaleli Dec 4, 2025
100b001
update
besaleli Dec 5, 2025
bb08b82
update
besaleli Dec 5, 2025
7d5c00c
update
besaleli Dec 5, 2025
7d79685
add tests
besaleli Dec 5, 2025
b323fb9
more tests
besaleli Dec 5, 2025
33017d1
update
besaleli Dec 5, 2025
f9e0f78
dry
besaleli Dec 5, 2025
922f7a9
dry
besaleli Dec 5, 2025
faf1fab
add entrypoint
besaleli Dec 5, 2025
fc0e997
clippy
besaleli Dec 5, 2025
1cb9583
fix feature gate
besaleli Dec 5, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ jobs:
run: cargo binstall cargo-llvm-cov --force --no-confirm

- name: Generate coverage
run: cargo llvm-cov --workspace --all-features --lcov --output-path lcov.info
run: make coverage

- name: Upload to codecov
uses: codecov/codecov-action@v2
Expand Down Expand Up @@ -163,7 +163,10 @@ jobs:
with:
toolchain: nightly

- name: Install test dependencies
run: make setup

- run: rustup component add --toolchain nightly-x86_64-unknown-linux-gnu clippy

- name: Run Clippy
run: cargo clippy --all-features --all-targets -- -D warnings
run: make lint
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"rust-analyzer.cargo.features": ["bench"],
"rust-analyzer.cargo.features": ["dev-utils"],
"rust-analyzer.procMacro.enable": true,
"rust-analyzer.cargo.buildScripts.enable": true,
"rust-analyzer.linkedProjects": [
Expand Down
4 changes: 3 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 16 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,22 @@ format:
@echo "Formatting rust..."
@cargo fmt

.PHONY: lint
lint:
cargo clippy \
--all-features \
--all-targets \
-- \
-D warnings

.PHONY: coverage
coverage:
cargo llvm-cov \
--workspace \
--all-features \
--lcov \
--output-path lcov.info

,PHONY: licenses
licenses:
@echo "Generating licenses..."
Expand Down Expand Up @@ -59,5 +75,3 @@ generate-docs:
# generate JSON schema for encoderfile config
@cargo run \
--bin generate-encoderfile-config-schema
# generate CLI docs for encoderfile build
@cargo run --bin generate-encoderfile-cli-docs --features="_internal"
5 changes: 4 additions & 1 deletion docs/reference/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ encoderfile:
transform:
path: ./transforms/normalize.lua
# OR inline transform:
# transform: "return normalize(output)"
# transform: "function Postprocess(logits) return logits:lp_normalize(2.0, 2.0) end"

# Whether to validate transform with a dry-run (optional, defaults to true)
validate_transform: true

# Whether to build the binary (optional, defaults to true)
build: true
Expand Down
13 changes: 7 additions & 6 deletions encoderfile-core/benches/benchmark_transforms.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use encoderfile_core::transforms::Postprocessor;
use ndarray::{Array2, Array3};
use rand::Rng;

Expand All @@ -21,9 +22,9 @@ fn get_random_3d(x: usize, y: usize, z: usize) -> Array3<f32> {

#[divan::bench(args = [(16, 16, 16), (32, 128, 384), (32, 256, 768)])]
fn bench_embedding_l2_normalization(bencher: divan::Bencher, (x, y, z): (usize, usize, usize)) {
let engine = encoderfile_core::transforms::Transform::new(include_str!(
let engine = encoderfile_core::transforms::EmbeddingTransform::new(Some(include_str!(
"../../transforms/embedding/l2_normalize_embeddings.lua"
))
)))
.unwrap();

let test_tensor = get_random_3d(x, y, z);
Expand All @@ -35,8 +36,8 @@ fn bench_embedding_l2_normalization(bencher: divan::Bencher, (x, y, z): (usize,

#[divan::bench(args = [(16, 2), (32, 8), (128, 32)])]
fn bench_seq_cls_softmax(bencher: divan::Bencher, (x, y): (usize, usize)) {
let engine = encoderfile_core::transforms::Transform::new(include_str!(
"../../transforms/sequence_classification/softmax_logits.lua"
let engine = encoderfile_core::transforms::SequenceClassificationTransform::new(Some(
include_str!("../../transforms/sequence_classification/softmax_logits.lua"),
))
.unwrap();

Expand All @@ -49,8 +50,8 @@ fn bench_seq_cls_softmax(bencher: divan::Bencher, (x, y): (usize, usize)) {

#[divan::bench(args = [(16, 16, 2), (32, 128, 8), (128, 256, 32)])]
fn bench_tok_cls_softmax(bencher: divan::Bencher, (x, y, z): (usize, usize, usize)) {
let engine = encoderfile_core::transforms::Transform::new(include_str!(
"../../transforms/token_classification/softmax_logits.lua"
let engine = encoderfile_core::transforms::TokenClassificationTransform::new(Some(
include_str!("../../transforms/token_classification/softmax_logits.lua"),
))
.unwrap();

Expand Down
34 changes: 33 additions & 1 deletion encoderfile-core/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,50 @@ use crate::{
EmbeddingRequest, ModelType, SentenceEmbeddingRequest, SequenceClassificationRequest,
TokenClassificationRequest,
},
runtime::AppState,
runtime::{AppState, get_model, get_model_config, get_model_type, get_tokenizer},
server::{run_grpc, run_http, run_mcp},
services::{embedding, sentence_embedding, sequence_classification, token_classification},
};
use anyhow::Result;
use clap::Parser;
use clap_derive::{Parser, Subcommand, ValueEnum};
use opentelemetry::trace::TracerProvider as _;
use opentelemetry_otlp::{Protocol, WithExportConfig};
use opentelemetry_sdk::trace::SdkTracerProvider;
use std::{fmt::Display, io::Write};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

pub async fn cli_entrypoint(
model_bytes: &[u8],
config_str: &str,
tokenizer_json: &str,
model_type: &str,
model_id: &str,
transform_str: Option<&str>,
) -> Result<()> {
let cli = Cli::parse();

let session = get_model(model_bytes);
let config = get_model_config(config_str);
let tokenizer = get_tokenizer(tokenizer_json, &config);
let model_type = get_model_type(model_type);
let transform_str = transform_str.map(|t| t.to_string());
let model_id = model_id.to_string();

let state = AppState {
session,
config,
tokenizer,
model_type,
model_id,
transform_str,
};

cli.command.execute(state).await?;

Ok(())
}

macro_rules! generate_cli_route {
($req:ident, $fn:path, $format:ident, $out_dir:expr, $state:expr) => {{
let result = $fn($req, &$state)?;
Expand Down
2 changes: 2 additions & 0 deletions encoderfile-core/src/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod embedding;
mod model_config;
mod model_metadata;
mod model_type;
mod sentence_embedding;
Expand All @@ -7,6 +8,7 @@ mod token;
mod token_classification;

pub use embedding::*;
pub use model_config::*;
pub use model_metadata::*;
pub use model_type::*;
pub use sentence_embedding::*;
Expand Down
87 changes: 87 additions & 0 deletions encoderfile-core/src/common/model_config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

#[derive(Debug, Serialize, Deserialize)]
pub struct ModelConfig {
pub model_type: String,
pub pad_token_id: u32,
pub num_labels: Option<usize>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect this is done for classification? Does it make sense to group class-related items under a classification key?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@javiermtorres this is taken from the ModelConfig schema from huggingface. unfortunately can't change it :')

pub id2label: Option<HashMap<u32, String>>,
pub label2id: Option<HashMap<String, u32>>,
}

impl ModelConfig {
pub fn id2label(&self, id: u32) -> Option<&str> {
self.id2label.as_ref()?.get(&id).map(|s| s.as_str())
}

pub fn label2id(&self, label: &str) -> Option<u32> {
self.label2id.as_ref()?.get(label).copied()
}

pub fn num_labels(&self) -> Option<usize> {
if self.num_labels.is_some() {
return self.num_labels;
}

if let Some(id2label) = &self.id2label {
return Some(id2label.len());
}

if let Some(label2id) = &self.label2id {
return Some(label2id.len());
}

None
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_num_labels() {
let test_labels: Vec<(String, u32)> = vec![("a", 1), ("b", 2), ("c", 3)]
.into_iter()
.map(|(i, j)| (i.to_string(), j))
.collect();

let label2id: HashMap<String, u32> = test_labels.clone().into_iter().collect();
let id2label: HashMap<u32, String> = test_labels
.clone()
.into_iter()
.map(|(i, j)| (j, i))
.collect();

let config = ModelConfig {
model_type: "MyModel".to_string(),
pad_token_id: 0,
num_labels: Some(3),
id2label: Some(id2label.clone()),
label2id: Some(label2id.clone()),
};

assert_eq!(config.num_labels(), Some(3));

let config = ModelConfig {
model_type: "MyModel".to_string(),
pad_token_id: 0,
num_labels: None,
id2label: Some(id2label.clone()),
label2id: Some(label2id.clone()),
};

assert_eq!(config.num_labels(), Some(3));

let config = ModelConfig {
model_type: "MyModel".to_string(),
pad_token_id: 0,
num_labels: None,
id2label: None,
label2id: Some(label2id.clone()),
};

assert_eq!(config.num_labels(), Some(3));
}
}
9 changes: 5 additions & 4 deletions encoderfile-core/src/common/model_type.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#[derive(Debug, Clone, serde::Serialize, utoipa::ToSchema)]
#[serde(rename_all = "snake_case")]
#[repr(u8)]
pub enum ModelType {
Embedding,
SequenceClassification,
TokenClassification,
SentenceEmbedding,
Embedding = 1,
SequenceClassification = 2,
TokenClassification = 3,
SentenceEmbedding = 4,
}
7 changes: 3 additions & 4 deletions encoderfile-core/src/dev_utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::{
common::ModelType,
runtime::{AppState, ModelConfig},
transforms::Transform,
common::{ModelConfig, ModelType},
runtime::AppState,
};
use ort::session::Session;
use parking_lot::Mutex;
Expand All @@ -22,7 +21,7 @@ pub fn get_state(dir: &str, model_type: ModelType) -> AppState {
config,
model_type,
model_id: "test-model".to_string(),
transform_factory: || Transform::new("").unwrap(),
transform_str: None,
}
}

Expand Down
3 changes: 2 additions & 1 deletion encoderfile-core/src/inference/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{
common::{TokenEmbedding, TokenEmbeddingSequence, TokenInfo},
error::ApiError,
runtime::AppState,
transforms::{EmbeddingTransform, Postprocessor},
};

#[tracing::instrument(skip_all)]
Expand All @@ -24,7 +25,7 @@ pub fn embedding<'a>(
.expect("Model does not return tensor of shape [n_batch, n_tokens, hidden_dim]")
.into_owned();

outputs = state.transform().postprocess(outputs)?;
outputs = EmbeddingTransform::new(state.transform_str())?.postprocess(outputs)?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would allow changing the transform at runtime. Do we aim to do that at some point?

Copy link
Member Author

@besaleli besaleli Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was basically already the case. We can improve though, I think this is a good flag


let embeddings = postprocess(outputs, encodings);

Expand Down
11 changes: 8 additions & 3 deletions encoderfile-core/src/inference/sentence_embedding.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
use ndarray::{Array2, Axis, Ix2, Ix3};
use tokenizers::Encoding;

use crate::{common::SentenceEmbedding, error::ApiError, runtime::AppState};
use crate::{
common::SentenceEmbedding,
error::ApiError,
runtime::AppState,
transforms::{Postprocessor, SentenceEmbeddingTransform},
};

#[tracing::instrument(skip_all)]
pub fn sentence_embedding<'a>(
Expand All @@ -28,9 +33,9 @@ pub fn sentence_embedding<'a>(
.expect("Model does not return tensor of shape [n_batch, n_tokens, hidden_dim]")
.into_owned();

let transform = state.transform();
let transform = SentenceEmbeddingTransform::new(state.transform_str())?;

let pooled_outputs = transform.pool(outputs, a_mask_arr)?;
let pooled_outputs = transform.postprocess((outputs, a_mask_arr))?;

let embeddings = postprocess(pooled_outputs, encodings);

Expand Down
7 changes: 4 additions & 3 deletions encoderfile-core/src/inference/sequence_classification.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::{
common::SequenceClassificationResult,
common::{ModelConfig, SequenceClassificationResult},
error::ApiError,
runtime::{AppState, ModelConfig},
runtime::AppState,
transforms::{Postprocessor, SequenceClassificationTransform},
};
use ndarray::{Array2, Axis, Ix2};
use ndarray_stats::QuantileExt;
Expand All @@ -24,7 +25,7 @@ pub fn sequence_classification<'a>(
.expect("Model does not return tensor of shape [n_batch, n_labels]")
.into_owned();

outputs = state.transform().postprocess(outputs)?;
outputs = SequenceClassificationTransform::new(state.transform_str())?.postprocess(outputs)?;

let results = postprocess(outputs, &state.config);

Expand Down
7 changes: 4 additions & 3 deletions encoderfile-core/src/inference/token_classification.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::{
common::{TokenClassification, TokenClassificationResult, TokenInfo},
common::{ModelConfig, TokenClassification, TokenClassificationResult, TokenInfo},
error::ApiError,
runtime::{AppState, ModelConfig},
runtime::AppState,
transforms::{Postprocessor, TokenClassificationTransform},
};
use ndarray::{Array3, Axis, Ix3};
use ndarray_stats::QuantileExt;
Expand All @@ -24,7 +25,7 @@ pub fn token_classification<'a>(
.expect("Model does not return tensor of shape [n_batch, n_tokens, n_labels]")
.into_owned();

outputs = state.transform().postprocess(outputs)?;
outputs = TokenClassificationTransform::new(state.transform_str())?.postprocess(outputs)?;

let predictions = postprocess(outputs, encodings, &state.config);

Expand Down
Loading
Loading