New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
hidden state raises NotImplementedError() #21
Comments
Hi, So in short I think that the variable last_state_list can be reused as hidden_state in the next forward pass if the flag return_all_layers was set to True during initialization (because you need to pass the hidden and cell states of all layers and not just the last one). As a clean fix I would just use an if in case the hidden_state is not provided. I will try it myself in these days and get back here to update the post in case. |
Hi, I actually tried It actually works okay...i am still experimenting and seeing if it increases performance. I have not tried to set Let me know if you find a good solution for dealing with the hidden state! :) |
In general your solution works only because in the case of a single layer setting the return_all_layers true or false makes no difference. For the general case if you set it to true it runs without any error (I still have to check if it learns properly, but I guess it will). |
To be honest I am massively confused :D It is the first time I am working with lstms (for my bachelor thesis) and I so far they gave me a lot of headache 😄 |
Hi, |
The stacking was just done because the tensors are feed in a loop into the lstm cell anyway (line 158-161). Yes, but detaching was (so far) the only way for me to enable the model to run without running into memory issues. I guess If I want to use information from [t-2, t-1, t] correctly I would need to detach after every 2 batches and specify retain_graph=True. (See https://discuss.pytorch.org/t/convolutional-lstm-retain-graph-error/85200/4 at the bottom). EDIT: I think I see want you meant with stacking. at t: at t+1: at t+2: so t will be processed 3x. |
Hi, My implementation is in Reinforcement Learning, so is somewhat different, but in one case I'm feeding an input vector x of shape (time, batch, channels, resolution, resolution) together with the initial hidden state (detached from the graph because it's how I get it, I never tried keeping its graph). In the other case I'm using the network with torch.no_grad() and basically feeding in a loop the updated hidden state together with the new input (1, batch, channels, resolution, resolution) . But if you look at the implementation there is nothing that makes me think that the input or the hidden state can't have the computational graph attached. Usually the gradients accumulated on the network's weights are reset calling optimizer.zero_grad(), whereas the variables in a batch with their gradients are trashed away automatically after a while. To wrap up just try
if you have batches of videos that do not need previous context ( hidden_state = bunch of zeros).
And that's it, no detach at all. Let me know how it goes :) |
Hi, I am actually in the middle of testing and waiting for the results :D There are 2 possibilities:
Setting Solution: --> No error + keeps the hidden values from previous run. I am testing right now if there is a performance difference between: A.) Inserting [t-2, t-1, t] in a loop each batch an detach like described in 2. or B.) Insert [t] and only detach the hidden state every 3 batches and keep the graph in between (will be small enough to not go beyond my memory limit) I already have tested Option A with hidden state option 1 vs. Option A with hidden state option 2 and the scond one is ~2-8% points better (depending on the metric). But this might very well be due to my use case, since the frames are so similar... Cheers, |
Hey,
I am trying to integrate your convLSTM cell into an existing model I have.
I did this in the following way:
self.hidden is none in the first run, but not none in the second, leading to:
which is in your ConvLSTM module (line 141, 142)
would this be implemented just by:
or am I misunderstanding something?
What is supposed to happen here?
Thanks for any help :-D
The text was updated successfully, but these errors were encountered: