You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
The LSTM module provided by burn seems to always fail (panic) when batch size is not equal to 1.
To Reproduce
To reproduce, you can run this code:
use burn::{
backend::TchBackend,
nn::LstmConfig,
tensor::{Data,Tensor},};typeB = TchBackend<f32>;fnmain(){// This part workslet lstm = LstmConfig::new(1,1,false,1).init::<B>();let a = Tensor::<B,3>::from_data(Data::from([[[0.0]]]));println!("shape: {:?}", a.shape());
lstm.forward(a,None);println!("success");// This part panicslet lstm = LstmConfig::new(1,1,false,2).init::<B>();let a = Tensor::<B,3>::from_data(Data::from([[[0.0]],[[0.0]]]));println!("shape: {:?}", a.shape());
lstm.forward(a,None);// Panic}
Expected behavior
I expect that the lstm should handle various batch sizes correctly, without panicking.
Desktop (please complete the following information):
OS: Windows 11
Burn version
0.9.0 (using burn = { version = "0.9.0", features = ["train", "tch"] })
Additional context
I see that recently the batch size parameter was removed from the LstmConfig (though this change wasn't released yet), and is instead inferred dynamically at runtime. I suppose it's possible that this bug wasfixed as part of that change, but I have not checked to see whether that is the case.
The text was updated successfully, but these errors were encountered:
// store the state for this timestep
batched_cell_state = batched_cell_state.slice_assign([0..batch_size, t..(t + 1),0..self.d_hidden],
cell_state.clone().unsqueeze(),);
batched_hidden_state = batched_hidden_state.slice_assign([0..batch_size, t..(t + 1),0..self.d_hidden],
hidden_state.clone().unsqueeze(),);
Describe the bug
The LSTM module provided by burn seems to always fail (panic) when batch size is not equal to 1.
To Reproduce
To reproduce, you can run this code:
Expected behavior
I expect that the lstm should handle various batch sizes correctly, without panicking.
Desktop (please complete the following information):
Burn version
burn = { version = "0.9.0", features = ["train", "tch"] }
)Additional context
I see that recently the batch size parameter was removed from the
LstmConfig
(though this change wasn't released yet), and is instead inferred dynamically at runtime. I suppose it's possible that this bug wasfixed as part of that change, but I have not checked to see whether that is the case.The text was updated successfully, but these errors were encountered: