-
Notifications
You must be signed in to change notification settings - Fork 3
feat: validate transforms on compile #151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
08fcf86
6a51af0
3e94228
9410066
3b44179
07cfc94
338a105
1e3ea10
38601c9
9109959
ab5cfbf
557ce5f
0bde95e
72dc2b4
e0f818f
ae483cc
cb6bb36
834dd7e
e157b8e
32746ca
9769eb3
9447673
5c34e3c
7d19be0
2128168
031257e
5b16b52
e5b2cae
0710e1d
c917425
0f5255d
82dcc7d
f3eb55e
7e2f3c2
aaa42dc
f3bd438
b06639f
c6bb0d9
f00d16c
100b001
bb08b82
7d5c00c
7d79685
b323fb9
33017d1
f9e0f78
922f7a9
faf1fab
fc0e997
1cb9583
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| use serde::{Deserialize, Serialize}; | ||
| use std::collections::HashMap; | ||
|
|
||
| #[derive(Debug, Serialize, Deserialize)] | ||
| pub struct ModelConfig { | ||
| pub model_type: String, | ||
| pub pad_token_id: u32, | ||
| pub num_labels: Option<usize>, | ||
| pub id2label: Option<HashMap<u32, String>>, | ||
| pub label2id: Option<HashMap<String, u32>>, | ||
| } | ||
|
|
||
| impl ModelConfig { | ||
| pub fn id2label(&self, id: u32) -> Option<&str> { | ||
| self.id2label.as_ref()?.get(&id).map(|s| s.as_str()) | ||
| } | ||
|
|
||
| pub fn label2id(&self, label: &str) -> Option<u32> { | ||
| self.label2id.as_ref()?.get(label).copied() | ||
| } | ||
|
|
||
| pub fn num_labels(&self) -> Option<usize> { | ||
| if self.num_labels.is_some() { | ||
| return self.num_labels; | ||
| } | ||
|
|
||
| if let Some(id2label) = &self.id2label { | ||
| return Some(id2label.len()); | ||
| } | ||
|
|
||
| if let Some(label2id) = &self.label2id { | ||
| return Some(label2id.len()); | ||
| } | ||
|
|
||
| None | ||
| } | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; | ||
|
|
||
| #[test] | ||
| fn test_num_labels() { | ||
| let test_labels: Vec<(String, u32)> = vec![("a", 1), ("b", 2), ("c", 3)] | ||
| .into_iter() | ||
| .map(|(i, j)| (i.to_string(), j)) | ||
| .collect(); | ||
|
|
||
| let label2id: HashMap<String, u32> = test_labels.clone().into_iter().collect(); | ||
| let id2label: HashMap<u32, String> = test_labels | ||
| .clone() | ||
| .into_iter() | ||
| .map(|(i, j)| (j, i)) | ||
| .collect(); | ||
|
|
||
| let config = ModelConfig { | ||
| model_type: "MyModel".to_string(), | ||
| pad_token_id: 0, | ||
| num_labels: Some(3), | ||
| id2label: Some(id2label.clone()), | ||
| label2id: Some(label2id.clone()), | ||
| }; | ||
|
|
||
| assert_eq!(config.num_labels(), Some(3)); | ||
|
|
||
| let config = ModelConfig { | ||
| model_type: "MyModel".to_string(), | ||
| pad_token_id: 0, | ||
| num_labels: None, | ||
| id2label: Some(id2label.clone()), | ||
| label2id: Some(label2id.clone()), | ||
| }; | ||
|
|
||
| assert_eq!(config.num_labels(), Some(3)); | ||
|
|
||
| let config = ModelConfig { | ||
| model_type: "MyModel".to_string(), | ||
| pad_token_id: 0, | ||
| num_labels: None, | ||
| id2label: None, | ||
| label2id: Some(label2id.clone()), | ||
| }; | ||
|
|
||
| assert_eq!(config.num_labels(), Some(3)); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,9 @@ | ||
| #[derive(Debug, Clone, serde::Serialize, utoipa::ToSchema)] | ||
| #[serde(rename_all = "snake_case")] | ||
| #[repr(u8)] | ||
| pub enum ModelType { | ||
| Embedding, | ||
| SequenceClassification, | ||
| TokenClassification, | ||
| SentenceEmbedding, | ||
| Embedding = 1, | ||
| SequenceClassification = 2, | ||
| TokenClassification = 3, | ||
| SentenceEmbedding = 4, | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ use crate::{ | |
| common::{TokenEmbedding, TokenEmbeddingSequence, TokenInfo}, | ||
| error::ApiError, | ||
| runtime::AppState, | ||
| transforms::{EmbeddingTransform, Postprocessor}, | ||
| }; | ||
|
|
||
| #[tracing::instrument(skip_all)] | ||
|
|
@@ -24,7 +25,7 @@ pub fn embedding<'a>( | |
| .expect("Model does not return tensor of shape [n_batch, n_tokens, hidden_dim]") | ||
| .into_owned(); | ||
|
|
||
| outputs = state.transform().postprocess(outputs)?; | ||
| outputs = EmbeddingTransform::new(state.transform_str())?.postprocess(outputs)?; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would allow changing the transform at runtime. Do we aim to do that at some point?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was basically already the case. We can improve though, I think this is a good flag |
||
|
|
||
| let embeddings = postprocess(outputs, encodings); | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect this is done for classification? Does it make sense to group class-related items under a
classificationkey?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@javiermtorres this is taken from the ModelConfig schema from huggingface. unfortunately can't change it :')