In [1]:
from __future__ import print_function
import torch
import torch.nn as nn
import math

In [2]:
"""
n = dimension of input
m = dimension of output
"""

'\nn = dimension of input\nm = dimension of output\n'

# Linear

In [3]:
def calc_logit(input, weight, bias=None):
    """
    Calculate logit = input.mm(weight.t()) + bias

    Args:
        input:  [1 x n]         Input vector
        weight:  [m x n]        Weight matrix
        bias:  [m]              Bias vector

    Returns:
        logit: [1 x n]          Logit = input.mm(weight.t()) + bias

    """

    logit = input.mm(weight.t())
    if bias is not None:
        logit += bias.unsqueeze(0).expand_as(logit)

    return logit

## ESGD

In [17]:
def esgd_hardtanh(grad_output, input, weight, bias, output):

    non_clamped = ((output > -1) * (output < 1)).type(torch.FloatTensor)  # [1 x m]
    grad_output_masked = non_clamped * grad_output  # [1 x m]

    # Calculate gradients
    grad_input = grad_output_masked.mm(weight)  # [1 x n]
    grad_weight = grad_output_masked.t().mm(input)  # [m x n]
    grad_bias = grad_output_masked.sum(0).squeeze(0)  # [m]

    return grad_input, grad_weight, grad_bias

## ISGD

In [22]:
def calc_backwards_variables(saved_tensors, logit, grad_output, lr, mu):
    """
    Calculate the variables required for back-propagation
    
    Args:
        saved_tensors:          Stores from forward-propagation the input, weight, bias, output
        logit: [1 x n]          Stores from forward-propagation the logit
        grad_output: [1 x m]    The gradient that has been back-propagated to this layer
        lr: [1]                 Learning rate
        mu: [1]                 Ridge-regularization constant

    Returns:
        input: [1 x n]          Input vector
        weight: [m x n]         Weight matrix
        bias: [m]               Bias vector
        output [1 x m]          Input to the next layer = logit put through the non-linear activation function
        logit: [1 x n]          Logit
        s: [1 x m]              Sign of back-propagated gradient
        z_norm: [1]             2-norm of (input, 1)
        d: [1 x m]              Weighted constant, proportional to the sqrt(abs(back-propagated gradient))
        c: [1 x m]              Logit contracted by ridge-regularization
    """

    # Unpack saved values
    input, weight, bias, output = saved_tensors

    # ISGD constants
    s = torch.sign(grad_output)
    z_norm = math.sqrt((torch.norm(input) ** 2 + 1.0))
    d = z_norm / math.sqrt(1.0 + lr * mu) * torch.sqrt(torch.abs(grad_output))
    c = logit / (1.0 + lr * mu)

    return input, weight, bias, output, logit, s, z_norm, d, c

In [27]:
def calc_weigh_bias_grad(weight, mu, lr, a, d, input, z_norm, bias):
    """
    Calculate the gradient of the weight matrix and bias vector

    Args:
        weight: [m x n]         Weight matrix
        mu: [1]                 Ridge-regularization constant
        lr: [1]                 Learning rate
        a: [1 x m]              Solution of ISGD update
        d: [1 x m]              Weighted constant, proportional to the sqrt(abs(back-propagated gradient))
        input: [1 x n]          Input vector
        z_norm: [1]             2-norm of (input, 1)
        bias: [m]               Bias vector

    Returns:
        grad_weight: [m x n]    Gradient of the weight matrix
        grad_bias: [m]          Gradient of the bias vector

    """

    grad_weight = weight * mu / (1.0 + lr * mu) - (a * d).t().mm(input) / z_norm ** 2
    grad_bias = bias * mu / (1.0 + lr * mu) - (a * d).squeeze() / z_norm ** 2
    return grad_weight, grad_bias

In [31]:
def isgd_hardtanh(grad_output, input, weight, bias, output, logit):
    
    lr, mu = 0.01, 0.0
    input, weight, bias, output, logit, s, z_norm, d, c = calc_backwards_variables([input, weight, bias, output], logit,
                                                                                       grad_output, lr, mu)

    # Calculate a
    sc = s * c
    cond1 = ((sc <= -1)
             + (sc >= (1 + lr * d ** 2 / 2))).type(torch.FloatTensor)
    cond2 = ((sc >= (lr * d ** 2 - 1)) * (sc <= 1)
             + ((sc >= torch.clamp(lr * d ** 2 / 2 - 1, min=1))
                * (sc <= (1 + lr * d ** 2 / 2)))).type(torch.FloatTensor)
    cond3 = ((sc >= -1) * (sc <= torch.clamp(lr * d ** 2 - 1, max=1))
             + (sc >= 1) * (sc <= (lr * d ** 2 / 2 - 1))).type(torch.FloatTensor)

    # Remove double conditions!
    # TODODODODODODODODO!

    a = (0.0 * cond1
         - s * d * cond2
         - (1 + sc) / (lr * d) * cond3
         )

    # a might contain Nan values if d = 0 at certain elements due to diving by d in (c / (lr * d)) * cond2
    # The operation below sets all Nans to zero
    # This is the appropriate value for ISGD
    a[a != a] = 0

    # Calculate grad_weight, grad_bias and return all gradients
    grad_weight, grad_bias = calc_weigh_bias_grad(weight, mu, lr, a, d, input, z_norm, bias)

    # Calculate input gradient
    non_clamped = ((output > -1) * (output < 1)).type(torch.FloatTensor)  # [1 x m]
    grad_output_masked = non_clamped * grad_output  # [1 x m]
    grad_input = grad_output_masked.mm(weight)  # [1 x n]

    # Calculate gradients
    # grad_weight = grad_output_masked.t().mm(input)  # [m x n]
    # grad_bias = grad_output_masked.sum(0).squeeze(0)  # [m]
    return grad_input, grad_weight, grad_bias



In [182]:
lr, mu = 0.01, 0.0
input, weight, bias, output, logit, s, z_norm, d, c = calc_backwards_variables([input, weight, bias, output], logit,
                                                                                   grad_output, lr, mu)

# Calculate a
sc = s * c
cond1 = ((sc <= -1)
         + (sc >= (1 + lr * d ** 2 / 2))).type(torch.FloatTensor)
cond2 = ((sc >= (lr * d ** 2 - 1)) * (sc <= 1)
         + ((sc >= torch.clamp(lr * d ** 2 / 2 - 1, min=1))
            * (sc <= (1 + lr * d ** 2 / 2)))).type(torch.FloatTensor)
cond3 = ((sc >= -1) * (sc <= torch.clamp(lr * d ** 2 - 1, max=1))
         + (sc >= 1) * (sc <= (lr * d ** 2 / 2 - 1))).type(torch.FloatTensor)

In [218]:
cond_sum = (cond1 + cond2 + cond3)
print(cond_sum)
assert(torch.mean((cond_sum >= 1).type(torch.FloatTensor)) == 1.0)


 1  1  1
[torch.FloatTensor of size 1x3]



In [215]:
cc = torch.Tensor([1,0,1])
print(cond_sum / cc)


  1 inf   1
[torch.FloatTensor of size 1x3]



## Test the differences between ESGD and ISGD

In [237]:
# Random data
grad_output = torch.randn(1, 3)     # [1 x m]
input = torch.randn(1, 2)           # [1 x n]
weight = torch.randn(3, 2)          # [m x n]
bias = torch.randn(3,)              # [m]

# Forward propagation
# Calculate logit [1 x m], where logit = input.mm(weight.t()) + bias
logit = calc_logit(input, weight, bias)

# Non-linear activation function
output = torch.clamp(logit, min=-1.0, max=1)  # [1 x m]

# print('logit: ', logit)
# print('output: ', output)

isgd_grads = isgd_hardtanh(grad_output, input, weight, bias, output, logit)
esgd_grads = esgd_hardtanh(grad_output, input, weight, bias, output)

print([(x-y) for x,y in zip(isgd_grads, esgd_grads)])

[
 0  0
[torch.FloatTensor of size 1x2]
, 
1.00000e-09 *
 -1.8626  1.8626
  0.0000  0.0000
  3.7253 -3.7253
[torch.FloatTensor of size 3x2]
, 
1.00000e-08 *
 -2.9802
 -0.0000
  5.9605
[torch.FloatTensor of size 3]
]


# Full Pytorch module

## Forward propagation

In [None]:
input = torch.randn(1, 3)
weight = torch.randn(5, 3)
bias = torch.randn(5,)
output = input.mm(weight.t())
output += bias.unsqueeze(0).expand_as(output)
relu = torch.clamp(output, min=0.0)

print('input: ', input)
print('weight: ', weight.size())
print('bias: ', bias.size())
print('output: ', output.size())
print('relu: ', relu.size())

## Back-propagation

## ISGD

In [None]:
# Constants
s = torch.sign(grad_output)
abs_grad_output = torch.abs(grad_output)
# Note that torch.norm outputs a float instead of a tensor
z_norm = math.sqrt((torch.norm(input) ** 2 + 1.0))
d = z_norm * math.sqrt(lr/(1.0+lr*mu)) * torch.sqrt(abs_grad_output)
c = output/(1.0+lr*mu)
# print('s: ', s)
# print('delta: ', d) 
# print(c)

# Calculate alpha
alpha = alpha_relu(s,d,c)

# Calculate gradients
new_weight = weight / (1.0 + lr * mu) + alpha.mul(d).mm(weight) / z_norm**2
grad_weight = (weight - new_weight) / lr
# print(weight)
# print(new_weight)
# print(grad_weight)

new_bias = bias / (1.0 + lr * mu) + alpha.mul(d).squeeze().mul(bias) / z_norm**2
grad_bias = (bias - new_bias) / lr
# print(bias)
# print(new_bias)
# print(grad_bias)

sgn_output = (output >= 0).type(torch.FloatTensor)
grad_input = (grad_output.mul(sgn_output)).mm(weight)
print(grad_input)

In [None]:
def alpha_relu(s,d,c):
#     cond1 = (s == 1).mul(c <= 0).type(torch.FloatTensor)
    cond2 = (s == 1).mul(c > 0).mul(c <= d**2).type(torch.FloatTensor)
    cond3 = (s == 1).mul(c > d**2).type(torch.FloatTensor)
#     cond4 = (s == -1).mul(c <= -d**2/2.0).type(torch.FloatTensor)
    cond5 = (s == -1).mul(c > -d**2/2.0).type(torch.FloatTensor)
    # print(cond1, cond2, cond3, cond4, cond5)

    alpha = (0.0
#              + 0.0 * cond1
            - (c.div(d)).mul(cond2)
            - d.mul(cond3)
#             + 0.0 * cond4
            + d.mul(cond5)
            )

    return alpha

In [None]:
alpha_relu(s,d,c)

In [None]:
# Understand grad_output_pos_out.sum(0).squeeze(0)
print(grad_output_pos_out)
print(grad_output_pos_out.sum(0))
print(grad_output_pos_out.sum(0).squeeze(0))

## Standard RELU

In [None]:
pos_out = (output >= 0).type(torch.FloatTensor)
grad_output_pos_out = torch.mul(grad_output, pos_out)
grad_input = grad_output_pos_out.mm(weight)
grad_weight = grad_output_pos_out.t().mm(input)
grad_bias = grad_output_pos_out.sum(0).squeeze(0)

print('grad_output: ', grad_output.size())
print('output: ', output.size())
print('pos_out: ', pos_out.size())
print('grad_output_pos_out: ', grad_output_pos_out.size())
print('grad_input: ', grad_input.size())
print('grad_bias: ', grad_bias.size())
print('grad_weight: ', grad_weight.size())