Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions backends/llamacpp/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ struct Args {
/// Maximum payload size in bytes.
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,

/// Maximum image fetch size in bytes.
#[clap(default_value = "1073741824", long, env)]
max_image_fetch_size: usize,
}

#[tokio::main]
Expand Down Expand Up @@ -320,6 +324,7 @@ async fn main() -> Result<(), RouterError> {
args.max_client_batch_size,
args.usage_stats,
args.payload_limit,
args.max_image_fetch_size,
args.prometheus_port,
)
.await?;
Expand Down
4 changes: 4 additions & 0 deletions backends/trtllm/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ struct Args {
usage_stats: UsageStatsLevel,
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
#[clap(default_value = "1073741824", long, env)]
max_image_fetch_size: usize,
}

async fn get_tokenizer(tokenizer_name: &str, revision: Option<&str>) -> Option<Tokenizer> {
Expand Down Expand Up @@ -244,6 +246,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
executor_worker,
usage_stats,
payload_limit,
max_image_fetch_size,
} = args;

// Launch Tokio runtime
Expand Down Expand Up @@ -325,6 +328,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
max_client_batch_size,
usage_stats,
payload_limit,
max_image_fetch_size,
prometheus_port,
)
.await?;
Expand Down
4 changes: 4 additions & 0 deletions backends/v2/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ struct Args {
usage_stats: usage_stats::UsageStatsLevel,
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
#[clap(default_value = "1073741824", long, env)]
max_image_fetch_size: usize,
}

#[derive(Debug, Subcommand)]
Expand Down Expand Up @@ -120,6 +122,7 @@ async fn main() -> Result<(), RouterError> {
max_client_batch_size,
usage_stats,
payload_limit,
max_image_fetch_size,
} = args;

if let Some(Commands::PrintSchema) = command {
Expand Down Expand Up @@ -201,6 +204,7 @@ async fn main() -> Result<(), RouterError> {
max_client_batch_size,
usage_stats,
payload_limit,
max_image_fetch_size,
prometheus_port,
)
.await?;
Expand Down
4 changes: 4 additions & 0 deletions backends/v3/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ struct Args {
usage_stats: usage_stats::UsageStatsLevel,
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
#[clap(default_value = "1073741824", long, env)]
max_image_fetch_size: usize,
}

#[derive(Debug, Subcommand)]
Expand Down Expand Up @@ -120,6 +122,7 @@ async fn main() -> Result<(), RouterError> {
max_client_batch_size,
usage_stats,
payload_limit,
max_image_fetch_size,
} = args;

if let Some(Commands::PrintSchema) = command {
Expand Down Expand Up @@ -217,6 +220,7 @@ async fn main() -> Result<(), RouterError> {
max_client_batch_size,
usage_stats,
payload_limit,
max_image_fetch_size,
prometheus_port,
)
.await?;
Expand Down
2 changes: 1 addition & 1 deletion integration-tests/models/test_flash_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

@pytest.fixture(scope="module")
def flash_llama_handle(launcher):
with launcher("huggingface/llama-7b", num_shard=2) as handle:
with launcher("huggyllama/llama-7b", num_shard=2) as handle:
yield handle


Expand Down
3 changes: 3 additions & 0 deletions integration-tests/models/test_flash_llama_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ async def flash_llama_fp8(flash_llama_fp8_handle):
return flash_llama_fp8_handle.client


@pytest.mark.skip(reason="Issue with the model access")
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
Expand All @@ -26,6 +27,7 @@ async def test_flash_llama_fp8(flash_llama_fp8, response_snapshot):
assert response == response_snapshot


@pytest.mark.skip(reason="Issue with the model access")
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
Expand All @@ -49,6 +51,7 @@ async def test_flash_llama_fp8_all_params(flash_llama_fp8, response_snapshot):
assert response == response_snapshot


@pytest.mark.skip(reason="Issue with the model access")
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
Expand Down
3 changes: 3 additions & 0 deletions integration-tests/models/test_flash_llama_marlin_24.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ async def flash_llama_marlin(flash_llama_marlin24_handle):
return flash_llama_marlin24_handle.client


@pytest.mark.skip(reason="Issue with the model access")
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
Expand All @@ -27,6 +28,7 @@ async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
assert response == response_snapshot


@pytest.mark.skip(reason="Issue with the model access")
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
Expand All @@ -50,6 +52,7 @@ async def test_flash_llama_marlin24_all_params(flash_llama_marlin, response_snap
assert response == response_snapshot


@pytest.mark.skip(reason="Issue with the model access")
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
Expand Down
2 changes: 1 addition & 1 deletion router/src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ mod tests {
let (name, arguments) = get_tool_call_content(&events[0]);
if let Some(name) = name {
assert_eq!(name, "get_current_weather");
output_name.push_str(&name);
output_name.push_str(name);
}
output.push_str(arguments);
} else {
Expand Down
4 changes: 4 additions & 0 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,7 @@ pub async fn run(
max_client_batch_size: usize,
usage_stats_level: usage_stats::UsageStatsLevel,
payload_limit: usize,
max_image_fetch_size: usize,
prometheus_port: u16,
) -> Result<(), WebServerError> {
// CORS allowed origins
Expand Down Expand Up @@ -1827,6 +1828,7 @@ pub async fn run(
compat_return_full_text,
allow_origin,
payload_limit,
max_image_fetch_size,
prometheus_port,
)
.await;
Expand Down Expand Up @@ -1889,6 +1891,7 @@ async fn start(
compat_return_full_text: bool,
allow_origin: Option<AllowOrigin>,
payload_limit: usize,
max_image_fetch_size: usize,
prometheus_port: u16,
) -> Result<(), WebServerError> {
// Determine the server port based on the feature and environment variable.
Expand Down Expand Up @@ -1920,6 +1923,7 @@ async fn start(
max_input_tokens,
max_total_tokens,
disable_grammar_support,
max_image_fetch_size,
);

let infer = Infer::new(
Expand Down
51 changes: 47 additions & 4 deletions router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use rand::{thread_rng, Rng};
use serde_json::Value;
/// Payload validation logic
use std::cmp::min;
use std::io::Cursor;
use std::io::{Cursor, Read};
use std::iter;
use std::sync::Arc;
use thiserror::Error;
Expand Down Expand Up @@ -51,6 +51,7 @@ impl Validation {
max_input_length: usize,
max_total_tokens: usize,
disable_grammar_support: bool,
max_image_fetch_size: usize,
) -> Self {
let workers = if let Tokenizer::Python { .. } = &tokenizer {
1
Expand Down Expand Up @@ -78,6 +79,7 @@ impl Validation {
config_clone,
preprocessor_config_clone,
tokenizer_receiver,
max_image_fetch_size,
)
});
}
Expand Down Expand Up @@ -480,6 +482,7 @@ fn tokenizer_worker(
config: Option<Config>,
preprocessor_config: Option<HubPreprocessorConfig>,
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
max_image_fetch_size: usize,
) {
match tokenizer {
Tokenizer::Python {
Expand All @@ -503,6 +506,7 @@ fn tokenizer_worker(
&tokenizer,
config.as_ref(),
preprocessor_config.as_ref(),
max_image_fetch_size,
))
.unwrap_or(())
})
Expand All @@ -524,6 +528,7 @@ fn tokenizer_worker(
&tokenizer,
config.as_ref(),
preprocessor_config.as_ref(),
max_image_fetch_size,
))
.unwrap_or(())
})
Expand Down Expand Up @@ -562,10 +567,35 @@ fn format_to_mimetype(format: ImageFormat) -> String {
.to_string()
}

fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
fn fetch_image(
input: &str,
max_image_fetch_size: usize,
) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
if input.starts_with("![](http://") || input.starts_with("![](https://") {
let url = &input["![](".len()..input.len() - 1];
let data = reqwest::blocking::get(url)?.bytes()?;
let response = reqwest::blocking::get(url)?;

// Check Content-Length header if present
if let Some(content_length) = response.content_length() {
if content_length as usize > max_image_fetch_size {
return Err(ValidationError::ImageTooLarge(
content_length as usize,
max_image_fetch_size,
));
}
}

// Read the body with size limit to prevent unbounded memory allocation
let mut data = Vec::new();
let mut limited_reader = response.take((max_image_fetch_size + 1) as u64);
limited_reader.read_to_end(&mut data)?;

if data.len() > max_image_fetch_size {
return Err(ValidationError::ImageTooLarge(
data.len(),
max_image_fetch_size,
));
}

let format = image::guess_format(&data)?;
// TODO Remove this clone
Expand Down Expand Up @@ -787,6 +817,7 @@ fn prepare_input<T: TokenizerTrait>(
tokenizer: &T,
config: Option<&Config>,
preprocessor_config: Option<&HubPreprocessorConfig>,
max_image_fetch_size: usize,
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
use Config::*;
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
Expand All @@ -805,7 +836,8 @@ fn prepare_input<T: TokenizerTrait>(
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let (data, mimetype, height, width) =
fetch_image(&inputs[chunk_start..chunk_end], max_image_fetch_size)?;
input_chunks.push(Chunk::Image(Image { data, mimetype }));
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
start = chunk_end;
Expand Down Expand Up @@ -990,6 +1022,10 @@ pub enum ValidationError {
InvalidImageContent(String),
#[error("Could not fetch image: {0}")]
FailedFetchImage(#[from] reqwest::Error),
#[error("Image size {0} bytes exceeds maximum allowed size of {1} bytes")]
ImageTooLarge(usize, usize),
#[error("Failed to read image data: {0}")]
ImageReadError(#[from] std::io::Error),
#[error("{0} modality is not supported")]
UnsupportedModality(&'static str),
}
Expand Down Expand Up @@ -1023,6 +1059,7 @@ mod tests {
max_input_length,
max_total_tokens,
disable_grammar_support,
1024 * 1024 * 1024, // 1GB
);

let max_new_tokens = 10;
Expand Down Expand Up @@ -1058,6 +1095,7 @@ mod tests {
max_input_length,
max_total_tokens,
disable_grammar_support,
1024 * 1024 * 1024, // 1GB
);

let max_new_tokens = 10;
Expand Down Expand Up @@ -1092,6 +1130,7 @@ mod tests {
max_input_length,
max_total_tokens,
disable_grammar_support,
1024 * 1024 * 1024, // 1GB
);
match validation
.validate(GenerateRequest {
Expand Down Expand Up @@ -1132,6 +1171,7 @@ mod tests {
max_input_length,
max_total_tokens,
disable_grammar_support,
1024 * 1024 * 1024, // 1GB
);
match validation
.validate(GenerateRequest {
Expand Down Expand Up @@ -1203,6 +1243,7 @@ mod tests {
max_input_length,
max_total_tokens,
disable_grammar_support,
1024 * 1024 * 1024, // 1GB
);
match validation
.validate(GenerateRequest {
Expand Down Expand Up @@ -1293,6 +1334,7 @@ mod tests {
max_input_length,
max_total_tokens,
disable_grammar_support,
1024 * 1024 * 1024, // 1GB
);

let chunks = match validation
Expand Down Expand Up @@ -1349,6 +1391,7 @@ mod tests {
max_input_length,
max_total_tokens,
disable_grammar_support,
1024 * 1024 * 1024, // 1GB
);

let (encoding, chunks) = match validation
Expand Down
2 changes: 1 addition & 1 deletion server/Makefile-flash-att-v2
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ flash_att_v2_commit_rocm := 47bd46e0204a95762ae48712fd1a3978827c77fd

build-flash-attention-v2-cuda:
pip install -U packaging wheel
pip install flash-attn==$(flash_att_v2_commit_cuda)
pip install --no-build-isolation flash-attn==$(flash_att_v2_commit_cuda)

install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
echo "Flash v2 installed"
Expand Down
2 changes: 1 addition & 1 deletion server/tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def batch_type(self):
def generate_token(self, batch):
raise NotImplementedError

tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")

model = TestModel(
"test_model_id",
Expand Down
Loading