# Gradient reversal pytorch

Inspired from the following tweets:

* https://twitter.com/mat_kelcey/status/932149793765261313
* https://twitter.com/ericjang11/status/932073259721359363

Basic idea:

```python
# Add something to gradient
f(x) + g(x) - tf.stop_gradients(g(x))

# Reverse gradient
tf.stop_gradient(f(x)*2) - f(x)
```

In [1]:
import torch
import tensorflow as tf
from torch.autograd import Variable

import numpy as np

In [2]:
def f(X):
    return X*X

def g(X):
    return X**3

In [3]:
X = np.random.randn(10)
X

array([-2.55617105,  0.13785352,  0.20610955, -0.8133813 , -1.14285471,
        1.13275964,  0.79103318,  0.14766171, -1.03148017,  0.36393108])

## Tensorflow implementation

In [4]:
sess = tf.InteractiveSession()

In [5]:
tf_X = tf.Variable(X)
init_op = tf.global_variables_initializer()

In [6]:
sess.run(init_op)
sess.run(tf_X)

array([-2.55617105,  0.13785352,  0.20610955, -0.8133813 , -1.14285471,
        1.13275964,  0.79103318,  0.14766171, -1.03148017,  0.36393108])

In [7]:
forward_op = f(tf_X)

In [8]:
sess.run(forward_op)

array([ 6.53401042,  0.01900359,  0.04248115,  0.66158914,  1.3061169 ,
        1.28314441,  0.62573349,  0.02180398,  1.06395134,  0.13244583])

In [9]:
gradient_op = tf.gradients(forward_op, tf_X)

In [10]:
sess.run(gradient_op)

[array([-5.1123421 ,  0.27570704,  0.4122191 , -1.6267626 , -2.28570943,
         2.26551928,  1.58206636,  0.29532342, -2.06296033,  0.72786216])]

In [11]:
X*2 # This should match the gradient above

array([-5.1123421 ,  0.27570704,  0.4122191 , -1.6267626 , -2.28570943,
        2.26551928,  1.58206636,  0.29532342, -2.06296033,  0.72786216])

### Modify the gradients
Keep forward pass the same. 
The trick is to add $g(x)$, such that $g'(x)$ is the gradient modifier, during the forward pass and substract it as well. But stop gradients from flowing through the substraction part. 

$f(x) + g(x) - g(x)$ will lead to gradients $f'(x) + g'(x) -g'(x)$. Since gradients don't flow through $-g'(x)$, hence we get new gradients as $f'(x) + g'(x)$

In [12]:
gradient_modifier_op = g(tf_X)

In [13]:
sess.run(gradient_modifier_op)

array([ -1.67020483e+01,   2.61971220e-03,   8.75577023e-03,
        -5.38124238e-01,  -1.49270185e+00,   1.45349420e+00,
         4.94975949e-01,   3.21961304e-03,  -1.09744470e+00,
         4.82011548e-02])

In [14]:
modified_forward_op = (f(tf_X) + g(tf_X) - tf.stop_gradient(g(tf_X)))
modified_backward_op = tf.gradients(modified_forward_op, tf_X)

In [15]:
sess.run(modified_forward_op)

array([ 6.53401042,  0.01900359,  0.04248115,  0.66158914,  1.3061169 ,
        1.28314441,  0.62573349,  0.02180398,  1.06395134,  0.13244583])

In [16]:
sess.run(modified_backward_op)

[array([ 14.48968918,   0.33271782,   0.53966255,   0.35800482,
          1.63264126,   6.1149525 ,   3.45926682,   0.36073536,
          1.12889367,   1.12519966])]

In [17]:
2*X + 3*(X**2) # This should match the gradients above

array([ 14.48968918,   0.33271782,   0.53966255,   0.35800482,
         1.63264126,   6.1149525 ,   3.45926682,   0.36073536,
         1.12889367,   1.12519966])

### Gradient reversal

Here the modifying function $g(x)$ is simply the $-2*f(x)$, this will make the gradients $-f'(x)$.

In [18]:
gradient_reversal_op = (tf.stop_gradient(2*f(tf_X)) - f(tf_X))
gradient_reversal_grad_op = tf.gradients(gradient_reversal_op, tf_X)

In [19]:
sess.run(gradient_reversal_op)

array([ 6.53401042,  0.01900359,  0.04248115,  0.66158914,  1.3061169 ,
        1.28314441,  0.62573349,  0.02180398,  1.06395134,  0.13244583])

In [20]:
sess.run(gradient_reversal_grad_op)

[array([ 5.1123421 , -0.27570704, -0.4122191 ,  1.6267626 ,  2.28570943,
        -2.26551928, -1.58206636, -0.29532342,  2.06296033, -0.72786216])]

In [21]:
sess.run((gradient_op[0] + gradient_reversal_grad_op[0])) # This should be zero. Signifying grad is reversed. 

array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.])

## Pytoch case

In [22]:
def zero_grad(X):
    if X.grad is not None:
        X.grad.data.zero_()

In [23]:
torch_X = Variable(torch.FloatTensor(X), requires_grad=True)

In [24]:
torch_X.data.numpy()

array([-2.55617094,  0.13785352,  0.20610955, -0.81338131, -1.14285469,
        1.13275969,  0.79103315,  0.14766172, -1.03148019,  0.36393109], dtype=float32)

In [25]:
f(torch_X).data.numpy()

array([ 6.53400993,  0.01900359,  0.04248115,  0.66158915,  1.30611682,
        1.28314447,  0.62573344,  0.02180398,  1.06395137,  0.13244584], dtype=float32)

In [26]:
g(torch_X).data.numpy()

array([ -1.67020454e+01,   2.61971215e-03,   8.75577051e-03,
        -5.38124263e-01,  -1.49270177e+00,   1.45349431e+00,
         4.94975895e-01,   3.21961334e-03,  -1.09744477e+00,
         4.82011586e-02], dtype=float32)

In [27]:
zero_grad(torch_X)
f_X = f(torch_X)
f_X.backward(torch.ones(f_X.size()))
torch_X.grad.data.numpy()

array([-5.11234188,  0.27570704,  0.41221911, -1.62676263, -2.28570938,
        2.26551938,  1.5820663 ,  0.29532343, -2.06296039,  0.72786218], dtype=float32)

In [28]:
2*X

array([-5.1123421 ,  0.27570704,  0.4122191 , -1.6267626 , -2.28570943,
        2.26551928,  1.58206636,  0.29532342, -2.06296033,  0.72786216])

### Modify gradients

In [29]:
modified_gradients_forward = lambda x: f(x) + g(x) - g(x).detach()

In [30]:
zero_grad(torch_X)
modified_grad = modified_gradients_forward(torch_X)
modified_grad.backward(torch.ones(modified_grad.size()))
torch_X.grad.data.numpy()

array([ 14.48968792,   0.33271781,   0.53966254,   0.35800481,
         1.63264108,   6.11495304,   3.45926666,   0.36073539,
         1.12889361,   1.12519968], dtype=float32)

In [31]:
2*X + 3*(X*X) # It should be same as above

array([ 14.48968918,   0.33271782,   0.53966255,   0.35800482,
         1.63264126,   6.1149525 ,   3.45926682,   0.36073536,
         1.12889367,   1.12519966])

### Gradien reversal

In [32]:
gradient_reversal = lambda x: (2*f(x)).detach() - f(x)

In [33]:
zero_grad(torch_X)
grad_reverse = gradient_reversal(torch_X)
grad_reverse.backward(torch.ones(grad_reverse.size()))
torch_X.grad.data.numpy()

array([ 5.11234188, -0.27570704, -0.41221911,  1.62676263,  2.28570938,
       -2.26551938, -1.5820663 , -0.29532343,  2.06296039, -0.72786218], dtype=float32)

In [34]:
-2*X # It should be same as above

array([ 5.1123421 , -0.27570704, -0.4122191 ,  1.6267626 ,  2.28570943,
       -2.26551928, -1.58206636, -0.29532342,  2.06296033, -0.72786216])