-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stateful in Pytorch #8
Comments
From what I can tell, |
@sharvil Not a backprop mechanism; gradients do not flow between batches, including for The states are reset to zero per layer via So basically, implementing requires building a dedicated tensor that captures the last timestep's hidden state, and is resettable to zero via a method. When and how does LSTM "pass states" in stateful? Pasting relevant excerpts from here:
Per above, this should not be done: # sampleNM = sample N at timestep(s) M
batch1 = [sample10, sample20, sample30, sample40]
batch2 = [sample21, sample41, sample11, sample31] This implies batch1 = [sample10, sample20, sample30, sample40]
batch2 = [sample11, sample21, sample31, sample41] |
As for how an "entirely unrelated batch" is detected in practice, can be done via callbacks w/ a counter in |
Thanks so much for the detailed writeup. I think this is fairly easy to achieve in PyTorch with the existing Haste API. Here's some sample code that demonstrates what a stateful LSTM would look like: import torch
import torch.nn as nn
import haste_pytorch as haste
class StatefulLSTM(nn.Module):
def __init__(self, layer):
super().__init__()
self.layer = layer
self.state = None
def forward(self, x, state=None, lengths=None, reset=False):
if reset:
self.state = state
print(f'Using state {self.state}')
y, state = self.layer(x, self.state, lengths)
self.state = (state[0].detach(), state[1].detach())
return y, state
SEQ_LEN = 250
BATCH_SIZE = 1
INPUT_SIZE = 3
HIDDEN_SIZE = 5
lstm = haste.LSTM(INPUT_SIZE, HIDDEN_SIZE)
model = StatefulLSTM(lstm)
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for t in range(10):
x = torch.rand([SEQ_LEN, BATCH_SIZE, INPUT_SIZE])
y, _ = model(x, reset=(t % 3 == 0))
loss = loss_fn(y, torch.zeros_like(y))
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f'Step {t}, loss {loss}') This code trains an LSTM on sequences of length 750 which are fed in chunks of 250. The LSTM state is squirreled away in Does this solution work for your application? |
Thanks for the demo. From what I understand from lstm.py, def reset_states(model):
for layer in model.layers: # tf.keras
if hasattr(layer, 'reset_states'):
layer.reset_states() but I can't tell from your snippet how one would reset states before passing an input. Also, it'd work best if stateful was part of the base implementation, rather than a dedicated one, else it'll complicate extending functionality (e.g. will also need |
The PyTorch equivalent to your def reset_states(model):
for layer in model.modules():
if hasattr(layer, 'reset_states'): # could use `isinstance` instead to check if it's an RNN layer
layer.reset_states() I think the right answer here is to use a generic stateful wrapper for RNNs. It could be as simple as: class StatefulRNN(nn.Module):
def __init__(self, layer):
super().__init__()
self.layer = layer
self.state = None
def reset_states(self):
self.state = None
def forward(self, x, state=None, lengths=None):
if state is not None:
self.state = state
print(f'Using state {self.state}')
y, state = self.layer(x, self.state, lengths)
self.state = self._detach_state(state)
return y, state
def _detach_state(self, state):
if isinstance(state, tuple):
return tuple(s.detach() for s in state)
if isinstance(state, list):
return [s.detach() for s in state]
return state.detach() You could wrap any of the Haste RNN layers with this Here's a complete example: import torch
import torch.nn as nn
import haste_pytorch as haste
class StatefulRNN(nn.Module):
def __init__(self, layer):
super().__init__()
self.layer = layer
self.state = None
def reset_states(self):
self.state = None
def forward(self, x, state=None, lengths=None):
if state is not None:
self.state = state
print(f'Using state {self.state}')
y, state = self.layer(x, self.state, lengths)
self.state = self._detach_state(state)
return y, state
def _detach_state(self, state):
if isinstance(state, tuple):
return tuple(s.detach() for s in state)
if isinstance(state, list):
return [s.detach() for s in state]
return state.detach()
def reset_states(model):
for layer in model.modules():
if hasattr(layer, 'reset_states'): # could use `isinstance` instead to check if it's an RNN layer
layer.reset_states()
SEQ_LEN = 250
BATCH_SIZE = 10
INPUT_SIZE = 3
HIDDEN_SIZE = 5
lstm = haste.GRU(INPUT_SIZE, HIDDEN_SIZE) # or haste.LSTM or ...
model = StatefulRNN(lstm)
learning_rate = 1e-3
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(10):
x = torch.rand([SEQ_LEN, BATCH_SIZE, INPUT_SIZE])
if t % 3 == 0:
reset_states(model)
y, _ = model(x)
loss = loss_fn(y, torch.zeros_like(y))
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f'Step {t}, loss {loss}') |
@sharvil Looks excellent - a wrapper is a fair-enough alternative. To be sure all works as expected, I'll get back to you once I get haste_pytorch working. |
@sharvil Ran the script - looks good; I'll compare it more extensively vs. Keras later, but currently all seems to work as expected. Thank you. |
Inspecting Pytorch's source code, I don't think
stateful=True
is supported, though there's a custom implementation, so it appears doable. Any planned support? It's crucial in my applicationThe text was updated successfully, but these errors were encountered: