In [25]:
from __future__ import print_function
import torch
import torch.nn as nn
import math

In [2]:
"""
n = dimension of input
m = dimension of output
"""

'\nn = dimension of input\nm = dimension of output\n'

# Linear

In [3]:
# Forward propagation
def forward_linear(input, weight, bias):
    """
    Arguments:
    input [1 x n]      The input vector to the layer
    weight [m x n]     The weight matrix of the layer
    bias [m]           The bias vector of the layer
    
    Returns:
    output [1 x m]     The input to the next layer = logit put through the non-linear activation function
    """
    
    # Calculate logit [1 x m], where logit = input.mm(weight.t()) + bias       
    logit = input.mm(weight.t())
    logit += bias.unsqueeze(0).expand_as(logit)
    
    # Non-linear activation function
    output = logit
    
    return [output, logit]


## ESGD

In [4]:
def esgd_linear(grad_output, weight, input):
    grad_input = grad_output.mm(weight)          # [1 x n]
    grad_weight = grad_output.t().mm(input)      # [m x n]
    grad_bias = grad_output.sum(0).squeeze(0)    # [m]

    return [grad_input, grad_weight, grad_bias]

## ISGD

In [36]:
def a_linear(s,d,c):
    """
    Arguments:
    s [1 x m]      Sign of back-propagated gradient
    d [1 x m]      Weighted constant, proportional to the sqrt(abs(back-propagated gradient))
    c [1 x m]      Logit contracted by ridge-regularization
    
    Return
    alpha [1 x m]  Solution of ISGD update for each output
    """
    alpha = - s * d  # Note that this is element-wise multiplication
    return alpha

In [6]:
## Old version which seems unstable wrt the learning rate
def isgd_linear_old(grad_output, weight, input, lr, mu):
    # ISGD constants
    s = torch.sign(grad_output)     # [1 x m]
    z_norm = math.sqrt((torch.norm(input) ** 2 + 1.0)) # [1]
    d = z_norm * math.sqrt(lr/(1.0+lr*mu)) * torch.sqrt(torch.abs(grad_output)) # [1 x m]
    c = logit / (1.0+lr*mu) #  [1 x m]

    # Calculate alpha
    alpha = alpha_linear(s,d,c) # [1 x m]

    # Calculate new weight, bias, and the implied gradients
    new_weight = weight / (1.0 + lr * mu) + (alpha * d).t().mm(input) / z_norm **2  # [m x n]
    grad_weight = (weight - new_weight) / lr  #  [m x n]

    new_bias = bias / (1.0 + lr * mu) + (alpha * d).squeeze() / z_norm **2  # [m]
    grad_bias = (bias - new_bias) / lr  # [m]
    
    grad_input = grad_output.mm(weight)

    # Return the results
    return [grad_input, grad_weight, grad_bias]

In [37]:
def isgd_linear(grad_output, weight, input, logit):
    
    # Hyperparameters
    lr = 0.001
    mu = 0.0

    # ISGD constants
    s = torch.sign(grad_output)     # [1 x m]
    z_norm = math.sqrt((torch.norm(input) ** 2 + 1.0)) # [1]
    d = z_norm / math.sqrt(1.0+lr*mu) * torch.sqrt(torch.abs(grad_output)) # [1 x m]
    c = logit / (1.0+lr*mu) #  [1 x m]

    # Calculate alpha
    a = a_linear(s,d,c) # [1 x m]

    # Calculate new weight, bias, and the implied gradients
    grad_weight = weight * mu / (1.0 + lr * mu) - (a * d).t().mm(input) / z_norm **2  #  [m x n]

    grad_bias = bias * mu / (1.0 + lr * mu) - (a * d).squeeze() / z_norm **2  #  [m x n]
    
    grad_input = grad_output.mm(weight)

    # Return the results
    return grad_input, grad_weight, grad_bias

## Test the differences between ESGD and ISGD

In [38]:
# Random data
grad_output = torch.randn(1, 3)     # [1 x m]
input = torch.randn(1, 2)           # [1 x n]
weight = torch.randn(3, 2)          # [m x n]
bias = torch.randn(3,)              # [m]

# Check that forward propagation makes sense
output, logit = forward_linear(input, weight, bias)

In [40]:
isgd_grads = isgd_linear(grad_output, weight, input, logit)
esgd_grads = esgd_linear(grad_output, weight, input)

print([(x-y) for x,y in zip(isgd_grads, esgd_grads)])

[
 0  0
[torch.FloatTensor of size 1x2]
, 
1.00000e-07 *
  0.5960 -0.5960
  0.0000 -1.1921
  0.5960  0.0000
[torch.FloatTensor of size 3x2]
, 
1.00000e-07 *
  0.5960
  1.1921
  0.5960
[torch.FloatTensor of size 3]
]
