Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add integration tests #101

Merged
merged 3 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
230 changes: 195 additions & 35 deletions Cargo.lock

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
integration-tests:
cargo test --release

cuda-integration-tests:
cargo test -F text-embeddings-backend-candle/cuda -F text-embeddings-backend-candle/flash-attn -F text-embeddings-router/candle-cuda --release

integration-tests-review:
cargo insta test --review --release

cuda-integration-tests-review:
cargo insta test --review --features "text-embeddings-backend-candle/cuda text-embeddings-backend-candle/flash-attn text-embeddings-router/candle-cuda" --release

9 changes: 8 additions & 1 deletion backends/candle/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ candle-flash-attn = { version = "0.3.0", optional = true }
candle-flash-attn-v1 = { git = "https://github.com/huggingface/candle-flash-attn-v1", rev = "62b75f1ea4e0961fad7b983ee8d723ed6fd68be5", optional = true }
candle-cublaslt = { git = "https://github.com/huggingface/candle-cublaslt", rev = "58684e116aae248c353f87846ddf0b2a8a7ed855", optional = true }
candle-layer-norm = { git = "https://github.com/huggingface/candle-layer-norm", rev = "5ed96012a693dff9685320765dd55a57fdaecdd6", optional = true }
lazy_static = "^1.4"
text-embeddings-backend-core = { path = "../core" }
tracing = "^0.1"
safetensors = "^0.4"
Expand All @@ -24,6 +23,14 @@ serde = { version = "^1.0", features = ["serde_derive"] }
serde_json = "^1.0"
memmap2 = "^0.9"

[dev-dependencies]
insta = { git = "https://github.com/OlivierDehaene/insta", rev = "f4f98c0410b91fb5a28b10df98e4422955be9c2c", features = ["yaml"] }
is_close = "0.1.3"
hf-hub = "0.3.2"
anyhow = "1.0.75"
tokenizers = { version = "^0.15.0", default-features = false, features = ["onig", "esaxx_fast"] }
serial_test = "2.0.0"

[build-dependencies]
anyhow = { version = "1", features = ["backtrace"] }

Expand Down
50 changes: 35 additions & 15 deletions backends/candle/src/compute_cap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,40 @@ use candle::cuda_backend::cudarc::driver::sys::CUdevice_attribute::{
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
};
use candle::cuda_backend::cudarc::driver::CudaDevice;
use lazy_static::lazy_static;
use std::sync::Once;

lazy_static! {
pub static ref RUNTIME_COMPUTE_CAP: usize = {
let device = CudaDevice::new(0).expect("cuda is not available");
let major = device
.attribute(CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR)
.unwrap();
let minor = device
.attribute(CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR)
.unwrap();
(major * 10 + minor) as usize
};
pub static ref COMPILE_COMPUTE_CAP: usize = env!("CUDA_COMPUTE_CAP").parse::<usize>().unwrap();
static INIT: Once = Once::new();
static mut RUNTIME_COMPUTE_CAP: usize = 0;
static mut COMPILE_COMPUTE_CAP: usize = 0;

fn init_compute_caps() {
unsafe {
INIT.call_once(|| {
let device = CudaDevice::new(0).expect("cuda is not available");
let major = device
.attribute(CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR)
.unwrap();
let minor = device
.attribute(CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR)
.unwrap();
RUNTIME_COMPUTE_CAP = (major * 10 + minor) as usize;
COMPILE_COMPUTE_CAP = env!("CUDA_COMPUTE_CAP").parse::<usize>().unwrap();
});
}
}

pub fn get_compile_compute_cap() -> usize {
unsafe {
init_compute_caps();
COMPILE_COMPUTE_CAP
}
}

pub fn get_runtime_compute_cap() -> usize {
unsafe {
init_compute_caps();
RUNTIME_COMPUTE_CAP
}
}

fn compute_cap_matching(runtime_compute_cap: usize, compile_compute_cap: usize) -> bool {
Expand All @@ -30,8 +50,8 @@ fn compute_cap_matching(runtime_compute_cap: usize, compile_compute_cap: usize)
}

pub fn incompatible_compute_cap() -> bool {
let compile_compute_cap = *COMPILE_COMPUTE_CAP;
let runtime_compute_cap = *RUNTIME_COMPUTE_CAP;
let compile_compute_cap = get_compile_compute_cap();
let runtime_compute_cap = get_runtime_compute_cap();
!compute_cap_matching(runtime_compute_cap, compile_compute_cap)
}

Expand Down
12 changes: 7 additions & 5 deletions backends/candle/src/flash_attn.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::compute_cap::RUNTIME_COMPUTE_CAP;
use crate::compute_cap::get_runtime_compute_cap;
use candle::Tensor;

#[allow(clippy::too_many_arguments, unused)]
Expand All @@ -13,7 +13,9 @@ pub(crate) fn flash_attn_varlen(
softmax_scale: f32,
causal: bool,
) -> Result<Tensor, candle::Error> {
if *RUNTIME_COMPUTE_CAP == 75 {
let runtime_compute_cap = get_runtime_compute_cap();

if runtime_compute_cap == 75 {
#[cfg(feature = "flash-attn-v1")]
{
use candle_flash_attn_v1::flash_attn_varlen;
Expand All @@ -31,7 +33,7 @@ pub(crate) fn flash_attn_varlen(
}
#[cfg(not(feature = "flash-attn-v1"))]
candle::bail!("Flash attention v1 is not installed. Use `flash-attn-v1` feature.")
} else if (80..90).contains(&*RUNTIME_COMPUTE_CAP) {
} else if (80..90).contains(&runtime_compute_cap) {
#[cfg(feature = "flash-attn")]
{
use candle_flash_attn::flash_attn_varlen;
Expand All @@ -49,7 +51,7 @@ pub(crate) fn flash_attn_varlen(
}
#[cfg(not(feature = "flash-attn"))]
candle::bail!("Flash attention is not installed. Use `flash-attn-v1` feature.")
} else if *RUNTIME_COMPUTE_CAP == 90 {
} else if runtime_compute_cap == 90 {
#[cfg(feature = "flash-attn")]
{
use candle_flash_attn::flash_attn_varlen;
Expand All @@ -70,6 +72,6 @@ pub(crate) fn flash_attn_varlen(
}
candle::bail!(
"GPU with CUDA capability {} is not supported",
*RUNTIME_COMPUTE_CAP
runtime_compute_cap
);
}
2 changes: 1 addition & 1 deletion backends/candle/src/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ mod cublaslt;
mod layer_norm;
mod linear;

pub use cublaslt::CUBLASLT;
pub use cublaslt::get_cublas_lt_wrapper;
pub use layer_norm::LayerNorm;
pub use linear::{HiddenAct, Linear};
42 changes: 24 additions & 18 deletions backends/candle/src/layers/cublaslt.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,34 @@
use crate::layers::HiddenAct;
use candle::{Device, Result, Tensor};
use lazy_static::lazy_static;
use std::sync::Once;

#[cfg(feature = "cuda")]
use candle_cublaslt::{fused_batch_matmul, fused_matmul, Activation, CublasLt};

lazy_static! {
pub static ref CUBLASLT: Option<CublasLtWrapper> = {
match Device::cuda_if_available(0) {
Ok(device) => {
#[cfg(feature = "cuda")]
{
Some(CublasLtWrapper {
cublaslt: CublasLt::new(&device).unwrap(),
})
}
#[cfg(not(feature = "cuda"))]
{
None
static INIT: Once = Once::new();
static mut CUBLASLT: Option<CublasLtWrapper> = None;

pub fn get_cublas_lt_wrapper() -> Option<&'static CublasLtWrapper> {
unsafe {
INIT.call_once(|| {
CUBLASLT = match Device::cuda_if_available(0) {
Ok(device) => {
#[cfg(feature = "cuda")]
{
Some(CublasLtWrapper {
cublaslt: CublasLt::new(&device).unwrap(),
})
}
#[cfg(not(feature = "cuda"))]
{
None
}
}
}
Err(_) => None,
}
};
Err(_) => None,
};
});
CUBLASLT.as_ref()
}
}

#[derive(Debug, Clone)]
Expand Down
42 changes: 23 additions & 19 deletions backends/candle/src/layers/linear.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::layers::CUBLASLT;
use candle::{Device, Result, Tensor, D};
use crate::layers::cublaslt::get_cublas_lt_wrapper;
use candle::{Device, Result, Tensor};
use serde::Deserialize;

#[derive(Debug, Deserialize, PartialEq, Clone)]
Expand Down Expand Up @@ -33,23 +33,27 @@ impl Linear {
let _enter = self.span.enter();

#[allow(unused)]
if let (Device::Cuda(_), Some(cublaslt)) = (x.device(), &*CUBLASLT) {
// fused matmul requires x to be dims2
let mut final_shape = x.dims().to_vec();
final_shape.pop();
final_shape.push(self.weight.dims()[0]);

let x = x.flatten_to(D::Minus2)?;
let result = cublaslt.matmul(
&self.weight,
&x,
None,
None,
None,
self.bias.as_ref(),
self.act.clone(),
)?;
result.reshape(final_shape)
if let (Device::Cuda(_), Some(cublaslt)) = (x.device(), get_cublas_lt_wrapper()) {
match x.dims() {
&[bsize, _, _] => cublaslt.batch_matmul(
&self.weight.broadcast_left(bsize)?,
x,
None,
None,
None,
self.bias.as_ref(),
self.act.clone(),
),
_ => cublaslt.matmul(
&self.weight,
x,
None,
None,
None,
self.bias.as_ref(),
self.act.clone(),
),
}
} else {
let w = match x.dims() {
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
Expand Down
6 changes: 4 additions & 2 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ mod layers;
mod models;

#[cfg(feature = "cuda")]
use crate::compute_cap::{incompatible_compute_cap, COMPILE_COMPUTE_CAP, RUNTIME_COMPUTE_CAP};
use crate::compute_cap::{
get_compile_compute_cap, get_runtime_compute_cap, incompatible_compute_cap,
};
#[cfg(feature = "cuda")]
use crate::models::FlashBertModel;
use crate::models::{BertModel, JinaBertModel, Model, PositionEmbeddingType};
Expand Down Expand Up @@ -94,7 +96,7 @@ impl CandleBackend {
#[cfg(feature = "cuda")]
{
if incompatible_compute_cap() {
return Err(BackendError::Start(format!("Runtime compute cap {} is not compatible with compile time compute cap {}", *RUNTIME_COMPUTE_CAP, *COMPILE_COMPUTE_CAP)));
return Err(BackendError::Start(format!("Runtime compute cap {} is not compatible with compile time compute cap {}", get_runtime_compute_cap(), get_compile_compute_cap())));
}

if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
Expand Down
6 changes: 4 additions & 2 deletions backends/candle/src/models/bert.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::layers::{HiddenAct, LayerNorm, Linear, CUBLASLT};
use crate::layers::{get_cublas_lt_wrapper, HiddenAct, LayerNorm, Linear};
use crate::models::Model;
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{Embedding, VarBuilder};
Expand Down Expand Up @@ -185,7 +185,9 @@ impl BertAttention {
let value_layer = &qkv[2];

#[allow(unused_variables)]
let context_layer = if let (Device::Cuda(_), Some(cublaslt)) = (device, &*CUBLASLT) {
let context_layer = if let (Device::Cuda(_), Some(cublaslt)) =
(device, get_cublas_lt_wrapper())
{
#[cfg(feature = "cuda")]
{
// cuBLASLt batch matmul implementation requires inputs to be dims3
Expand Down
6 changes: 4 additions & 2 deletions backends/candle/src/models/jina.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::alibi::build_alibi_tensor;
use crate::layers::{HiddenAct, LayerNorm, Linear, CUBLASLT};
use crate::layers::{get_cublas_lt_wrapper, HiddenAct, LayerNorm, Linear};
use crate::models::Model;
use crate::models::{Config, PositionEmbeddingType};
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
Expand Down Expand Up @@ -159,7 +159,9 @@ impl BertAttention {
let value_layer = &qkv[2];

#[allow(unused_variables)]
let context_layer = if let (Device::Cuda(_), Some(cublaslt)) = (device, &*CUBLASLT) {
let context_layer = if let (Device::Cuda(_), Some(cublaslt)) =
(device, get_cublas_lt_wrapper())
{
#[cfg(feature = "cuda")]
{
// cuBLASLt batch matmul implementation requires inputs to be dims3
Expand Down