Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lstm panics when batch size is greater than 1 #872

Closed
mosheduminer opened this issue Oct 17, 2023 · 3 comments
Closed

Lstm panics when batch size is greater than 1 #872

mosheduminer opened this issue Oct 17, 2023 · 3 comments

Comments

@mosheduminer
Copy link

mosheduminer commented Oct 17, 2023

Describe the bug
The LSTM module provided by burn seems to always fail (panic) when batch size is not equal to 1.

To Reproduce
To reproduce, you can run this code:

use burn::{
    backend::TchBackend,
    nn::LstmConfig,
    tensor::{Data, Tensor},
};

type B = TchBackend<f32>;

fn main() {
    // This part works
    let lstm = LstmConfig::new(1, 1, false, 1).init::<B>();

    let a = Tensor::<B, 3>::from_data(Data::from([[[0.0]]]));
    println!("shape: {:?}", a.shape());
    lstm.forward(a, None);
    println!("success");

    // This part panics
    let lstm = LstmConfig::new(1, 1, false, 2).init::<B>();

    let a = Tensor::<B, 3>::from_data(Data::from([[[0.0]], [[0.0]]]));
    println!("shape: {:?}", a.shape());
    lstm.forward(a, None); // Panic
}

Expected behavior
I expect that the lstm should handle various batch sizes correctly, without panicking.

Desktop (please complete the following information):

  • OS: Windows 11

Burn version

  • 0.9.0 (using burn = { version = "0.9.0", features = ["train", "tch"] })

Additional context
I see that recently the batch size parameter was removed from the LstmConfig (though this change wasn't released yet), and is instead inferred dynamically at runtime. I suppose it's possible that this bug wasfixed as part of that change, but I have not checked to see whether that is the case.

@agelas
Copy link
Contributor

agelas commented Oct 17, 2023

@mosheduminer Based on the description this bug is almost definitely still present in development/main.

@nathanielsimard Based on the discord snippets if I had to guess it's probably got to do with the squeezing/unsqueezing, so either this:

for (t, input_t) in batched_input.iter_dim(1).enumerate() {
            let input_t = input_t.squeeze(1);

or this

// store the state for this timestep
batched_cell_state = batched_cell_state.slice_assign(
  [0..batch_size, t..(t + 1), 0..self.d_hidden],
  cell_state.clone().unsqueeze(),
);
batched_hidden_state = batched_hidden_state.slice_assign(
  [0..batch_size, t..(t + 1), 0..self.d_hidden],
  hidden_state.clone().unsqueeze(),
);

@agelas agelas mentioned this issue Oct 17, 2023
1 task
@agelas
Copy link
Contributor

agelas commented Oct 25, 2023

@nathanielsimard I think this issue can be closed!

@mosheduminer
Copy link
Author

I guess you're waiting on me to close this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants