Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
fangwei123456 committed Nov 10, 2023
1 parent abe0352 commit fc13801
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions spikingjelly/activation_based/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1872,6 +1872,7 @@ def __init__(self, alpha: float, v_th: float, *args, **kwargs):
torch.nn.init.constant_(self.weight, alpha * v_th)

def forward(self, x_seq):
assert self.step_mode == 'm', "ThresholdDependentBatchNormBase can only be used in the multi-step mode!"
return functional.seq_to_ann_forward(x_seq, super().forward)


Expand Down Expand Up @@ -1907,6 +1908,10 @@ def __init__(self, alpha: float, v_th: float, *args, **kwargs):
"""
super().__init__(alpha, v_th, *args, **kwargs)

def _check_input_dim(self, input):
assert input.dim() == 4 - 1 # [T * N, C, L]



class ThresholdDependentBatchNorm2d(_ThresholdDependentBatchNormBase):
def __init__(self, alpha: float, v_th: float, *args, **kwargs):
Expand Down Expand Up @@ -1940,6 +1945,8 @@ def __init__(self, alpha: float, v_th: float, *args, **kwargs):
"""
super().__init__(alpha, v_th, *args, **kwargs)

def _check_input_dim(self, input):
assert input.dim() == 5 - 1 # [T * N, C, H, W]

class ThresholdDependentBatchNorm3d(_ThresholdDependentBatchNormBase):
def __init__(self, alpha: float, v_th: float, *args, **kwargs):
Expand Down Expand Up @@ -1973,6 +1980,9 @@ def __init__(self, alpha: float, v_th: float, *args, **kwargs):
"""
super().__init__(alpha, v_th, *args, **kwargs)

def _check_input_dim(self, input):
assert input.dim() == 6 - 1 # [T * N, C, H, W, D]


class TemporalWiseAttention(nn.Module, base.MultiStepModule):
def __init__(self, T: int, reduction: int = 16, dimension: int = 4):
Expand Down

0 comments on commit fc13801

Please sign in to comment.