Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
fangwei123456 committed Dec 1, 2023
1 parent 16f24ba commit 915abc5
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions spikingjelly/activation_based/cuda_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,20 +118,27 @@ def cal_fun_t(n: int, device: str or torch.device or int, f: Callable, *args, **
:rtype: float
"""
if device == 'cpu':
c_timer = cpu_timer
else:
c_timer = cuda_timer

if n == 1:
return c_timer(device, f, *args, **kwargs)
if device == 'cpu':
return cpu_timer(f, *args, **kwargs)
else:
return cuda_timer(device, f, *args, **kwargs)

# warm up
c_timer(device, f, *args, **kwargs)
if device == 'cpu':
cpu_timer(f, *args, **kwargs)
else:
cuda_timer(device, f, *args, **kwargs)

t_list = []
for _ in range(n * 2):
t_list.append(c_timer(device, f, *args, **kwargs))
if device == 'cpu':
ti = cpu_timer(f, *args, **kwargs)
else:
ti = cuda_timer(device, f, *args, **kwargs)
t_list.append(ti)


t_list = np.asarray(t_list)
return t_list[n:].mean()

Expand Down

0 comments on commit 915abc5

Please sign in to comment.