In [1]:
from __future__ import print_function
import torch
import torch.nn as nn
import math
import numpy as np
from functools import reduce

In [2]:
n =  2  # dimension of input
m =  3  # dimension of output
batch =  4  # batch size

In [3]:
def calc_logit(input, weight, bias=None):
    """
    Calculate logit = input.mm(weight.t()) + bias

    Args:
        input:  [1 x n]         Input vector
        weight:  [m x n]        Weight matrix
        bias:  [m]              Bias vector

    Returns:
        logit: [1 x n]          Logit = input.mm(weight.t()) + bias

    """

    logit = input.mm(weight.t())
    if bias is not None:
        logit += bias.unsqueeze(0).expand_as(logit)

    return logit

## ESGD

In [4]:
def esgd_arctan(grad_output, input, weight, bias, output):

    grad_output_scaled = grad_output / (1 + logit ** 2)  # [1 x m]
    grad_input = grad_output_scaled.mm(weight)  # [1 x n]
    grad_weight = grad_output_scaled.t().mm(input)  # [m x n]
    grad_bias = grad_output_scaled.sum(0).squeeze(0)  # [m]

    return grad_input, grad_weight, grad_bias

## Old ISGD implementation

In [100]:
def calc_backwards_variables(saved_tensors, logit, grad_output, lr, mu):
    """
    Calculate the variables required for back-propagation
    
    Args:
        saved_tensors:          Stores from forward-propagation the input, weight, bias, output
        logit: [1 x n]          Stores from forward-propagation the logit
        grad_output: [1 x m]    The gradient that has been back-propagated to this layer
        lr: [1]                 Learning rate
        mu: [1]                 Ridge-regularization constant

    Returns:
        input: [1 x n]          Input vector
        weight: [m x n]         Weight matrix
        bias: [m]               Bias vector
        output [1 x m]          Input to the next layer = logit put through the non-linear activation function
        logit: [1 x n]          Logit
        s: [1 x m]              Sign of back-propagated gradient
        z_norm: [1]             2-norm of (input, 1)
        d: [1 x m]              Weighted constant, proportional to the sqrt(abs(back-propagated gradient))
        c: [1 x m]              Logit contracted by ridge-regularization
    """

    # Unpack saved values
    input, weight, bias, output = saved_tensors

    # ISGD constants
    s = torch.sign(grad_output)
    z_norm = math.sqrt((torch.norm(input) ** 2 + 1.0))
    d = z_norm / math.sqrt(1.0 + lr * mu) * torch.sqrt(torch.abs(grad_output))
    c = logit / (1.0 + lr * mu)

    return input, weight, bias, output, logit, s, z_norm, d, c

In [99]:
def calc_weigh_bias_grad(weight, mu, lr, a, d, input, z_norm, bias):
    """
    Calculate the gradient of the weight matrix and bias vector

    Args:
        weight: [m x n]         Weight matrix
        mu: [1]                 Ridge-regularization constant
        lr: [1]                 Learning rate
        a: [1 x m]              Solution of ISGD update
        d: [1 x m]              Weighted constant, proportional to the sqrt(abs(back-propagated gradient))
        input: [1 x n]          Input vector
        z_norm: [1]             2-norm of (input, 1)
        bias: [m]               Bias vector

    Returns:
        grad_weight: [m x n]    Gradient of the weight matrix
        grad_bias: [m]          Gradient of the bias vector

    """

    grad_weight = weight * mu / (1.0 + lr * mu) - (a * d).t().mm(input) / z_norm ** 2
    grad_bias = bias * mu / (1.0 + lr * mu) - (a * d).squeeze() / z_norm ** 2
    return grad_weight, grad_bias

In [97]:
def real_root_closest_to_zero(coeff):
    roots = np.roots(coeff)
    real_roots = [root.real for root in roots if root.imag == 0]
    root_closest_to_zero = reduce((lambda x, y: x if (abs(x) < abs(y)) else y), real_roots)
    return root_closest_to_zero

In [98]:
def isgd_arctan(grad_output, input, weight, bias, output, logit):
    
    lr, mu = 0.00000001, 0.0
    input, weight, bias, output, logit, s, z_norm, d, c = calc_backwards_variables([input, weight, bias, output], logit,
                                                                                       grad_output, lr, mu)

    # Calculate a
    coeff = np.array([((lr * d) ** 2).numpy()[0],
                      (2 * lr * d * c).numpy()[0],
                      (c ** 2 + 1).numpy()[0],
                      (s * d).numpy()[0]])

    root_closest_to_zero = np.apply_along_axis(real_root_closest_to_zero, 0, coeff)
    a = torch.from_numpy(root_closest_to_zero).unsqueeze(1).t().type(torch.FloatTensor)

    # Calculate grad_weight, grad_bias and return all gradients
    grad_weight, grad_bias = calc_weigh_bias_grad(weight, mu, lr, a, d, input, z_norm, bias)

    # Calculate input gradient
    grad_output_scaled = grad_output / (1 + logit ** 2)  # [1 x m]
    grad_input = grad_output_scaled.mm(weight)  # [1 x n]

    return grad_input, grad_weight, grad_bias

In [96]:
def isgd_arctan2(grad_output, input, weight, bias, output, logit):
    
    lr, mu = 0.0, 0.0
    input, weight, bias, output, logit, s, z_norm, d, c = calc_backwards_variables([input, weight, bias, output], logit,
                                                                                       grad_output, lr, mu)

    # Calculate a
    d_d = d.double()
    s_m_d = (s * d).double()
    c_d = c.double()

    a = d_d * 0  # [b x m]
    a_diff = 1  # Norm difference between previous and current a values
    iter_count = 0  # Count of number of a iterations
    while a_diff > 1e-15:
        a_new = - s_m_d / (1.0 + (lr * d_d * a + c_d) ** 2)  # [b x m]
        a_diff = torch.norm(a - a_new)
        a = a_new  # [b x m]
        iter_count += 1
        if iter_count >= 50:
            assert (iter_count < 50), 'Arctan update has failed to converge'

    # Make a float so that can be operated on with other tensors
    a = a.float()

    # Calculate grad_weight, grad_bias and return all gradients
    grad_weight, grad_bias = calc_weigh_bias_grad(weight, mu, lr, a, d, input, z_norm, bias)

    # Calculate input gradient
    grad_output_scaled = grad_output / (1 + logit ** 2)  # [b x m]
    grad_input = grad_output_scaled.mm(weight)  # [b x n]

    return grad_input, grad_weight, grad_bias

## New ISGD implementation

In [152]:
def isgd_new_arctan(input, weight, bias, output, logit, grad_output):
    
    # Hyperparameters
    lr = 0.00000001
    mu = 0.0

    # ISGD constants
    z_norm_squared = (torch.norm(input, p=2, dim=1) ** 2 + 1.0).double()  # [b]
    c = (logit / (1.0 + lr * mu)).double()  # [b x m]

    # Calculate u
    # Calculate conditions for u

#     b = torch.mul(z_norm_squared, grad_output.t()).t() * lr / (1 + lr * mu)
#     v = cube_solver(b, c)
#     u = torch.div(v.t(), lr * z_norm_squared).t()

    b = torch.mul(z_norm_squared, grad_output.t()).t() * lr / (1 + lr * mu)
    v = cube_solver(b, c)
    u = torch.div(v.t(), lr * z_norm_squared).t()

    u = u_arctan()

    print('v_diff: ', v_diff(grad_output, c, lr, z_norm_squared, v))
    print('u_diff: ', u_diff(grad_output, c, lr, z_norm_squared, u))

    # Calculate input gradient
    grad_output_scaled = grad_output / (1 + logit ** 2)  # [b x m]
    grad_input = grad_output_scaled.mm(weight)  # [b x n]

    # Calculate grad_weight, grad_bias
    grad_weight = weight * mu / (1.0 + lr * mu) + u.t().mm(input)  # [m x n]
    grad_bias = bias * mu / (1.0 + lr * mu) + u.t().sum(1)  # [m]
    
    return grad_input, grad_weight, grad_bias

In [None]:
def u_arctan(z_norm_squared, grad_output, lr, mu):
    b = torch.mul(z_norm_squared, grad_output.t()).t() * lr / (1 + lr * mu)
    v = cube_solver(b, c)
    u = torch.div(v.t(), lr * z_norm_squared).t()

    return u

In [145]:
def cube_solver(b, c):
    """ Solves for v in the equation:  v * (1 + (v - c)**2) = b
    """
    b = b.double()
    c = c.double()
    
    delta = 27 * (b ** 2) - 4 * b * (c ** 3) - 36 * b * c + 4 * (c ** 4) + 8 * (c ** 2) + 4  # [b x m]
    gamma = 27 * b - 2 * (c ** 3) - 18 * c  # [b x m]
    beta = (3 * ((3 * delta) ** (1/2)) + gamma)
    cr_beta = torch.sign(beta) * (torch.abs(beta) ** (1/3))
    v = cr_beta / (3 * (2 ** (1/3))) - (2 ** (1/3)) * (3 - c**2) / (3 * cr_beta) + 2 * c / 3
    
#     # Check that the solution is valid
#     diff = x * (1 + (x - c)**2) - b
#     print(diff)

    return v.float()

## Test the differences between ESGD, ISGD old and ISGD_new

In [7]:
def u_diff(grad_output, c, lr, z_norm_squared, u):
    return u - grad_output / (1.0 + (c - lr * torch.mul(z_norm_squared, u.t()).t()) ** 2) / (1.0 + lr * mu)

In [151]:
def v_diff(grad_output, c, lr, z_norm_squared, v):
    b = torch.mul(z_norm_squared, grad_output.t()).t() * lr / (1 + lr * mu)
    return v - b / (1.0 + (c - v) ** 2)

In [153]:
# Random data
grad_output = torch.randn(batch, m)     # [b x m]
input = torch.randn(batch, n)           # [b x n]
weight = torch.randn(m, n)          # [m x n]
bias = torch.randn(m,)              # [m]

# Forward propagation
# Calculate logit [1 x m], where logit = input.mm(weight.t()) + bias
logit = calc_logit(input, weight, bias)

# Non-linear activation function
output = torch.atan(logit)  # [1 x m]

# Calculate gradients
esgd_grads = esgd_arctan(grad_output, input, weight, bias, output)
isgd_grads = isgd_arctan2(grad_output, input, weight, bias, output, logit)
isgd_new_grads = isgd_new_arctan(input, weight, bias, output, logit, grad_output)

# print(esgd_grads)
# print(isgd_new_grads)
# Print difference
# print('Difference between ESGD and ISGD old')
# print([(x-y) for x,y in zip(isgd_grads, esgd_grads)])

print('\nDifference between ESGD and ISGD new')
print([(x-y) for x,y in zip(isgd_new_grads, esgd_grads)])


1.00000e-13 *
  2.4799 -0.0142 -0.0006
 -0.0012 -0.0093  0.0022
 -0.0108 -0.0038 -0.0014
  0.0047  0.0220 -0.0003
[torch.DoubleTensor of size 4x3]

v_diff:  
1.00000e-14 *
  5.4179 -0.0222  0.0000
  0.0000  0.0444 -0.0888
 -0.0444 -0.0888  0.0000
  0.0444  0.0222 -0.0444
[torch.FloatTensor of size 4x3]

u_diff:  
1.00000e-06 *
  2.4438 -0.0149  0.0000
 -0.1192  0.0149 -0.0298
 -0.0298  0.0000  0.0298
  0.0149  0.0075  0.0000
[torch.FloatTensor of size 4x3]


Difference between ESGD and ISGD new
[
 0  0
 0  0
 0  0
 0  0
[torch.FloatTensor of size 4x2]
, 
1.00000e-06 *
 -2.6226 -1.4305
  0.0298  0.0298
 -0.0596 -0.0596
[torch.FloatTensor of size 3x2]
, 
1.00000e-06 *
  2.2650
  0.0000
 -0.0149
[torch.FloatTensor of size 3]
]


In [113]:
# Hyperparameters
lr = 0.0001
mu = 0.0

# ISGD constants
s = torch.sign(grad_output)  # [b x m]
z_norm_squared = torch.norm(input, p=2, dim=1) ** 2 + 1.0  # [b]
c = logit / (1.0 + lr * mu)  # [b x m]

b = torch.mul(z_norm_squared, grad_output.t()).t() * lr / (1 + lr * mu)
v = cube_solver(b, c)
u = torch.div(grad_output.t(), lr * z_norm_squared).t()
print(u)

# print('delta: ', delta)
# print('gamma: ', gamma)
# print('beta: ', beta)
# print('u: ', u)

grad_weight = weight * mu / (1.0 + lr * mu) + u.t().mm(input)  # [m x n]
grad_bias = bias * mu / (1.0 + lr * mu) + u.t().sum(1)  # [m]

# u if it were created using esgd
u_esgd = grad_output / (1 + c ** 2)

# print('grad_weight: ', grad_weight)
# print('grad_bias: ', grad_bias)


 -208.8467   114.0256  2164.5085
-3141.7490 -5193.1665 -4957.5005
 5208.5259  7695.1826 -3666.8020
 2619.1262 -1915.1833 -3137.2427
[torch.FloatTensor of size 4x3]



In [108]:
v_diff(grad_output, c, lr, z_norm_squared, v)


1.00000e-11 *
  0.0000  0.0000  0.0000
  0.0000  0.3638 -0.3638
  0.0000  1.4552  0.3638
  0.0000  0.0000 -0.0909
[torch.FloatTensor of size 4x3]

In [114]:
b = torch.mul(z_norm_squared, grad_output.t()).t() * lr / (1 + lr * mu)
print(v - b / (1.0 + (c - v) ** 2))
u = torch.div(v.t(), lr * z_norm_squared).t()
# print(u - torch.div((b / (1.0 + (c - v) ** 2) ).t(), lr * z_norm_squared).t())
print(u - grad_output / (1.0 + (c - v) ** 2) / (1 + lr * mu) )
print(u)


1.00000e-11 *
  0.0000  0.0000  0.0000
  0.0000  0.3638 -0.3638
  0.0000  1.4552  0.3638
  0.0000  0.0000 -0.0909
[torch.FloatTensor of size 4x3]


1.00000e-08 *
  0.0000  0.1863  0.0000
  2.9802  0.0000 -2.9802
  0.0000  5.9605  0.0000
  0.0000 -2.9802 -0.7451
[torch.FloatTensor of size 4x3]


-0.0437  0.0156  0.0824
-0.3537 -0.3011 -0.3721
 0.8572  0.6016 -0.2674
 0.3294 -0.4565 -0.0494
[torch.FloatTensor of size 4x3]



In [93]:
# If accurate these should all be zero
print('explicit u_diff: ', u_diff(grad_output, c, lr, z_norm_squared, u_esgd))
print('implicit u_diff: ', u_diff(grad_output, c, lr, z_norm_squared, u))

explicit u_diff:  
1.00000e-04 *
 -0.0490  0.0799  0.1079
 -0.8857  0.1895  0.0288
  0.9292  0.1058  0.0031
 -0.0018 -0.0533 -2.1857
[torch.FloatTensor of size 4x3]

implicit u_diff:  
1.00000e-08 *
  0.0000  2.9802  0.0000
  0.0000  0.0000  0.0000
  0.0000  1.4901  0.0000
  0.1863  0.0000  5.9605
[torch.FloatTensor of size 4x3]



In [80]:
delta = 27 * (b ** 2) - 4 * b * (c ** 3) - 36 * b * c + 4 * (c ** 4) + 8 * (c ** 2) + 4  # [b x m]
gamma = 27 * b - 2 * (c ** 3) - 18 * c  # [b x m]
beta = (3 * ((3 * delta) ** (1/2)) + gamma)
cr_beta = torch.sign(beta) * (torch.abs(beta) ** (1/3))
v = cr_beta / (3 * (2 ** (1/3))) - (2 ** (1/3)) * (3 - c**2) / (3 * cr_beta) + 2 * c / 3

print('delta: ', delta)
print('gamma: ', gamma)
print('beta: ', beta)
print('cr_beta: ', cr_beta)
print('v: ', v)
print(cube_solver(b, c))

delta:  
   56.1127     8.9971    40.9222
  110.5813     4.0082  2267.0413
    6.0503   179.2418   415.8600
 1525.8562    34.0047     4.9125
[torch.FloatTensor of size 4x3]

gamma:  
 -38.9227   13.4318   33.2091
 -54.7136    0.5707  303.7489
   8.8526   70.1259  110.3616
-237.0251  -30.2165   -5.9979
[torch.FloatTensor of size 4x3]

beta:  
 8.4686e-04  2.9018e+01  6.6449e+01
-7.2071e-02  1.0974e+01  5.5116e+02
 2.1634e+01  1.3969e+02  2.1632e+02
-3.4052e+01  8.4171e-02  5.5189e+00
[torch.FloatTensor of size 4x3]

cr_beta:  
 0.0946  3.0729  4.0504
-0.4162  2.2222  8.1989
 2.7864  5.1887  6.0030
-3.2413  0.4382  1.7672
[torch.FloatTensor of size 4x3]

v:  
1.00000e-03 *
 -1.5748 -0.0314 -0.0369
 -0.2987  0.5080  0.0782
 -0.1945  0.0688  0.0129
 -0.0153  0.0650  0.4473
[torch.FloatTensor of size 4x3]


1.00000e-03 *
 -1.5748 -0.0314 -0.0369
 -0.2987  0.5080  0.0782
 -0.1945  0.0688  0.0129
 -0.0153  0.0650  0.4473
[torch.FloatTensor of size 4x3]



In [8]:
def f1(u, c, lr, z_norm_squared, grad_output, mu):
    return (u * (1.0 + c**2)
                -2 * c * lr * torch.mul(z_norm_squared, (u**2).t()).t()
                + (lr ** 2 * torch.mul(z_norm_squared** 2, (u** 3).t()).t())  
      - grad_output / (1.0 + lr * mu))

def f2(u, c, lr, z_norm_squared, grad_output, mu):
    return ((1.0 + c**2)
                -4 * c * lr * torch.mul(z_norm_squared, (u_esgd).t()).t()
                + 3 *(lr ** 2 * torch.mul(z_norm_squared** 2, (u_esgd** 2).t()).t()))

In [20]:
# print(u_diff(grad_output, c, lr, z_norm_squared, u_esgd))
# print(u_esgd - grad_output / (1.0 + (c - lr * torch.mul(z_norm_squared, u_esgd.t()).t()) ** 2) / (1.0 + lr * mu))
# print(u_esgd * (1.0 + (c - lr * torch.mul(z_norm_squared, u_esgd.t()).t()) ** 2) - grad_output / (1.0 + lr * mu))
lr = 0.01
# print(u_esgd * (1.0 + c**2
#                 -2 * c * lr * torch.mul(z_norm_squared, u_esgd.t()).t()
#                 + (lr * torch.mul(z_norm_squared, u_esgd.t()).t()) ** 2) 
#       - grad_output / (1.0 + lr * mu))
f1 = f1(u_esgd, c, lr, z_norm_squared, grad_output, mu)
f2 = f2(u_esgd, c, lr, z_norm_squared, grad_output, mu)

u = u_esgd
for i in range(10):
#     print(u_diff(grad_output, c, lr, z_norm_squared, u))
    u = u - f1 / f2
