In [1]:
from __future__ import print_function
import torch
import torch.nn as nn
import math
import numpy as np
from functools import reduce

In [2]:
n =  2  # dimension of input
m =  3  # dimension of output
batch =  4  # batch size

In [3]:
def calc_logit(input, weight, bias=None):
    """
    Calculate logit = input.mm(weight.t()) + bias

    Args:
        input:  [1 x n]         Input vector
        weight:  [m x n]        Weight matrix
        bias:  [m]              Bias vector

    Returns:
        logit: [1 x n]          Logit = input.mm(weight.t()) + bias

    """

    logit = input.mm(weight.t())
    if bias is not None:
        logit += bias.unsqueeze(0).expand_as(logit)

    return logit

## ESGD

In [4]:
def esgd_arctan(grad_output, input, weight, bias, output):

    grad_output_scaled = grad_output / (1 + logit ** 2)  # [1 x m]
    grad_input = grad_output_scaled.mm(weight)  # [1 x n]
    grad_weight = grad_output_scaled.t().mm(input)  # [m x n]
    grad_bias = grad_output_scaled.sum(0).squeeze(0)  # [m]

    return grad_input, grad_weight, grad_bias

## New ISGD implementation

In [5]:
def solve(coeff):
    a, b, c, d = coeff
    if (a == 0 and b == 0):  # Case for handling Liner Equation
        return np.array([(-d * 1.0) / c])  # Returning linear root as numpy array.

    elif (a == 0):  # Case for handling Quadratic Equations

        D = c * c - 4.0 * b * d  # Helper Temporary Variable
        if D >= 0:
            D = math.sqrt(D)
            x1 = (-c + D) / (2.0 * b)
            x2 = (-c - D) / (2.0 * b)
        else:
            D = math.sqrt(-D)
            x1 = (-c + D * 1j) / (2.0 * b)
            x2 = (-c - D * 1j) / (2.0 * b)

        return np.array([x1, x2])  # Returning Quadratic Roots as numpy array.

    f = findF(a, b, c)  # Helper Temporary Variable
    g = findG(a, b, c, d)  # Helper Temporary Variable
    h = findH(g, f)  # Helper Temporary Variable

    if f == 0 and g == 0 and h == 0:  # All 3 Roots are Real and Equal
        if (d / a) >= 0:
            x = (d / (1.0 * a)) ** (1 / 3.0) * -1
        else:
            x = (-d / (1.0 * a)) ** (1 / 3.0)
        return np.array([x, x, x])  # Returning Equal Roots as numpy array.

    elif h <= 0:  # All 3 roots are Real

        i = math.sqrt(((g ** 2.0) / 4.0) - h)  # Helper Temporary Variable
        j = i ** (1 / 3.0)  # Helper Temporary Variable
        k = math.acos(-(g / (2 * i)))  # Helper Temporary Variable
        L = j * -1  # Helper Temporary Variable
        M = math.cos(k / 3.0)  # Helper Temporary Variable
        N = math.sqrt(3) * math.sin(k / 3.0)  # Helper Temporary Variable
        P = (b / (3.0 * a)) * -1  # Helper Temporary Variable

        x1 = 2 * j * math.cos(k / 3.0) - (b / (3.0 * a))
        x2 = L * (M + N) + P
        x3 = L * (M - N) + P

        return np.array([x1, x2, x3])  # Returning Real Roots as numpy array.

    elif h > 0:  # One Real Root and two Complex Roots
        R = -(g / 2.0) + math.sqrt(h)  # Helper Temporary Variable
        if R >= 0:
            S = R ** (1 / 3.0)  # Helper Temporary Variable
        else:
            S = (-R) ** (1 / 3.0) * -1  # Helper Temporary Variable
        T = -(g / 2.0) - math.sqrt(h)
        if T >= 0:
            U = (T ** (1 / 3.0))  # Helper Temporary Variable
        else:
            U = ((-T) ** (1 / 3.0)) * -1  # Helper Temporary Variable

        x1 = (S + U) - (b / (3.0 * a))
        x2 = -(S + U) / 2 - (b / (3.0 * a)) + (S - U) * math.sqrt(3) * 0.5j
        x3 = -(S + U) / 2 - (b / (3.0 * a)) - (S - U) * math.sqrt(3) * 0.5j

        return np.array([x1, x2, x3])  # Returning One Real Root and two Complex Roots as numpy array.


# Helper function to return float value of f.
def findF(a, b, c):
    return ((3.0 * c / a) - ((b ** 2.0) / (a ** 2.0))) / 3.0


# Helper function to return float value of g.
def findG(a, b, c, d):
    return (((2.0 * (b ** 3.0)) / (a ** 3.0)) - ((9.0 * b * c) / (a ** 2.0)) + (27.0 * d / a)) / 27.0


# Helper function to return float value of h.
def findH(g, f):
    return ((g ** 2.0) / 4.0 + (f ** 3.0) / 27.0)

In [6]:
def real_root_closest_to_zero_np(coeff):
    """
    Given a list of polynomial coefficients,
    return the real root that is closest to zero

    Args:
        coeff:  List of polynomial coefficients

    Returns:
        root_closest_to_zero:   Root that is closest to zero

    """
    # Calculate all (complex) roots
    roots = np.roots(coeff) #solve(coeff)  #

    # Extract real roots
    # Note cannot use root.imag == 0 since numpy sometimes has a tiny imaginary component for real roots
    # See: https://stackoverflow.com/questions/28081247/print-real-roots-only-in-numpy
    real_roots = (root.real for root in roots if abs(root.imag) < 1e-10)

    # Extract the real root that is closest to zero
    root = reduce((lambda x, y: x if (abs(x) < abs(y)) else y), real_roots)

    return root

In [7]:
def real_root_closest_to_zero(coeff):
    """
    Given a list of polynomial coefficients,
    return the real root that is closest to zero

    Args:
        coeff:  List of polynomial coefficients

    Returns:
        root_closest_to_zero:   Root that is closest to zero

    """
    # Calculate all (complex) roots
    roots = solve(coeff)  #

    # Extract real roots
    # Note cannot use root.imag == 0 since numpy sometimes has a tiny imaginary component for real roots
    # See: https://stackoverflow.com/questions/28081247/print-real-roots-only-in-numpy
    real_roots = (root.real for root in roots if abs(root.imag) < 1e-10)

    # Extract the real root that is closest to zero
    root = reduce((lambda x, y: x if (abs(x) < abs(y)) else y), real_roots)

    return root.astype('float32')

In [11]:
def isgd_numpy_arctan(input, weight, bias, output, logit, grad_output):
    
    # Hyperparameters
    lr = 0.000001
    mu = 0.0

    # ISGD constants
    b = grad_output / (1 + lr * mu)  # [b x m]
    c = logit / (1.0 + lr * mu)  # [b x m]
    z_norm_squared_mat = (torch.norm(input, p=2, dim=1) ** 2 + 1.0).unsqueeze(1).expand_as(c)  # [b x m]

    # Coefficients of cubic equation for each power:
    # a3*u**3 + a2*u**2 + a1*u + a0 = 0
    a3 = ((lr * z_norm_squared_mat) ** 2)  # [b x m]
    a2 = (-2 * lr * c * z_norm_squared_mat)  # [b x m]
    a1 = (1 + c ** 2)  # [b x m]
    a0 = (- b)  # [b x m]

    # Coefficients as one big numpy matrix
    coeff = torch.stack((a3, a2, a1, a0)).numpy()  # [4 x b x m]

    # Calculate roots of cubic that are real and closest to zero
    roots = np.apply_along_axis(real_root_closest_to_zero, 0, coeff)  # [b x m] # Real root closest to zero
    u = torch.from_numpy(roots)  # [b x m]

    # Calculate input gradient
    grad_output_scaled = grad_output / (1 + logit ** 2)  # [b x m]
    grad_input = grad_output_scaled.mm(weight)  # [b x n]

    # Calculate grad_weight, grad_bias
    grad_weight = weight * mu / (1.0 + lr * mu) + u.t().mm(input)  # [m x n]
    grad_bias = bias * mu / (1.0 + lr * mu) + u.t().sum(1)  # [m]
    
    return grad_input, grad_weight, grad_bias

## Test the differences between ESGD, ISGD old and ISGD_new

In [12]:
# Random data
grad_output = torch.randn(batch, m)     # [b x m]
input = torch.randn(batch, n)           # [b x n]
weight = torch.randn(m, n)          # [m x n]
bias = torch.randn(m,)              # [m]

# Forward propagation
# Calculate logit [1 x m], where logit = input.mm(weight.t()) + bias
logit = calc_logit(input, weight, bias)

# Non-linear activation function
output = torch.atan(logit)  # [1 x m]

# Calculate gradients
esgd_grads = esgd_arctan(grad_output, input, weight, bias, output)
isgd_new_grads = isgd_numpy_arctan(input, weight, bias, output, logit, grad_output)

print('\nDifference between ESGD and ISGD new')
print([(x-y) for x,y in zip(isgd_new_grads, esgd_grads)])


Difference between ESGD and ISGD new
[
 0  0
 0  0
 0  0
 0  0
[torch.FloatTensor of size 4x2]
, 
1.00000e-06 *
  0.1639 -0.0298
  1.0729 -0.6258
  0.5066 -0.8047
[torch.FloatTensor of size 3x2]
, 
1.00000e-06 *
 -0.2682
 -1.1921
 -0.3576
[torch.FloatTensor of size 3]
]


In [10]:
# Test how accurate the update is in terms of the equation it is supposed to solve

def u_diff(grad_output, c, lr, z_norm_squared, u):
    return u - grad_output / (1.0 + (c - lr * torch.mul(z_norm_squared, u.t()).t()) ** 2) / (1.0 + lr * mu)

# Hyperparameters
lr = 0.01
mu = 0.0

# ISGD constants
b = grad_output / (1 + lr * mu)  # [b x m]
c = logit / (1.0 + lr * mu)  # [b x m]
z_norm_squared_mat = (torch.norm(input, p=2, dim=1) ** 2 + 1.0).unsqueeze(1).expand_as(c)  # [b x m]
z_norm_squared = torch.norm(input, p=2, dim=1) ** 2 + 1.0

# Coefficients of cubic equation for each power:
# a3*u**3 + a2*u**2 + a1*u + a0 = 0
a3 = ((lr * z_norm_squared_mat) ** 2)  # [b x m]
a2 = (-2 * lr * c * z_norm_squared_mat)  # [b x m]
a1 = (1 + c ** 2)  # [b x m]
a0 = (- b)  # [b x m]

# Coefficients as one big numpy matrix
coeff = torch.stack((a3, a2, a1, a0)).numpy()  # [4 x b x m]

# Calculate roots of cubic that are real and closest to zero
roots = np.apply_along_axis(real_root_closest_to_zero, 0, coeff)  # [b x m] # Real root closest to zero
u = torch.from_numpy(roots)  # [b x m]

u_esgd = grad_output / (1 + logit ** 2)

print(u_diff(grad_output, c, lr, z_norm_squared, u_esgd))
print(u_diff(grad_output, c, lr, z_norm_squared, u))


1.00000e-02 *
 -0.2326 -0.0599 -0.0424
 -0.0658 -3.1712 -0.0014
 -0.0054  0.0547 -0.2752
 -1.1846 -2.8215 -0.0117
[torch.FloatTensor of size 4x3]


1.00000e-07 *
  0.2980  0.0000 -0.1490
  0.0000 -2.3842  0.0000
  0.0373  0.0745 -0.1490
 -0.5960  0.5960 -0.0745
[torch.FloatTensor of size 4x3]

