Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stateful in Pytorch #8

Closed
OverLordGoldDragon opened this issue Mar 23, 2020 · 8 comments
Closed

Stateful in Pytorch #8

OverLordGoldDragon opened this issue Mar 23, 2020 · 8 comments

Comments

@OverLordGoldDragon
Copy link

Inspecting Pytorch's source code, I don't think stateful=True is supported, though there's a custom implementation, so it appears doable. Any planned support? It's crucial in my application

@sharvil
Copy link
Contributor

sharvil commented Mar 23, 2020

From what I can tell, stateful in Keras-land is a flag that aids in implementing truncated backprop through time. It stores the last state of the RNN layer for each batch item and uses that as the initial state for the next batch that passes through the RNN. Is that correct? How does the state get reset when an entirely unrelated batch of data comes along?

@OverLordGoldDragon
Copy link
Author

OverLordGoldDragon commented Mar 23, 2020

@sharvil Not a backprop mechanism; gradients do not flow between batches, including for stateful=True. The rest of your description is correct; details below.

The states are reset to zero per layer via layer.reset_states() - or for all layers in the model via model.reset_states(); if reset_states() isn't called, the states do not reset themselves to zero. States are built and initialized via an overridable method, get_initial_state() - e.g. tailored for LSTM (using another overridable method).

So basically, implementing requires building a dedicated tensor that captures the last timestep's hidden state, and is resettable to zero via a method.


When and how does LSTM "pass states" in stateful? Pasting relevant excerpts from here:

  • When: only batch-to-batch; samples are entirely independent
  • How: in Keras, only batch-sample to batch-sample: stateful=True requires you to specify batch_shape instead of input_shape - because, Keras builds batch_size separate states of the LSTM at compiling

Per above, this should not be done:

# sampleNM = sample N at timestep(s) M
batch1 = [sample10, sample20, sample30, sample40]
batch2 = [sample21, sample41, sample11, sample31]

This implies 21 causally follows 10 - and will wreck training. Instead do:

batch1 = [sample10, sample20, sample30, sample40]
batch2 = [sample11, sample21, sample31, sample41]

@OverLordGoldDragon
Copy link
Author

OverLordGoldDragon commented Mar 23, 2020

As for how an "entirely unrelated batch" is detected in practice, can be done via callbacks w/ a counter in on_batch_end(); e.g. call reset_states() every 4th batch - or a custom train loop (my approach).

@sharvil
Copy link
Contributor

sharvil commented Mar 25, 2020

Thanks so much for the detailed writeup. I think this is fairly easy to achieve in PyTorch with the existing Haste API. Here's some sample code that demonstrates what a stateful LSTM would look like:

import torch
import torch.nn as nn
import haste_pytorch as haste


class StatefulLSTM(nn.Module):
  def __init__(self, layer):
    super().__init__()
    self.layer = layer
    self.state = None

  def forward(self, x, state=None, lengths=None, reset=False):
    if reset:
      self.state = state
    print(f'Using state {self.state}')
    y, state = self.layer(x, self.state, lengths)
    self.state = (state[0].detach(), state[1].detach())
    return y, state


SEQ_LEN = 250
BATCH_SIZE = 1
INPUT_SIZE = 3
HIDDEN_SIZE = 5

lstm = haste.LSTM(INPUT_SIZE, HIDDEN_SIZE)
model = StatefulLSTM(lstm)

loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for t in range(10):
  x = torch.rand([SEQ_LEN, BATCH_SIZE, INPUT_SIZE])
  y, _ = model(x, reset=(t % 3 == 0))
  loss = loss_fn(y, torch.zeros_like(y))

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  print(f'Step {t}, loss {loss}')

This code trains an LSTM on sequences of length 750 which are fed in chunks of 250. The LSTM state is squirreled away in self.state of StatefulLSTM and is cleared whenever reset=True is passed in to forward(...).

Does this solution work for your application?

@OverLordGoldDragon
Copy link
Author

OverLordGoldDragon commented Mar 25, 2020

Thanks for the demo. From what I understand from lstm.py, state=None triggers resetting the cell and hidden states to zero, which StatefulLSTM enables via an additional reset argument. Indeed, that's accurate. However, it's unclear how to reset states if the layers are part of a larger network (e.g. with conv layers). I don't suppose model.reset_states() is doable, unless modifying model class instance; a workaround is something like:

def reset_states(model):
    for layer in model.layers:  # tf.keras
        if hasattr(layer, 'reset_states'):
            layer.reset_states()

but I can't tell from your snippet how one would reset states before passing an input.

Also, it'd work best if stateful was part of the base implementation, rather than a dedicated one, else it'll complicate extending functionality (e.g. will also need StatefulLayerNormLSTM).

@sharvil
Copy link
Contributor

sharvil commented Mar 25, 2020

The PyTorch equivalent to your reset_states function looks like this:

def reset_states(model):
  for layer in model.modules():
    if hasattr(layer, 'reset_states'):  # could use `isinstance` instead to check if it's an RNN layer
      layer.reset_states()

I think the right answer here is to use a generic stateful wrapper for RNNs. It could be as simple as:

class StatefulRNN(nn.Module):
  def __init__(self, layer):
    super().__init__()
    self.layer = layer
    self.state = None

  def reset_states(self):
    self.state = None

  def forward(self, x, state=None, lengths=None):
    if state is not None:
      self.state = state
    print(f'Using state {self.state}')
    y, state = self.layer(x, self.state, lengths)
    self.state = self._detach_state(state)
    return y, state

  def _detach_state(self, state):
    if isinstance(state, tuple):
      return tuple(s.detach() for s in state)
    if isinstance(state, list):
      return [s.detach() for s in state]
    return state.detach()

You could wrap any of the Haste RNN layers with this StatefulRNN decorator class and you'd get the stateful behavior you're looking for.

Here's a complete example:

import torch
import torch.nn as nn
import haste_pytorch as haste


class StatefulRNN(nn.Module):
  def __init__(self, layer):
    super().__init__()
    self.layer = layer
    self.state = None

  def reset_states(self):
    self.state = None

  def forward(self, x, state=None, lengths=None):
    if state is not None:
      self.state = state
    print(f'Using state {self.state}')
    y, state = self.layer(x, self.state, lengths)
    self.state = self._detach_state(state)
    return y, state

  def _detach_state(self, state):
    if isinstance(state, tuple):
      return tuple(s.detach() for s in state)
    if isinstance(state, list):
      return [s.detach() for s in state]
    return state.detach()


def reset_states(model):
  for layer in model.modules():
    if hasattr(layer, 'reset_states'):  # could use `isinstance` instead to check if it's an RNN layer
      layer.reset_states()


SEQ_LEN = 250
BATCH_SIZE = 10
INPUT_SIZE = 3
HIDDEN_SIZE = 5

lstm = haste.GRU(INPUT_SIZE, HIDDEN_SIZE)  # or haste.LSTM or ...
model = StatefulRNN(lstm)

learning_rate = 1e-3
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(10):
  x = torch.rand([SEQ_LEN, BATCH_SIZE, INPUT_SIZE])
  if t % 3 == 0:
    reset_states(model)
  y, _ = model(x)
  loss = loss_fn(y, torch.zeros_like(y))

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  print(f'Step {t}, loss {loss}')

@OverLordGoldDragon
Copy link
Author

@sharvil Looks excellent - a wrapper is a fair-enough alternative. To be sure all works as expected, I'll get back to you once I get haste_pytorch working.

@OverLordGoldDragon
Copy link
Author

@sharvil Ran the script - looks good; I'll compare it more extensively vs. Keras later, but currently all seems to work as expected. Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants