Skip to content

Commit

Permalink
add compress functions and tutorials: #385
Browse files Browse the repository at this point in the history
  • Loading branch information
fangwei123456 committed May 31, 2023
1 parent 6dca147 commit 542cea3
Show file tree
Hide file tree
Showing 3 changed files with 373 additions and 34 deletions.
148 changes: 148 additions & 0 deletions docs/source/activation_based/monitor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -374,4 +374,152 @@
[tensor(4.3944), tensor(2.4396), tensor(0.8996), tensor(0.4376), tensor(0.0640), tensor(0.0122), tensor(0.0053), tensor(0.0016), tensor(0.0013), tensor(0.0005)]
降低内存占用
-------------------------------------------
如果我们需要记录大量数据,当被记录的数据是脉冲时,可以通过一些方法来降低内存占用。
为了能够进行浮点计算,尽管脉冲只含有0/1,但它们仍然被存储为浮点形式。因此,脉冲tensor的数据类型仍然为float32,或float16(如果使用混合精度训练)。

将float32转换为bool类型,可以降低内存占用。但由于C++中的bool类型实际上仍然是8比特,这种方式只能把内存降低为原来的1/4:

.. code-block:: python
import torch
def tensor_memory(x: torch.Tensor):
return x.element_size() * x.numel()
N = 1 << 10
spike = torch.randint(0, 2, [N]).float()
print('float32 size =', tensor_memory(spike))
print('torch.bool size =', tensor_memory(spike.to(torch.bool)))
输出为:

.. code-block:: shell
float32 size = 4096
torch.bool size = 1024
:classl:`spikingjelly.activation_based.tensor_cache` 中提供了将float32/float16类型的脉冲tensor压缩到uint8类型脉冲tensor的函数,其中uint8的tensor,每个\
元素使用8比特,保存8个脉冲,相当于是“真正的bool”类型。示例如下:

.. code-block:: python
import torch
def tensor_memory(x: torch.Tensor):
return x.element_size() * x.numel()
N = 1 << 10
spike = torch.randint(0, 2, [N]).float()
print('float32 size =', tensor_memory(spike))
print('torch.bool size =', tensor_memory(spike.to(torch.bool)))
from spikingjelly.activation_based import tensor_cache
spike_b, s_dtype, s_shape, s_padding = tensor_cache.float_spike_to_bool(spike)
print('bool size =', tensor_memory(spike_b))
spike_recover = tensor_cache.bool_spike_to_float(spike_b, s_dtype, s_shape, s_padding)
print('spike == spike_recover?', torch.equal(spike, spike_recover))
输出为:

.. code-block:: shell
float32 size = 4096
torch.bool size = 1024
bool size = 128
spike == spike_recover? True
与监视器结合使用,只需要将压缩函数增加到监视器的自定义函数中:

.. code-block:: python
spike_seq_monitor = monitor.OutputMonitor(net, neuron.IFNode, function_on_output=tensor_cache.float_spike_to_bool)
在访问记录的数据时,再临时解压缩即可:


.. code-block:: python
for item in spike_seq_monitor.records:
print(tensor_cache.bool_spike_to_float(*item))
此外,对于稀疏的脉冲,还可以考虑使用 ``zlib`` 等库进行进一步的压缩。下面是对发放率为0.2的脉冲进行进一步压缩的例子:

.. code-block:: python
import torch
import zlib
from spikingjelly.activation_based import tensor_cache
def tensor_memory(x: torch.Tensor):
return x.element_size() * x.numel()
N = 1 << 20
spike = (torch.rand([N]) > 0.8).float()
spike_b, s_dtype, s_shape, s_padding = tensor_cache.float_spike_to_bool(spike)
arr = spike_b.numpy()
compressed_arr = zlib.compress(arr.tobytes())
print("compressed ratio:", len(compressed_arr) / arr.nbytes * tensor_memory(spike_b) / tensor_memory(spike))
输出为:

.. code-block:: shell
compressed ratio: 0.024264097213745117
如果想和监视器结合使用,仍然是放进自定义函数即可。完整的示例如下:

.. code-block:: python
import torch
import torch.nn as nn
import zlib
import numpy as np
from spikingjelly.activation_based import monitor, neuron, functional, layer, tensor_cache
def compress(spike: torch.Tensor):
spike_b, s_dtype, s_shape, s_padding = tensor_cache.float_spike_to_bool(spike)
spike_cb = zlib.compress(spike_b.cpu().numpy().tobytes())
return spike_cb, s_dtype, s_shape, s_padding
def decompress(spike_cb, s_dtype, s_shape, s_padding):
spike_b = torch.frombuffer(zlib.decompress(spike_cb), dtype=torch.uint8)
return tensor_cache.bool_spike_to_float(spike_b, s_dtype, s_shape, s_padding)
net = nn.Sequential(
layer.Linear(8, 4),
neuron.IFNode(),
layer.Linear(4, 2),
neuron.IFNode()
)
for param in net.parameters():
param.data.abs_()
functional.set_step_mode(net, 'm')
spike_seq_monitor = monitor.OutputMonitor(net, neuron.IFNode, function_on_output=compress)
T = 4
N = 1
x_seq = torch.rand([T, N, 8])
with torch.no_grad():
net(x_seq)
for item in spike_seq_monitor.records:
print(decompress(*item))
需要注意的是,``zlib`` 的压缩只能在CPU上进行,如果原始数据在GPU上,则两边传输数据会大幅度拖慢运行速度。
150 changes: 150 additions & 0 deletions docs/source/activation_based_en/monitor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,153 @@ The outputs are:
alpha=8, input_grad_monitor.records=
[tensor(4.3944), tensor(2.4396), tensor(0.8996), tensor(0.4376), tensor(0.0640), tensor(0.0122), tensor(0.0053), tensor(0.0016), tensor(0.0013), tensor(0.0005)]
Reduce Memory Consumption
-------------------------------------------
If we need to record huge amounts of data and the data are spikes, we can use some methods to reduce memory consumption.

Although spike tensors only contain 0 and 1, they are still stored in float format. We can convert them to bool to reduce memory consumption. But it still uses 1/4, rather than 1/32 of the original memory consumption because bool in C++ requires 8 bits, rather than 1 bit:

.. code-block:: python
import torch
def tensor_memory(x: torch.Tensor):
return x.element_size() * x.numel()
N = 1 << 10
spike = torch.randint(0, 2, [N]).float()
print('float32 size =', tensor_memory(spike))
print('torch.bool size =', tensor_memory(spike.to(torch.bool)))
The outputs are:

.. code-block:: shell
float32 size = 4096
torch.bool size = 1024
:classl:`spikingjelly.activation_based.tensor_cache` provides functions to compress a float32/float16 tensor to an uint8 tensor, whose each element saves 8 spikes. This uint8 tensor can be regarded as a "true bool" tensor. Here is an example:

.. code-block:: python
import torch
def tensor_memory(x: torch.Tensor):
return x.element_size() * x.numel()
N = 1 << 10
spike = torch.randint(0, 2, [N]).float()
print('float32 size =', tensor_memory(spike))
print('torch.bool size =', tensor_memory(spike.to(torch.bool)))
from spikingjelly.activation_based import tensor_cache
spike_b, s_dtype, s_shape, s_padding = tensor_cache.float_spike_to_bool(spike)
print('bool size =', tensor_memory(spike_b))
spike_recover = tensor_cache.bool_spike_to_float(spike_b, s_dtype, s_shape, s_padding)
print('spike == spike_recover?', torch.equal(spike, spike_recover))
The outputs are:

.. code-block:: shell
float32 size = 4096
torch.bool size = 1024
bool size = 128
spike == spike_recover? True
To compress recorded data with monitors, we can add the compress function in custom functions of the monitor:

.. code-block:: python
spike_seq_monitor = monitor.OutputMonitor(net, neuron.IFNode, function_on_output=tensor_cache.float_spike_to_bool)
When we visit the data, we need to decompress them:

.. code-block:: python
for item in spike_seq_monitor.records:
print(tensor_cache.bool_spike_to_float(*item))
For sparse spikes, we can also use ``zlib`` for advanced compression. Here is an example of compressing spikes with a firing rate of 0.2:

.. code-block:: python
import torch
import zlib
from spikingjelly.activation_based import tensor_cache
def tensor_memory(x: torch.Tensor):
return x.element_size() * x.numel()
N = 1 << 20
spike = (torch.rand([N]) > 0.8).float()
spike_b, s_dtype, s_shape, s_padding = tensor_cache.float_spike_to_bool(spike)
arr = spike_b.numpy()
compressed_arr = zlib.compress(arr.tobytes())
print("compressed ratio:", len(compressed_arr) / arr.nbytes * tensor_memory(spike_b) / tensor_memory(spike))
The outputs are:

.. code-block:: shell
compressed ratio: 0.024264097213745117
Here is a complete example:

.. code-block:: python
import torch
import torch.nn as nn
import zlib
import numpy as np
from spikingjelly.activation_based import monitor, neuron, functional, layer, tensor_cache
def compress(spike: torch.Tensor):
spike_b, s_dtype, s_shape, s_padding = tensor_cache.float_spike_to_bool(spike)
spike_cb = zlib.compress(spike_b.cpu().numpy().tobytes())
return spike_cb, s_dtype, s_shape, s_padding
def decompress(spike_cb, s_dtype, s_shape, s_padding):
spike_b = torch.frombuffer(zlib.decompress(spike_cb), dtype=torch.uint8)
return tensor_cache.bool_spike_to_float(spike_b, s_dtype, s_shape, s_padding)
net = nn.Sequential(
layer.Linear(8, 4),
neuron.IFNode(),
layer.Linear(4, 2),
neuron.IFNode()
)
for param in net.parameters():
param.data.abs_()
functional.set_step_mode(net, 'm')
spike_seq_monitor = monitor.OutputMonitor(net, neuron.IFNode, function_on_output=compress)
T = 4
N = 1
x_seq = torch.rand([T, N, 8])
with torch.no_grad():
net(x_seq)
for item in spike_seq_monitor.records:
print(decompress(*item))
Note that ``zlib`` only works on the CPU. If the original data are on GPU, then moving data will slow down the running speed.

0 comments on commit 542cea3

Please sign in to comment.