Skip to content

Commit

Permalink
替代函数参数可学习
Browse files Browse the repository at this point in the history
  • Loading branch information
fangwei123456 committed Sep 12, 2020
1 parent 5101cef commit 0ac9bed
Showing 1 changed file with 108 additions and 62 deletions.
170 changes: 108 additions & 62 deletions spikingjelly/clock_driven/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch.nn as nn
import torch.nn.functional as F
import math


def heaviside(x: torch.Tensor):
'''
* :ref:`API in English <heaviside.__init__-en>`
Expand Down Expand Up @@ -41,6 +43,38 @@ def heaviside(x: torch.Tensor):
'''
return (x >= 0).to(x.dtype)


class SurrogateFunctionBase(nn.Module):
def __init__(self, alpha, spiking=True, learnable=False):
super().__init__()
self.spiking = spiking
self.learnable = learnable
if learnable:
self.alpha = nn.Parameter(torch.tensor(alpha, dtype=torch.float))
else:
self.alpha = alpha

if spiking:
self.f = self.spiking_function
else:
self.f = self.primitive_function

@staticmethod
def spiking_function(x, alpha):
raise NotImplementedError

@staticmethod
def primitive_function(x, alpha):
raise NotImplementedError

def forward(self, x: torch.Tensor):
if self.training:
return self.f(x, self.alpha)
else:
# 无论是否为spiking模式,只要是测试(推理)阶段,都直接使用阶跃函数
return heaviside(x)


class piecewise_quadratic(torch.autograd.Function):
@staticmethod
def forward(ctx, x, alpha):
Expand All @@ -58,6 +92,7 @@ def backward(ctx, grad_output):
grad_x = grad_output * ctx.saved_tensors[0]
return grad_x, None


class PiecewiseQuadratic(nn.Module):
def __init__(self, alpha=1.0, spiking=True):
'''
Expand Down Expand Up @@ -132,13 +167,15 @@ def __init__(self, alpha=1.0, spiking=True):
self.f = piecewise_quadratic.apply
else:
self.f = self.primitive_function

def forward(self, x):
return self.f(x, self.alpha)

@staticmethod
def primitive_function(x: torch.Tensor, alpha):
mask0 = (x > 1.0 / alpha).float()
mask1 = (x.abs() <= 1.0 / alpha).float()

return mask0 + mask1 * (-(alpha ** 2) / 2 * x.square() * x.sign() + alpha * x + 0.5)

# plt.style.use(['science', 'muted', 'grid'])
Expand All @@ -163,6 +200,7 @@ def primitive_function(x: torch.Tensor, alpha):
# plt.grid(linestyle='--')
# plt.show()


class piecewise_leaky_relu(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, w=1, c=0.01):
Expand All @@ -180,6 +218,7 @@ def backward(ctx, grad_output):
grad_x = grad_output * ctx.saved_tensors[0]
return grad_x, None, None, None


class PiecewiseLeakyReLU(nn.Module):
def __init__(self, w=1, c=0.01, spiking=True):
'''
Expand Down Expand Up @@ -260,7 +299,6 @@ def __init__(self, w=1, c=0.01, spiking=True):
def forward(self, x):
return self.f(x, self.w, self.c)


@staticmethod
def primitive_function(x: torch.Tensor, w, c):
mask0 = (x < -w).float()
Expand Down Expand Up @@ -295,6 +333,7 @@ def primitive_function(x: torch.Tensor, w, c):
# plt.grid(linestyle='--')
# plt.show()


class piecewise_exp(torch.autograd.Function):
@staticmethod
def forward(ctx, x, alpha):
Expand All @@ -310,6 +349,7 @@ def backward(ctx, grad_output):
grad_x = grad_output * ctx.alpha / 2 * (-ctx.alpha * ctx.saved_tensors[0].abs()).exp()
return grad_x, None


class PiecewiseExp(nn.Module):
def __init__(self, alpha=1.0, spiking=True):
'''
Expand Down Expand Up @@ -373,14 +413,16 @@ def __init__(self, alpha=1.0, spiking=True):
self.f = piecewise_exp.apply
else:
self.f = self.primitive_function

def forward(self, x):
return self.f(x, self.alpha)

@staticmethod
def primitive_function(x: torch.Tensor, alpha):
mask_nonnegative = heaviside(x)
mask_sign = mask_nonnegative * 2 - 1
exp_x = 0.5 * (mask_sign * x * -alpha).exp()

return mask_nonnegative - exp_x * mask_sign

# plt.style.use(['science', 'muted', 'grid'])
Expand All @@ -405,6 +447,7 @@ def primitive_function(x: torch.Tensor, alpha):
# plt.grid(linestyle='--')
# plt.show()


class sigmoid(torch.autograd.Function):
@staticmethod
def forward(ctx, x, alpha):
Expand All @@ -421,6 +464,7 @@ def backward(ctx, grad_output):
grad_x = grad_output * s_x * (1 - s_x) * ctx.alpha
return grad_x, None


class Sigmoid(nn.Module):
def __init__(self, alpha=1.0, spiking=True):
'''
Expand Down Expand Up @@ -476,8 +520,10 @@ def __init__(self, alpha=1.0, spiking=True):
self.f = sigmoid.apply
else:
self.f = self.primitive_function

def forward(self, x):
return self.f(x, self.alpha)

@staticmethod
def primitive_function(x: torch.Tensor, alpha):
return (x * alpha).sigmoid()
Expand All @@ -504,6 +550,7 @@ def primitive_function(x: torch.Tensor, alpha):
# plt.grid(linestyle='--')
# plt.show()


class soft_sign(torch.autograd.Function):
@staticmethod
def forward(ctx, x, alpha):
Expand All @@ -519,6 +566,7 @@ def backward(ctx, grad_output):
grad_x = 1 / 2 / ctx.alpha / (1 / ctx.alpha + ctx.saved_tensors[0].abs()).square() * grad_output
return grad_x, None


class SoftSign(nn.Module):
def __init__(self, alpha=2.0, spiking=True):
'''
Expand Down Expand Up @@ -576,6 +624,7 @@ def __init__(self, alpha=2.0, spiking=True):
self.f = soft_sign.apply
else:
self.f = self.primitive_function

def forward(self, x):
return self.f(x, self.alpha)

Expand Down Expand Up @@ -605,86 +654,77 @@ def primitive_function(x: torch.Tensor, alpha):
# plt.grid(linestyle='--')
# plt.show()


class atan(torch.autograd.Function):
@staticmethod
def forward(ctx, x, half_alpha, half_pi_alpha):
def forward(ctx, x, alpha):
if x.requires_grad:
ctx.save_for_backward(x)
ctx.half_pi_alpha = half_pi_alpha
ctx.half_alpha = half_alpha
if isinstance(alpha, torch.Tensor):
ctx.save_for_backward(x, alpha)
else:
ctx.save_for_backward(x)
ctx.alpha = alpha
return heaviside(x)

@staticmethod
def backward(ctx, grad_output):
grad_x = None
if ctx.needs_input_grad[0]:
grad_x = grad_output * ctx.half_alpha / (1 + (ctx.half_pi_alpha * ctx.saved_tensors[0]).square())
return grad_x, None, None

class ATan(nn.Module):
def __init__(self, alpha=2.0, spiking=True):
'''
* :ref:`API in English <ATan.__init__-en>`
.. _ATan.__init__-cn:
:param alpha: 控制反向传播时梯度的平滑程度的参数
:param spiking: 是否输出脉冲,默认为 ``True``,在前向传播时使用 ``heaviside`` 而在反向传播使用替代梯度。若为 ``False``
则不使用替代梯度,前向传播时,使用反向传播时的梯度替代函数对应的原函数
反向传播时使用反正切函数arc tangent的梯度的脉冲发放函数。反向传播为
grad_alpha = None
if ctx.saved_tensors.__len__() == 1:
grad_x = ctx.alpha / 2 / (1 + (ctx.alpha * math.pi / 2 * ctx.saved_tensors[0]).square()) * grad_output
else:
# 避免重复计算,共用的部分
shared_c = grad_output / (1 + (ctx.saved_tensors[1] * math.pi / 2 * ctx.saved_tensors[0]).square())
if ctx.needs_input_grad[0]:
grad_x = ctx.saved_tensors[1] / 2 * shared_c
if ctx.needs_input_grad[1]:
# 由于alpha只有一个元素,因此梯度需要求和,变成标量
grad_alpha = (ctx.saved_tensors[0] / 2 * shared_c).sum()
return grad_x, grad_alpha

.. math::
g'(x) = \\frac{\\alpha}{2(1 + (\\frac{\\pi}{2}\\alpha x)^2)}

对应的原函数为
class ATan(SurrogateFunctionBase):
'''
* :ref:`API in English <ATan.__init__-en>`
.. _ATan.__init__-cn:
.. math::
g(x) = \\frac{1}{\\pi} \\arctan(\\frac{\\pi}{2}\\alpha x) + \\frac{1}{2}
反向传播时使用反正切函数arc tangent的梯度的脉冲发放函数。反向传播为
.. image:: ./_static/API/clock_driven/surrogate/ATan.*
:width: 100%
.. math::
g'(x) = \\frac{\\alpha}{2(1 + (\\frac{\\pi}{2}\\alpha x)^2)}
该函数在文章 中使用。
对应的原函数为
* :ref:`中文API <ATan.__init__-cn>`
.. _ATan.__init__-en:
.. math::
g(x) = \\frac{1}{\\pi} \\arctan(\\frac{\\pi}{2}\\alpha x) + \\frac{1}{2}
:param alpha: parameter to control smoothness of gradient
:param spiking: whether output spikes. The default is ``True`` which means that using ``heaviside`` in forward
propagation and using surrogate gradient in backward propagation. If ``False``, in forward propagation,
using the primitive function of the surrogate gradient function used in backward propagation
.. image:: ./_static/API/clock_driven/surrogate/ATan.*
:width: 100%
The arc tangent surrogate spiking function. The gradient is defined by
* :ref:`中文API <ATan.__init__-cn>`
.. _ATan.__init__-en:
.. math::
g'(x) = \\frac{\\alpha}{2(1 + (\\frac{\\pi}{2}\\alpha x)^2)}
The arc tangent surrogate spiking function. The gradient is defined by
The primitive function is defined by
.. math::
g'(x) = \\frac{\\alpha}{2(1 + (\\frac{\\pi}{2}\\alpha x)^2)}
.. math::
g(x) = \\frac{1}{\\pi} \\arctan(\\frac{\\pi}{2}\\alpha x) + \\frac{1}{2}
The primitive function is defined by
.. image:: ./_static/API/clock_driven/surrogate/ATan.*
:width: 100%
.. math::
g(x) = \\frac{1}{\\pi} \\arctan(\\frac{\\pi}{2}\\alpha x) + \\frac{1}{2}
The function is used in .
'''
super().__init__()
self.half_pi_alpha = math.pi / 2 * alpha
self.spiking = spiking
if spiking:
self.coefficient = alpha / 2
self.f = atan.apply
else:
self.coefficient = 1 / math.pi
self.f = self.primitive_function
.. image:: ./_static/API/clock_driven/surrogate/ATan.*
:width: 100%
'''

def forward(self, x):
return self.f(x, self.coefficient, self.half_pi_alpha)
@staticmethod
def spiking_function(x, alpha):
return atan.apply(x, alpha)

@staticmethod
def primitive_function(x: torch.Tensor, coefficient, half_pi_alpha):
return coefficient * (half_pi_alpha * x).atan() + 0.5
def primitive_function(x: torch.Tensor, alpha):
return (math.pi / 2 * alpha * x).atan() / math.pi + 0.5

# plt.style.use(['science', 'muted', 'grid'])
# fig = plt.figure(dpi=200)
Expand All @@ -708,6 +748,7 @@ def primitive_function(x: torch.Tensor, coefficient, half_pi_alpha):
# plt.grid(linestyle='--')
# plt.show()


class nonzero_sign_log_abs(torch.autograd.Function):
@staticmethod
def forward(ctx, x, inv_alpha):
Expand All @@ -723,6 +764,7 @@ def backward(ctx, grad_output):
grad_x = grad_output / (ctx.saved_tensors[0].abs() + ctx.inv_alpha)
return grad_x, None


class NonzeroSignLogAbs(nn.Module):
def __init__(self, alpha=1.0, spiking=True):
'''
Expand Down Expand Up @@ -838,6 +880,7 @@ def primitive_function(x: torch.Tensor, alpha):
# plt.grid(linestyle='--')
# plt.show()


class erf(torch.autograd.Function):
@staticmethod
def forward(ctx, x, alpha):
Expand All @@ -850,8 +893,10 @@ def forward(ctx, x, alpha):
def backward(ctx, grad_output):
grad_x = None
if ctx.needs_input_grad[0]:
grad_x = grad_output * ctx.alpha / math.sqrt(math.pi) * (-((ctx.saved_tensors[0] * ctx.alpha).square())).exp()
return grad_x, None
grad_x = grad_output * ctx.alpha / math.sqrt(math.pi) * (
-((ctx.saved_tensors[0] * ctx.alpha).square())).exp()
return grad_x, None


class Erf(nn.Module):
def __init__(self, alpha=2.0, spiking=True):
Expand Down Expand Up @@ -920,6 +965,7 @@ def __init__(self, alpha=2.0, spiking=True):
self.f = erf.apply
else:
self.f = self.primitive_function

def forward(self, x):
return self.f(x, self.alpha)

Expand Down Expand Up @@ -947,4 +993,4 @@ def primitive_function(x: torch.Tensor, alpha):
# plt.xlabel('Input')
# plt.ylabel('Output')
# plt.grid(linestyle='--')
# plt.show()
# plt.show()

0 comments on commit 0ac9bed

Please sign in to comment.