In [1]:
import pandas as pd
import os
import math
import numpy as np
import pandas as pd
import matplotlib as plt
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.autograd as autograd
from torch.autograd import Variable

In [2]:
# set seed for experiment
def set_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    # making sure GPU runs are deterministic even if they are slower
    print("Seeded everything: {}".format(seed))

In [3]:
set_seed(42)

Seeded everything: 42


# Test how the straight through estimator works

$$\text{With indicators:}\quad f(x_1, x_2) = \mathbb{1}_{x_1 > 5}x_1^2 + \mathbb{1}_{x_2 > 5}x_2^2$$
$$\text{Without indicators:}\quad g(x_1, x_2) = x_1^2 + x_2^2$$

In [4]:
# returns 1_{x > 5}
class GetIndicators(autograd.Function):
    @staticmethod
    def forward(ctx, x):
        """
            In the forward pass we receive a Tensor containing the input and return
            a Tensor containing the output. ctx is a context object that can be used
            to stash information for backward computation. You can cache arbitrary
            objects for use in the backward pass using the ctx.save_for_backward method.
        """
        # this is the correct way to do it (torch.gte also works)
        mask = Variable((x > 5).float(), requires_grad=True)
        # if you do this: mask = (x > 5).int(), then pytorch breaks the computational graph
        # and no longer calls .backward()!
        return mask
    
    @staticmethod
    def backward(ctx, g):
        """
            In the backward pass we receive a Tensor containing the gradient of the loss
            with respect to the output, and we need to compute the gradient of the loss
            with respect to the input.
            NOTE: We need only "g" here because we have only one input -> x
        """
        # NOTE: This isn't getting printed! I don't know why
        print("Incoming grad = Outgoing grad = {}".format(g))
        # send the gradient g straight-through on the backward pass.
        return g

## See if the indicator function works as expected (it does)

In [5]:
x = torch.Tensor([1, 6])
GetIndicators.apply(x)

tensor([0., 1.])

In [6]:
def f(x):
    mask = GetIndicators.apply(x)
    out = (mask * x** 2).sum()
    return out

def g(x):
    out = (x**2).sum()
    return out

$$\nabla g(x) = \begin{bmatrix} 2x_1\\ 2x_2 \end{bmatrix}$$

### Assuming STE
$$\frac{\partial f(x)}{\partial x_i} = \mathbb{1}_{x_i > 5}\frac{\partial x_i^2}{x_i} + x_i^2 \frac{\partial \mathbb{1}_{x_i > 5}}{\partial x_i} = 2x_i\mathbb{1}_{x_i > 5} + x_i^2 \frac{\partial \mathbb{1}_{x_i > 5}}{\partial x_i} = 2x_i\mathbb{1}_{x_i > 5} + x_i^2$$

#### (The grad of indicator is just 1 for STE. If you make it 0 whenever the indicator is off, that is not STE, but may also be interesting!)

### Now check if that is what we get

In [7]:
x = Variable(torch.Tensor([2, 6]), requires_grad=True)
f_out = f(x)
f_out.backward()
print("grad(f(x)) = {}".format(x.grad.data))

Incoming grad = Outgoing grad = tensor([ 4., 36.])
grad(f(x)) = tensor([ 4., 48.])


In [8]:
x = Variable(torch.Tensor([2, 6]), requires_grad=True)
g_out = g(x)
g_out.backward()
print("grad(g(x)) = {}".format(x.grad.data))

grad(g(x)) = tensor([ 4., 12.])


## It works as expected now! STE is a little weird, but it is consistent.

# Now, try a more typical usage. Threshold of a function like Bengio's example

$$f(x_1, x_2) = \mathbf{1}^T \mathbb{1}_{x^2 > 25} = \mathbb{1}_{x_1^2 > 25} + \mathbb{1}_{x_2^2 > 25}$$ 

In [9]:
# returns 1_{x^2 > 25} (which is actually the same as x > 5)
class GetIndicators(autograd.Function):
    @staticmethod
    def forward(ctx, x):
        """
            In the forward pass we receive a Tensor containing the input and return
            a Tensor containing the output. ctx is a context object that can be used
            to stash information for backward computation. You can cache arbitrary
            objects for use in the backward pass using the ctx.save_for_backward method.
        """
        mask = Variable((x**2 > 25).float(), requires_grad=True)
        return mask
    
    @staticmethod
    def backward(ctx, g):
        """
            In the backward pass we receive a Tensor containing the gradient of the loss
            with respect to the output, and we need to compute the gradient of the loss
            with respect to the input.
            NOTE: We need only "g" here because we have only one input -> x
        """
        # send the gradient g straight-through on the backward pass.
        print("Incoming grad = Outgoing grad = {}".format(g))
        return g

def f(x):
    mask = GetIndicators.apply(x)
    # if I don't include a .float(), then I can't set requires_grad=True
    out = mask.sum().float()
    return out

In [10]:
x = Variable(torch.Tensor([2, 6]), requires_grad=True)
# need to set requires_grad=True, otherwise it says it can't differentiate f
f_out = f(x)
f_out.backward()
print("grad(f(x)) = {}".format(x.grad.data))

Incoming grad = Outgoing grad = tensor([1., 1.])
grad(f(x)) = tensor([1., 1.])


## Gradient is just 1. As expected, it ignores the Indicator in the backward.

# Let's see how this plays with clamp

$$ f(x_1, x_2) = \mathbf{1}^T((P_{[1, 4]}(x))^2)$$

In [11]:
def f(x):
    x = torch.clamp(x, 1, 4)
    out = (x**2).sum()
    return out

In [12]:
x = Variable(torch.Tensor([2, 6]), requires_grad=True)
# need to set requires_grad=True, otherwise it says it can't differentiate f
f_out = f(x)
f_out.backward()
print("grad(f(x)) = {}".format(x.grad.data))

grad(f(x)) = tensor([4., 0.])


## As expected, any projected value just gets 0 gradient

## The right way to project imo

In [13]:
x = Variable(torch.Tensor([2, 6]), requires_grad=True)
# need to set requires_grad=True, otherwise it says it can't differentiate f
with torch.no_grad():
    x = torch.clamp(x, 1, 4)
# need to reset requires_grad
x.requires_grad=True
f_out = f(x)
f_out.backward()
print("grad(f(x)) = {}".format(x.grad.data))

grad(f(x)) = tensor([4., 8.])


## EP's impelementation. Turns out their STE works the right way!
### Note that this function is really $$f(x_1, x_2) = \mathbb{0}\cdot x_1^2 + \mathbb{1}\cdot x_2^2 $$

In [14]:
# returns 1_{x > 5}
class GetIndicators(autograd.Function):
    @staticmethod
    def forward(ctx, x):
        """
            In the forward pass we receive a Tensor containing the input and return
            a Tensor containing the output. ctx is a context object that can be used
            to stash information for backward computation. You can cache arbitrary
            objects for use in the backward pass using the ctx.save_for_backward method.
        """
        out = x.clone()

        # flat_out and out access the same memory.
        flat_out = out.flatten()
        flat_out[0] = 0
        flat_out[1] = 1
        return out
    
    @staticmethod
    def backward(ctx, g):
        """
            In the backward pass we receive a Tensor containing the gradient of the loss
            with respect to the output, and we need to compute the gradient of the loss
            with respect to the input.
            NOTE: We need only "g" here because we have only one input -> x
        """
        # NOTE: This isn't getting printed! I don't know why
        print("Incoming grad = Outgoing grad = {}".format(g))
        # send the gradient g straight-through on the backward pass.
        return g

In [15]:
def f(x):
    mask = GetIndicators.apply(x)
    out = (mask * x**2).sum()
    return out

x = Variable(torch.Tensor([2, 6]), requires_grad=True)
f_out = f(x)
f_out.backward()
print("grad(f(x)) = {}".format(x.grad.data))

Incoming grad = Outgoing grad = tensor([ 4., 36.])
grad(f(x)) = tensor([ 4., 48.])


## Trying out our implementation of HC. That is using STE the right way too!

In [16]:
# returns 1_{x > 5}
class GetIndicators(autograd.Function):
    @staticmethod
    def forward(ctx, x):
        """
            In the forward pass we receive a Tensor containing the input and return
            a Tensor containing the output. ctx is a context object that can be used
            to stash information for backward computation. You can cache arbitrary
            objects for use in the backward pass using the ctx.save_for_backward method.
        """
        out = torch.gt(x, torch.ones_like(x)*5).int().float()
        return out
    
    @staticmethod
    def backward(ctx, g):
        """
            In the backward pass we receive a Tensor containing the gradient of the loss
            with respect to the output, and we need to compute the gradient of the loss
            with respect to the input.
            NOTE: We need only "g" here because we have only one input -> x
        """
        # NOTE: This isn't getting printed! I don't know why
        print("Incoming grad = Outgoing grad = {}".format(g))
        # send the gradient g straight-through on the backward pass.
        return g

In [17]:
def f(x):
    mask = GetIndicators.apply(x)
    out = (mask * x**2).sum()
    return out

x = Variable(torch.Tensor([2, 6]), requires_grad=True)
f_out = f(x)
f_out.backward()
print("grad(f(x)) = {}".format(x.grad.data))

Incoming grad = Outgoing grad = tensor([ 4., 36.])
grad(f(x)) = tensor([ 4., 48.])
