Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replacing AbstractModule and reuse_variables decorator when upgrading to Sonnet 2 #236

Closed
isabellahuang opened this issue Mar 15, 2022 · 4 comments

Comments

@isabellahuang
Copy link

isabellahuang commented Mar 15, 2022

I want to upgrade my TF version from 1.15 to 2.5, and therefore need to use Sonnet 2 instead of Sonnet 1. I need to convert all my classes that inherit from snt.AbstractModule and want to know how to do it correctly.

My current Sonnet 1 module is here:

class Normalizer(snt.AbstractModule):
  """Feature normalizer that accumulates statistics online."""

  def __init__(self, size, max_accumulations=10**10, std_epsilon=1e-8,
               name='Normalizer'):
    super(Normalizer, self).__init__(name=name)
    self._max_accumulations = max_accumulations
    self._std_epsilon = std_epsilon

    with self._enter_variable_scope(): 
      self._acc_count = tf.Variable(0, dtype=tf.float32, trainable=False)
      self._num_accumulations = tf.Variable(0, dtype=tf.float32,
                                            trainable=False)
      self._acc_sum = tf.Variable(tf.zeros(size, tf.float32), trainable=False)
      self._acc_sum_squared = tf.Variable(tf.zeros(size, tf.float32),
                                          trainable=False)

  def _build(self, batched_data, accumulate=False): # Used to be True by default
    """Normalizes input data and accumulates statistics."""
    update_op = tf.no_op()
    if accumulate:
      # stop accumulating after a million updates, to prevent accuracy issues
      update_op = tf.cond(self._num_accumulations < self._max_accumulations,
                          lambda: self._accumulate(batched_data),
                          tf.no_op)
    with tf.control_dependencies([update_op]):
      return (batched_data - self._mean()) / self._std_with_epsilon()

  @snt.reuse_variables
  def inverse(self, normalized_batch_data):
    """Inverse transformation of the normalizer."""
    return normalized_batch_data * self._std_with_epsilon() + self._mean()

  def _accumulate(self, batched_data):
    """Function to perform the accumulation of the batch_data statistics."""
    count = tf.cast(tf.shape(batched_data)[0], tf.float32)
    data_sum = tf.reduce_sum(batched_data, axis=0)
    squared_data_sum = tf.reduce_sum(batched_data**2, axis=0)
    return tf.group(
        tf.assign_add(self._acc_sum, data_sum),
        tf.assign_add(self._acc_sum_squared, squared_data_sum),
        tf.assign_add(self._acc_count, count),
        tf.assign_add(self._num_accumulations, 1.))

  def _mean(self):
    safe_count = tf.maximum(self._acc_count, 1.)
    return self._acc_sum / safe_count

  def _std_with_epsilon(self):
    safe_count = tf.maximum(self._acc_count, 1.)
    std = tf.sqrt(self._acc_sum_squared / safe_count - self._mean()**2)
    return tf.math.maximum(std, self._std_epsilon)

and what I believe the Sonnet 2 compatible module should look like

class Normalizer(snt.Module):

  """Feature normalizer that accumulates statistics online."""

  def __init__(self, size, max_accumulations=10**10, std_epsilon=1e-8,
               name='Normalizer'):
    super(Normalizer, self).__init__(name=name)
    self._max_accumulations = max_accumulations
    self._std_epsilon = std_epsilon

  @snt.once
  def _initialize(self):
    self._acc_count = tf.Variable(0, dtype=tf.float32, trainable=False)
    self._num_accumulations = tf.Variable(0, dtype=tf.float32,
                                          trainable=False)
    self._acc_sum = tf.Variable(tf.zeros(size, tf.float32), trainable=False)
    self._acc_sum_squared = tf.Variable(tf.zeros(size, tf.float32),
                                        trainable=False)

  def _build(self, batched_data, accumulate=False): # Used to be True by default
    """Normalizes input data and accumulates statistics."""
    update_op = tf.no_op()
    if accumulate:
      # stop accumulating after a million updates, to prevent accuracy issues
      update_op = tf.cond(self._num_accumulations < self._max_accumulations,
                          lambda: self._accumulate(batched_data),
                          tf.no_op)
    with tf.control_dependencies([update_op]):
      return (batched_data - self._mean()) / self._std_with_epsilon()

  def inverse(self, normalized_batch_data):
    """Inverse transformation of the normalizer."""
    return normalized_batch_data * self._std_with_epsilon() + self._mean()

  def _accumulate(self, batched_data):
    """Function to perform the accumulation of the batch_data statistics."""
    count = tf.cast(tf.shape(batched_data)[0], tf.float32)
    data_sum = tf.reduce_sum(batched_data, axis=0)
    squared_data_sum = tf.reduce_sum(batched_data**2, axis=0)
    return tf.group(
        tf.assign_add(self._acc_sum, data_sum),
        tf.assign_add(self._acc_sum_squared, squared_data_sum),
        tf.assign_add(self._acc_count, count),
        tf.assign_add(self._num_accumulations, 1.))

  def _mean(self):
    safe_count = tf.maximum(self._acc_count, 1.)
    return self._acc_sum / safe_count

  def _std_with_epsilon(self):
    safe_count = tf.maximum(self._acc_count, 1.)
    std = tf.sqrt(self._acc_sum_squared / safe_count - self._mean()**2)
    return tf.math.maximum(std, self._std_epsilon)

  def __call__(self, batched_data, accumulate=False):
    self._initialize()
    return self._build(batched_data, accumulate=False)

  1. The main changes I made are to add a __call__() method that initializes all the variables once (with the @snt.once decorator), then returns the value using the old _build() method. Is this correct?

  2. Also, how do we deal with the previous enter_variable_scope() call and @snt.reuse_variables decorator? The self.inverse method is called outside of the module, which I understand is a reason for using @snt.reuse_variables on it. I simply got rid of them in the new module but I'm not sure this is right.

  3. Lastly, is the @snt.reuse_variables decorator in the original module even used, seeing as there are no tf.get_variable calls in the self.inverse method? Every existing example I've seen for how to use @snt.reuse_variables has a tf.get_variable call in the method.

Thanks

@isabellahuang isabellahuang changed the title Replacing AbstractModule to Module when upgrading to Sonnet 2 Replacing AbstractModule and reuse_variables decorator when upgrading to Sonnet 2 Mar 15, 2022
@tomhennigan
Copy link
Collaborator

The main changes I made are to add a call() method that initializes all the variables once (with the @snt.once decorator), then returns the value using the old _build() method. Is this correct?

In your usecase you can actually move variable creation to the constructor if you prefer. the @snt.once pattern is used because many module parameters depend on the shape of the input to that module, it is convenient to do this the first time you run __call__.

Also, how do we deal with the previous enter_variable_scope() call and @snt.reuse_variables decorator? The self.inverse method is called outside of the module, which I understand is a reason for using @snt.reuse_variables on it. I simply got rid of them in the new module but I'm not sure this is right.

No need for this, in Sonnet 2 variable sharing is done using regular Python objects (e.g. just create your modules once, and just create parameters of your module once, to reuse parameters just refer to the ones you made before).

Lastly, is the @snt.reuse_variables decorator in the original module even used, seeing as there are no tf.get_variable calls in the self.inverse method? Every existing example I've seen for how to use @snt.reuse_variables has a tf.get_variable call in the method.

It isn't needed for reusing variables (because as you say there aren't any in that method), it will have a side effect in Sonnet 1 of entering a name scope so operations would have a clear name in the TF graph, this won't affect correctness or performance but might be useful for debugging.

In Sonnet 2 all module methods enter a name scope, so if you are debugging a tf.function you should find that your code already has sensible names for operations.

want to know how to do it correctly

There are a few other changes you can make:

  1. No need for control deps or group since in TF2 effectful operations (e.g. assign) will happen in program order.
  2. It would be useful to add a name to your variables for debugging.
  3. You can skip the tf.cond in your call, in TF2 you can use tf.function and it's "autograph" feature will convert python control flow to tf control flow.
  4. I'd suggest name=None in the constructor, Sonnet will pick a name based on your class name if the user does not pass one.
  5. You don't seem to use the accumulate argument in __call__.

Here is my version of your module, seems to work: https://colab.research.google.com/gist/tomhennigan/520f004cc781e8231f53ea9dd62bea86/example-of-porting-to-sonnet-2.ipynb

@isabellahuang
Copy link
Author

Hey @tomhennigan, thanks so much! I'm trying to migrate the rest of the code before I can fully confirm these changes work, and am working through some issues now. One of them is that Sonnet 2 throws an error when eager execution is disabled. Is there a way to get Sonnet 2 working without eager execution?

@tomhennigan
Copy link
Collaborator

tomhennigan commented Mar 24, 2022

Sonnet 2 throws an error when eager execution is disabled. Is there a way to get Sonnet 2 working without eager execution?

Hi @isabellahuang , Sonnet 2 is designed to work both eagerly and inside @tf.function. I would guess the error you are getting is because you are using TensorFlow 1 which defaults to graph mode (we do check to make sure you are using TF2)? What version of TensorFlow do you have installed? You can check with:

import tensorflow as tf
print(tf.__version__)

@isabellahuang
Copy link
Author

isabellahuang commented Mar 24, 2022

Thanks @tomhennigan , I think some leftover import tensorflow.v1.compat as tf imports threw that error. I've successfully migrated to TF 2.5 now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants