Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Latest tensor squeeze impl make cuda matmal fail #1948

Closed
yinqiwen opened this issue Mar 27, 2024 · 13 comments
Closed

Latest tensor squeeze impl make cuda matmal fail #1948

yinqiwen opened this issue Mar 27, 2024 · 13 comments

Comments

@yinqiwen
Copy link
Contributor

while the test code run success with candle 0.4.1 but with different x stride:[2048, 1];
while with pytorch, it produce same x stride like candle main branch, but it run the matmul success.

#[test]
fn test_squeeze() -> candle::Result<()> {
    let device = candle::Device::new_cuda(0).unwrap();
    let a = Tensor::zeros((1, 8, 2048), DType::F32, &device)?;
    let seq_len = 8_usize;
    let x = a.i((.., seq_len - 1, ..))?;
    println!(
        "x shape:{:?}, stride:{:?}, is_contiguous:{}",
        x.shape(),
        x.stride(),
        x.is_contiguous()
    );

    let w = Tensor::zeros((32, 2048), DType::F32, &device)?.t()?;
    println!(
        "w shape:{:?}, stride:{:?}, is_contiguous:{}",
        w.shape(),
        w.stride(),
        w.is_contiguous()
    );
    let x = x.matmul(&w)?;
    Ok(())
}

x shape:[1, 2048], stride:[16384, 1], is_contiguous:true
w shape:[2048, 32], stride:[1, 2048], is_contiguous:false
Error: WithBacktrace { inner: Cuda(MatMulNonContiguous { lhs_stride: [16384, 1], rhs_stride: [1, 2048], mnk: (1, 32, 2048) }),
@LaurentMazare
Copy link
Collaborator

That's actually expected, #1884 is the change that made squeeze/unsqueeze more efficient at the expense of breaking some existing use cases. In order to fix this, you should just call .contiguous() manually before the matmul.

@yinqiwen
Copy link
Contributor Author

@LaurentMazare the test code is actually modified from some code with the candle_nn::Linear, while the input tensor x is already contiguous, and the w is a transposed version from loaded weights which can not be allowed to call .contiguous()

x shape:[1, 2048], stride:[16384, 1], is_contiguous:true
w shape:[2048, 32], stride:[1, 2048], is_contiguous:false
Error: WithBacktrace { inner: Cuda(MatMulNonContiguous { lhs_stride: [16384, 1], rhs_stride: [1, 2048], mnk: (1, 32, 2048) }),

@LaurentMazare
Copy link
Collaborator

Not sure to understand, where is w coming from? If it's loaded from a var-builder it should be contiguous, if it's provided manually via e.g. Linear::new then you should be able to pass a contiguous version? Maybe you could point me at the code that is actually breaking?

@yinqiwen
Copy link
Contributor Author

u can see that the w a transposed version from loaded weight which is noncontiguous.
https://github.com/huggingface/candle/blob/main/candle-nn/src/linear.rs#L47

found this issue by the code from
https://github.com/yinqiwen/lmsf/blob/rust/core/src/model_executor/models/gemma/gemma.rs#L282-L283

@LaurentMazare
Copy link
Collaborator

Couldn't you just make x contiguous at the place that you pointed out? I think the issue might be coming from x and not from w in this case.

@yinqiwen
Copy link
Contributor Author

@LaurentMazare but x is already contiguous, while invoke .contiguous() dose nothing

@LaurentMazare
Copy link
Collaborator

Oh I see, actually the notion of contiguous is not in sync between the matmul check and the is_contiguous bit, I'll make a fix for this. In the meantime can you try forcing it to be contiguous (you're right that calling .contiguous() won't be enough), e.. by just adding 0 to it, let x = (x + 0.0)? and let me know if it fixes it.

@yinqiwen
Copy link
Contributor Author

let x = (x + 0.0)?; make a tensor with (shape:[1, 2048], stride:[2048, 1]), but seems there is another gpu memory issue produces by the '(x + 0.0)?'

let x = x.i((.., seq_len - 1, ..))?.contiguous()?;
let x = (x + 0.0)?;
println!("#### x shape:{:?}, stride:{:?}, is_contiguous:{}",
        x.shape(),
        x.stride(),
        x.is_contiguous()
);
let logits = self.lm_head.forward(&x)?;
#### x shape:[1, 2048], stride:[2048, 1], is_contiguous:true
thread '<unnamed>' panicked at /root/.cargo/registry/src/index.crates.io-6f17d22bba15001f/cudarc-0.10.0/src/driver/safe/core.rs:208:76:
called `Result::unwrap()` on an `Err` value: DriverError(CUDA_ERROR_ILLEGAL_ADDRESS, "an illegal memory access was encountered")
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
thread '<unnamed>' panicked at /root/.cargo/registry/src/index.crates.io-6f17d22bba15001f/cudarc-0.10.0/src/driver/safe/core.rs:208:76:
called `Result::unwrap()` on an `Err` value: DriverError(CUDA_ERROR_ILLEGAL_ADDRESS, "an illegal memory access was encountered")
stack backtrace:
   0:     0x563d5ce115a6 - std::backtrace_rs::backtrace::libunwind::trace::hbee8a7973eeb6c93
                               at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library/std/src/../../backtrace/src/backtrace/libunwind.rs:104:5
   1:     0x563d5ce115a6 - std::backtrace_rs::backtrace::trace_unsynchronized::hc8ac75eea3aa6899
                               at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library/std/src/../../backtrace/src/backtrace/mod.rs:66:5
   2:     0x563d5ce115a6 - std::sys_common::backtrace::_print_fmt::hc7f3e3b5298b1083
                               at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library/std/src/sys_common/backtrace.rs:68:5
   3:     0x563d5ce115a6 - <std::sys_common::backtrace::_print::DisplayBacktrace as core::fmt::Display>::fmt::hbb235daedd7c6190
                               at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library/std/src/sys_common/backtrace.rs:44:22
   4:     0x563d5ce3e5c0 - core::fmt::rt::Argument::fmt::h76c38a80d925a410
                               at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library/core/src/fmt/rt.rs:142:9
   5:     0x563d5ce3e5c0 - core::fmt::write::h3ed6aeaa977c8e45
                               at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library/core/src/fmt/mod.rs:1120:17
   6:     0x563d5ce0ebaf - std::io::Write::write_fmt::h78b18af5775fedb5
                               at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library/std/src/io/mod.rs:1810:15
   7:     0x563d5ce11384 - std::sys_common::backtrace::_print::h5d645a07e0fcfdbb
                               at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library/std/src/sys_common/backtrace.rs:47:5
   8:     0x563d5ce11384 - std::sys_common::backtrace::print::h85035a511aafe7a8
                               at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library/std/src/sys_common/backtrace.rs:34:9
   9:     0x563d5ce12c07 - std::panicking::default_hook::{{closure}}::hcce8cea212785a25
  10:     0x563d5ce12969 - std::panicking::default_hook::hf5fcb0f213fe709a
                               at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library/std/src/panicking.rs:292:9
  11:     0x563d5ce13098 - std::panicking::rust_panic_with_hook::h095fccf1dc9379ee
                               at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library/std/src/panicking.rs:779:13
  12:     0x563d5ce12f72 - std::panicking::begin_panic_handler::{{closure}}::h032ba12139b353db
                               at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library/std/src/panicking.rs:657:13
  13:     0x563d5ce11aa6 - std::sys_common::backtrace::__rust_end_short_backtrace::h9259bc2ff8fd0f76
                               at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library/std/src/sys_common/backtrace.rs:171:18
  14:     0x563d5ce12cd0 - rust_begin_unwind
                               at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library/std/src/panicking.rs:645:5
  15:     0x563d5c558e85 - core::panicking::panic_fmt::h784f20a50eaab275
                               at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library/core/src/panicking.rs:72:14
  16:     0x563d5c5593d3 - core::result::unwrap_failed::h03d8a5018196e1cd
                               at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library/core/src/result.rs:1649:5
  17:     0x563d5c619994 - <cudarc::driver::safe::core::CudaSlice<T> as core::ops::drop::Drop>::drop::h84a62617683e748a
  18:     0x563d5c62649c - alloc::sync::Arc<T,A>::drop_slow::h48352ef3f3a7aa1c
  19:     0x563d5c6275c0 - alloc::sync::Arc<T,A>::drop_slow::hc81f0778974720ba
  20:     0x563d5c670bb3 - lmsf_core::model_executor::layers::sampler::Sampler::forward::hc8d671348d974142

@yinqiwen
Copy link
Contributor Author

sorry, the crash seems produced by another unsqueeze code, I'll check it again. let x = (x + 0.0)?; works here

@LaurentMazare
Copy link
Collaborator

Cool, it's a bit worrying though if the unsqueeze make such errors as the api is supposed to be safe, pretty curious if you find out what is going on here.
In the meantime, I'm minting #1949 which should make for more flexible checks for matmul in the cuda backend, I've added your snippet as a test, thanks for making it. This also introduced a force_contiguous method on tensors so that in the future it's possible to trigger a copy (this would have worked as a quick fix instead of the + 0 hack).

@yinqiwen
Copy link
Contributor Author

yinqiwen commented Mar 28, 2024

@LaurentMazare the crash seems another issue, this can be simply reproduced by code:

#[test]
fn test_mul() -> candle::Result<()> {
    let device = candle::Device::new_cuda(0).unwrap();
    let a = Tensor::ones((1, 256000), DType::F32, &device)?;
    let b = Tensor::ones((1, 256000), DType::F32, &device)?;

    let c = b.mul(&a)?;
    println!("{}", c.to_string());

    Ok(())
}

--- after a fully cargo clean, the crash is gone, maybe an issue with cuda kernels rebuilding process.

@LaurentMazare
Copy link
Collaborator

The test worked ok for me. And indeed the building process is a bit flaky and could have caused this as there was a recent change in a .cuh file that is used for binary op like mul and this wouldn't trigger a rebuild of the kernels.
I just minted #1954 to get around this, it doesn't properly track the dependency so might rebuild a bit too often but at least it should avoid the issue you had going forward.

@yinqiwen
Copy link
Contributor Author

tested success with latest commit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants