In [1]:
from __future__ import print_function
import torch
import torch.nn as nn
import math

In [167]:
n =  2  # dimension of input
m =  3  # dimension of output
b =  1  # batch size

# Linear

In [3]:
def calc_logit(input, weight, bias=None):
    """
    Calculate logit = input.mm(weight.t()) + bias

    Args:
        input:  [b x n]         Input vector
        weight:  [m x n]        Weight matrix
        bias:  [m]              Bias vector

    Returns:
        logit: [b x n]          Logit = input.mm(weight.t()) + bias

    """

    logit = input.mm(weight.t())
    if bias is not None:
        logit += bias.unsqueeze(0).expand_as(logit)

    return logit

In [4]:
# Forward propagation
def forward(input, weight, bias):
    """
    Arguments:
    input [1 x n]      The input vector to the layer
    weight [m x n]     The weight matrix of the layer
    bias [m]           The bias vector of the layer
    
    Returns:
    output [1 x m]     The input to the next layer = logit put through the non-linear activation function
    """
    
    # Calculate logit [1 x m], where logit = input.mm(weight.t()) + bias
    logit = input.mm(weight.t())
    if bias is not None:
        logit += bias.unsqueeze(0).expand_as(logit)

    # Non-linear activation function
    output = torch.clamp(logit, min=0.0)
    
    return output

## ESGD

In [5]:
def esgd_relu(grad_output, input, weight, bias, output):

    # Find all nodes where the output is greater than or equal to 0
    ge0 = (output > 0).type(torch.FloatTensor)  # [1 x m]
        
    # Mask the back-propagated gradient to zero out elements where the output is zero.
    grad_output_masked = ge0 * grad_output  # [1 x m]

    # Calculate gradients
    grad_input = grad_output_masked.mm(weight)  # [1 x n]
    grad_weight = grad_output_masked.t().mm(input)  # [m x n]
    grad_bias = grad_output_masked.sum(0).squeeze(0)  # [m]

    return grad_input, grad_weight, grad_bias

## Old ISGD implementation

In [6]:
def calc_backwards_variables(input, weight, bias, output, logit, grad_output, lr, mu):
    """
    Calculate the variables required for back-propagation
    
    Args:
        saved_tensors:          Stores from forward-propagation the input, weight, bias, output
        logit: [b x n]          Stores from forward-propagation the logit
        grad_output: [b x m]    The gradient that has been back-propagated to this layer
        lr: [1]                 Learning rate
        mu: [1]                 Ridge-regularization constant

    Returns:
        input: [b x n]          Input vector
        weight: [m x n]         Weight matrix
        bias: [m]               Bias vector
        output [b x m]          Input to the next layer = logit put through the non-linear activation function
        logit: [b x n]          Logit
        s: [b x m]              Sign of back-propagated gradient
        z_norm: [b]             2-norm of (input, 1)
        d: [b x m]              Weighted constant, proportional to the sqrt(abs(back-propagated gradient))
        c: [b x m]              Logit contracted by ridge-regularization
    """

    # ISGD constants
    s = torch.sign(grad_output)  # [b x m]
    z_norm = torch.sqrt(torch.norm(input, p=2, dim=1) ** 2 + 1.0)  # [b]
    d = torch.mul(z_norm, torch.sqrt(torch.abs(grad_output)).t()).t() / math.sqrt(1.0 + lr * mu)  # [b x m]
    c = logit / (1.0 + lr * mu)  # [b x m]

    return input, weight, bias, output, logit, s, z_norm, d, c

In [7]:
def calc_weigh_bias_grad(weight, mu, lr, a, d, input, z_norm, bias):
    """
    Calculate the gradient of the weight matrix and bias vector

    Args:
        weight: [m x n]         Weight matrix
        mu: [1]                 Ridge-regularization constant
        lr: [1]                 Learning rate
        a: [b x m]              Solution of ISGD update
        d: [b x m]              Weighted constant, proportional to the sqrt(abs(back-propagated gradient))
        input: [b x n]          Input vector
        z_norm: [b]             2-norm of (input, 1)
        bias: [m]               Bias vector

    Returns:
        grad_weight: [m x n]    Gradient of the weight matrix
        grad_bias: [m]          Gradient of the bias vector

    """

    grad_weight = weight * mu / (1.0 + lr * mu) - torch.mul(z_norm ** -2, (a * d).t()).mm(input)  # [m x n]
    grad_bias = bias * mu / (1.0 + lr * mu) - torch.mul(z_norm ** -2, (a * d).t()).sum(1)  # [m]
    return grad_weight, grad_bias

In [8]:
def isgd_relu(input, weight, bias, output, logit, grad_output):
    
    # Hyperparameters
    lr = 0.01
    mu = 0.0
    
    input, weight, bias, output, logit, s, z_norm, d, c = calc_backwards_variables(input, weight, bias, output, logit,
                                                                                       grad_output, lr, mu)

    # Calculate a
    # Calculate conditions for a
    conds0 = (s == 0).type(torch.FloatTensor)  # [b x m]
    cond1 = ((s == +1) * (c <= 0)).type(torch.FloatTensor)  # [b x m]
    cond2 = ((s == +1) * (c > 0) * (c <= (lr * d ** 2))).type(torch.FloatTensor)  # [b x m]
    cond3 = ((s == +1) * (c > (lr * d ** 2))).type(torch.FloatTensor)  # [b x m]
    cond4 = ((s == -1) * (c <= -(lr * d ** 2) / 2.0)).type(torch.FloatTensor)  # [b x m]
    cond5 = ((s == -1) * (c > -(lr * d ** 2) / 2.0)).type(torch.FloatTensor)  # [b x m]

    # Check that exactly one condition satisfied for each node
    cond_sum = (conds0 + cond1 + cond2 + cond3 + cond4 + cond5)  # [b x m]
    assert torch.mean(
        (cond_sum == 1).type(torch.FloatTensor)) == 1.0, 'No implicit update condition was satisfied'

    # Calculate a
    a = (0.0 * conds0
         + 0.0 * cond1
         - (c / (lr * d)) * cond2
         - d * cond3
         + 0.0 * cond4
         + d * cond5
         )  # [b x m]

    # a might contain Nan values if d = 0 at certain elements due to diving by d in (c / (lr * d)) * cond2
    # The operation below sets all Nans to zero
    # This is the appropriate value for ISGD
    a[a != a] = 0

    # Calculate input gradient
    ge0 = (output > 0).type(torch.FloatTensor)  # [b x m]
    grad_output_masked = ge0 * grad_output  # [b x m]
    grad_input = grad_output_masked.mm(weight)  # [b x n]

    # Calculate grad_weight, grad_bias and return all gradients
    grad_weight, grad_bias = calc_weigh_bias_grad(weight, mu, lr, a, d, input, z_norm, bias)
    
    return grad_input, grad_weight, grad_bias

## New ISGD implementation

In [202]:
def isgd_new_relu(input, weight, bias, output, logit, grad_output):
    
    # Hyperparameters
    lr = 0.00001
    mu = 0.0

    # ISGD constants
    s = torch.sign(grad_output)  # [b x m]
    z_norm_squared = torch.norm(input, p=2, dim=1) ** 2 + 1.0  # [b]
    c = logit / (1.0 + lr * mu)  # [b x m]

    # Calculate u
    # Calculate conditions for u
    threshold = lr * torch.mul(z_norm_squared, grad_output.t()).t() / (1.0 + lr * mu)  # [b x m]

    cond0 = (s == 0).type(torch.FloatTensor)  # [b x m]
    cond1 = ((s == +1) * (c <= 0)).type(torch.FloatTensor)  # [b x m]
    cond2 = ((s == +1) * (c > 0) * (c <= threshold)).type(torch.FloatTensor)  # [b x m]
    cond3 = ((s == +1) * (c > threshold)).type(torch.FloatTensor)  # [b x m]
    cond4 = ((s == -1) * (c <= threshold / 2.0)).type(torch.FloatTensor)  # [b x m]
    cond5 = ((s == -1) * (c > threshold / 2.0)).type(torch.FloatTensor)  # [b x m]

    # Check that exactly one condition satisfied for each node
    cond_sum = (cond0 + cond1 + cond2 + cond3 + cond4 + cond5)  # [b x m]
    assert torch.mean(
        (cond_sum == 1).type(torch.FloatTensor)) == 1.0, 'No implicit update condition was satisfied'

    # Calculate u
    u = (0.0 * (cond0 + cond1 + cond4)
         + torch.div(c.t(), z_norm_squared).t() / lr * cond2
         + grad_output / (1.0 + lr * mu) * (cond3 + cond5)
         )  # [b x m]

    # a might contain Nan values if d = 0 at certain elements due to diving by d in (c / (lr * d)) * cond2
    # The operation below sets all Nans to zero
    # This is the appropriate value for ISGD
    u[u != u] = 0

    # Calculate input gradient
    ge0 = (output > 0).type(torch.FloatTensor)  # [b x m]
    grad_output_masked = ge0 * grad_output  # [b x m]
    grad_input = grad_output_masked.mm(weight)  # [b x n]

    # Calculate grad_weight, grad_bias
    grad_weight = weight * mu / (1.0 + lr * mu) + u.t().mm(input)  # [m x n]
    grad_bias = bias * mu / (1.0 + lr * mu) + u.t().sum(1)  # [m]
    
    return grad_input, grad_weight, grad_bias

## Test the differences between ESGD, ISGD old and ISGD_new

In [232]:
# Random data
grad_output = torch.randn(b, m)     # [b x m]
input = torch.randn(b, n)           # [b x n]
weight = torch.randn(m, n)          # [m x n]
bias = torch.randn(m,)              # [m]

# Check that forward propagation makes sense
logit = calc_logit(input, weight, bias)
output = forward(input, weight, bias)

# Calculate gradients
esgd_grads = esgd_relu(grad_output, input, weight, bias, output)
isgd_new_grads = isgd_new_relu(input, weight, bias, output, logit, grad_output)
isgd_grads = isgd_relu(input, weight, bias, output, logit, grad_output)


# Print difference
# print('Difference between ESGD and ISGD old')
# print([(x-y) for x,y in zip(isgd_grads, esgd_grads)])

print('\nDifference between ESGD and ISGD new')
print([(x-y) for x,y in zip(isgd_new_grads, esgd_grads)])


Difference between ESGD and ISGD new
[
 0  0
[torch.FloatTensor of size 1x2]
, 
 0  0
 0  0
 0  0
[torch.FloatTensor of size 3x2]
, 
 0
 0
 0
[torch.FloatTensor of size 3]
]
