Skip to content

Commit

Permalink
fix some bugs (not finished)
Browse files Browse the repository at this point in the history
  • Loading branch information
fangwei123456 committed Jul 5, 2023
1 parent 348e0c8 commit 71023d9
Showing 1 changed file with 27 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.fusion import *
from torch.autograd import Variable, Function
from torch.autograd import Function
from torch import Tensor
from collections import namedtuple
from spikingjelly.activation_based.surrogate import SurrogateFunctionBase, heaviside
from spikingjelly.activation_based import layer
from spikingjelly.activation_based.neuron import LIFNode
from spikingjelly.activation_based.base import StepModule,SingleModule
from ...activation_based import layer
from ..neuron import LIFNode
from torch.nn.functional import interpolate
from spikingjelly.activation_based.surrogate import SurrogateFunctionBase, heaviside
from ..surrogate import SurrogateFunctionBase, heaviside
from math import tanh
from torch.jit import script

Expand All @@ -23,13 +20,13 @@ def network_layer_to_space(net_arch):
"""
:param net_arch: network level sample rate
0: down 1: None 2: Up
:type net_arch: numpy.array
:type net_arch: numpy.ndarray
:return: network level architecture
network_space[layer][level][sample]:
layer: 0 - 8
level: sample_level {0: 1, 1: 2, 2: 4, 3: 8}
sample: 0: down 1: None 2: Up
:rtype: numpy.array
:rtype: numpy.ndarray
Convert network level sample rate like [0,0,1,1,1,2,2,2] to network architecture.
"""
Expand Down Expand Up @@ -319,9 +316,9 @@ def __init__(self, steps, block_multiplier, prev_prev_fmultiplier,
:param prev_filter_multiplier: The change factor for the channel for previous node
:type prev_filter_multiplier: int
:param cell_arch: cell level architecture
:type cell_arch: numpy.array
:type cell_arch: numpy.ndarray
:param network_arch: layer level architecture
:type network_arch: numpy.array
:type network_arch: numpy.ndarray
:param filter_multiplier: filter channel multiplier
:type filter_multiplier: int
:param downup_sample: sample rate, -1:downsample, 1:upsample, 0: no change
Expand Down Expand Up @@ -410,9 +407,9 @@ def __init__(self, frame_rate, network_arch, cell_arch, cell=Cell, args=None):
:param frame_rate: input channel
:type frame_rate: int
:param network_arch: layer level architecture
:type network_arch: numpy.array
:type network_arch: numpy.ndarray
:param cell_arch: cell level architecture
:type cell_arch: numpy.array
:type cell_arch: numpy.ndarray
:param cell: choice the type of cell, defaults to Cell
:type cell: Cell class
:param args: additional arguments
Expand Down Expand Up @@ -537,7 +534,7 @@ def __init__(self, init_channels=3, args=None):
:type init_channels: int
:param args: additional arguments
The SpikeDHS (Auto-Spikformer: Spikformer Architecture Search) implementation by Spikingjelly.
The SpikeDHS `Auto-Spikformer: Spikformer Architecture Search <https://arxiv.org/abs/2306.00807>`_ implementation by Spikingjelly.
"""
super(SpikeDHS, self).__init__()
Expand Down Expand Up @@ -588,22 +585,23 @@ def dgs_unfreeze_weights(self):
for name,value in self.named_parameters():
value.requires_grad_(True)

### Example ###
if __name__ == "__main__":
### Example ###

parser = argparse.ArgumentParser("cifar")
parser = argparse.ArgumentParser("cifar")

parser.add_argument('--layers', type=int, default=8, help='total number of layers')
parser.add_argument('--layers', type=int, default=8, help='total number of layers')

parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')
parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss')
parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')
parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss')

parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--arch', type=str, default='DARTS', help='which architecture to use')
parser.add_argument('--fea_num_layers', type=int, default=8)
parser.add_argument('--fea_filter_multiplier', type=int, default=48)
parser.add_argument('--fea_block_multiplier', type=int, default=3)
parser.add_argument('--fea_step', type=int, default=3)
parser.add_argument('--net_arch_fea', default=None, type=str)
parser.add_argument('--cell_arch_fea', default=None, type=str)
args = parser.parse_args()
spikedhs = SpikeDHS(init_channels=3, args=args)
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--arch', type=str, default='DARTS', help='which architecture to use')
parser.add_argument('--fea_num_layers', type=int, default=8)
parser.add_argument('--fea_filter_multiplier', type=int, default=48)
parser.add_argument('--fea_block_multiplier', type=int, default=3)
parser.add_argument('--fea_step', type=int, default=3)
parser.add_argument('--net_arch_fea', default=None, type=str)
parser.add_argument('--cell_arch_fea', default=None, type=str)
args = parser.parse_args()
spikedhs = SpikeDHS(init_channels=3, args=args)

0 comments on commit 71023d9

Please sign in to comment.