# Extending PyTorch differentiable functions

In this notebook you'll see how to add your custom differentiable function for which you need to specify
`forward` and `backward` passes.

In [1]:
# Import some libraries
import torch
import numpy

For a gentle introduction see [PyTorch extension](https://pytorch.org/docs/stable/notes/extending.html) tutorial.

Source for `torch.autograd.Function` available [here](https://github.com/pytorch/pytorch/blob/master/torch/autograd/function.py).
These are the two that we have to override:

```python
@staticmethod
def forward(ctx, *args, **kwargs):
    """Performs the operation.
    This function is to be overridden by all subclasses.
    It must accept a context ctx as the first argument, followed by any
    number of arguments (tensors or other types).
    The context can be used to store tensors that can be then retrieved
    during the backward pass.
    """
    raise NotImplementedError

@staticmethod
def backward(ctx, *grad_outputs):
    """Defines a formula for differentiating the operation.
    This function is to be overridden by all subclasses.
    It must accept a context :attr:`ctx` as the first argument, followed by
    as many outputs did :func:`forward` return, and it should return as many
    tensors, as there were inputs to :func:`forward`. Each argument is the
    gradient w.r.t the given output, and each returned value should be the
    gradient w.r.t. the corresponding input.
    The context can be used to retrieve tensors saved during the forward
    pass. It also has an attribute :attr:`ctx.needs_input_grad` as a tuple
    of booleans representing whether each input needs gradient. E.g.,
    :func:`backward` will have ``ctx.needs_input_grad[0] = True`` if the
    first input to :func:`forward` needs gradient computated w.r.t. the
    output.
    """
    raise NotImplementedError
```    

In [2]:
# Custom addition module
class MyAdd(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x1, x2):
        # ctx is a context where we can save
        # computations for backward.
        ctx.save_for_backward(x1, x2)
        return x1 + x2

    @staticmethod
    def backward(ctx, grad_output):
        x1, x2 = ctx.saved_tensors
        grad_x1 = grad_output * torch.ones_like(x1)
        grad_x2 = grad_output * torch.ones_like(x2)
        # need to return grads in order 
        # of inputs to forward (excluding ctx)
        return grad_x1, grad_x2

In [3]:
# Let's try out the addition module
x1 = torch.randn((3), requires_grad=True)
x2 = torch.randn((3), requires_grad=True)
print(f'x1: {x1}')
print(f'x2: {x2}')
myadd = MyAdd.apply  # aliasing the apply method
y = myadd(x1, x2)
print(f' y: {y}')
z = y.mean()
print(f' z: {z}, z.grad_fn: {z.grad_fn}')
z.backward()
print(f'x1.grad: {x1.grad}')
print(f'x2.grad: {x2.grad}')

x1: tensor([ 0.2215,  1.2990, -1.7461], requires_grad=True)
x2: tensor([ 0.7587, -1.2437,  1.3185], requires_grad=True)
 y: tensor([ 0.9803,  0.0553, -0.4276], grad_fn=<MyAddBackward>)
 z: 0.20264236629009247, z.grad_fn: <MeanBackward0 object at 0x7f71301fdfa0>
x1.grad: tensor([0.3333, 0.3333, 0.3333])
x2.grad: tensor([0.3333, 0.3333, 0.3333])


In [4]:
# Custom split module
class MySplit(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        x1 = x.clone()
        x2 = x.clone()
        return x1, x2
        
    @staticmethod
    def backward(ctx, grad_x1, grad_x2):
        x = ctx.saved_tensors[0]
        print(f'grad_x1: {grad_x1}')
        print(f'grad_x2: {grad_x2}')
        return grad_x1 + grad_x2

In [5]:
# Let's try out the split module
x = torch.randn((4), requires_grad=True)
print(f' x: {x}')
split = MySplit.apply
x1, x2 = split(x)
print(f'x1: {x1}')
print(f'x2: {x2}')
y = x1 + x2
print(f' y: {y}')
z = y.mean()
print(f' z: {z}, z.grad_fn: {z.grad_fn}')
z.backward()
print(f' x.grad: {x.grad}')

 x: tensor([ 0.6151, -0.7128,  0.5547,  0.1158], requires_grad=True)
x1: tensor([ 0.6151, -0.7128,  0.5547,  0.1158], grad_fn=<MySplitBackward>)
x2: tensor([ 0.6151, -0.7128,  0.5547,  0.1158], grad_fn=<MySplitBackward>)
 y: tensor([ 1.2301, -1.4256,  1.1094,  0.2316], grad_fn=<AddBackward0>)
 z: 0.2863945960998535, z.grad_fn: <MeanBackward0 object at 0x7f71301c6940>
grad_x1: tensor([0.2500, 0.2500, 0.2500, 0.2500])
grad_x2: tensor([0.2500, 0.2500, 0.2500, 0.2500])
 x.grad: tensor([0.5000, 0.5000, 0.5000, 0.5000])


In [7]:
# Custom max module
class MyMax(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x):
        # example where we explicitly use non-torch code
        maximum = x.detach().numpy().max()
        argmax = x.detach().eq(maximum).float()
        ctx.save_for_backward(argmax)
        return torch.tensor(maximum)

    @staticmethod
    def backward(ctx, grad_output):
        argmax = ctx.saved_tensors[0]
        return grad_output * argmax

In [8]:
# Let's try out the max module
x = torch.randn((5), requires_grad=True)
print(f'x: {x}')

mymax = MyMax.apply
y = mymax(x)
print(f'y: {y}, y.grad_fn: {y.grad_fn}')

y.backward()
print(f'x.grad: {x.grad}')

x: tensor([-1.3772,  1.7799, -0.8777,  1.5753,  0.0288], requires_grad=True)
y: 1.779890537261963, y.grad_fn: <torch.autograd.function.MyMaxBackward object at 0x7f6fd7ba3ba0>
x.grad: tensor([0., 1., 0., 0., 0.])
