/
bitpack.py
executable file
·144 lines (119 loc) · 5.07 KB
/
bitpack.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# Written by Dr. Hicham Badri @Mobius Labs GmbH - 2023
#####################################################
import torch
from torch import uint8, int32, Tensor
import numpy as np
# Bit packing logic. format: pack/unpack_nBits_target-<uint8 or int32>
class BitPack:
# 8-bit
################################################
@staticmethod
def pack_8bit_u8(W_q: Tensor) -> Tensor:
return W_q.to(uint8)
@staticmethod
def unpack_8bit_u8(W_q: Tensor, dtype=uint8) -> Tensor:
return W_q.to(dtype)
# 4-bit
################################################
@staticmethod
def pack_4bit_u8(W_q: Tensor) -> Tensor: # uint8 > uint8/2
W_q = W_q.to(uint8)
_step = int(len(W_q) / 2)
return (W_q[:_step] << 4) | W_q[_step:]
@staticmethod
def unpack_4bit_u8(W_q: Tensor, dtype=uint8) -> Tensor: # uint8/2 > uint8
_step = W_q.shape[0]
tmp = torch.empty([2 * _step, W_q.shape[1]], dtype=dtype, device=W_q.device)
tmp[:_step] = (W_q & 0b11110000) >> 4
tmp[_step:] = W_q & 0b00001111
return tmp
# 2-bit
################################################
@staticmethod
def pack_2bit_u8(W_q: Tensor) -> Tensor: # uint8 > uint8/4
W_q = W_q.to(uint8)
_step = int(len(W_q) / 4)
return (
W_q[:_step] << 6
| W_q[_step : 2 * _step] << 4
| W_q[2 * _step : 3 * _step] << 2
| W_q[3 * _step :]
)
@staticmethod
def unpack_2bit_u8(W_q: Tensor, dtype=uint8) -> Tensor:
_step = W_q.shape[0]
tmp = torch.empty([4 * _step, W_q.shape[1]], dtype=dtype, device=W_q.device)
tmp[0 * _step : 1 * _step] = (W_q & 0b11000000) >> 6
tmp[1 * _step : 2 * _step] = (W_q & 0b00110000) >> 4
tmp[2 * _step : 3 * _step] = (W_q & 0b00001100) >> 2
tmp[3 * _step : 4 * _step] = W_q & 0b00000011
return tmp
# 3-bit
################################################
@staticmethod
def pack_3bit_32(W_q_in: Tensor) -> Tensor:
W_q = torch.zeros(
[int(10 * np.ceil(W_q_in.shape[0] / 10.0)), W_q_in.shape[1]],
device=W_q_in.device,
dtype=int32,
)
W_q[: len(W_q_in)] = W_q_in
_step = int(len(W_q) / 10)
W_q = (
(W_q[:_step] << 27)
| (W_q[1 * _step : 2 * _step] << 24)
| (W_q[2 * _step : 3 * _step] << 21)
| (W_q[3 * _step : 4 * _step] << 18)
| (W_q[4 * _step : 5 * _step] << 15)
| (W_q[5 * _step : 6 * _step] << 12)
| (W_q[6 * _step : 7 * _step] << 9)
| (W_q[7 * _step : 8 * _step] << 6)
| (W_q[8 * _step : 9 * _step] << 3)
| (W_q[9 * _step : 10 * _step])
)
return W_q
# A bit faster than _cat version
@staticmethod
def unpack_3bit_32(W_q: Tensor, dtype=uint8) -> Tensor:
_step = W_q.shape[0]
tmp = torch.empty([10 * _step, W_q.shape[1]], dtype=dtype, device=W_q.device)
tmp[0 * _step : 1 * _step] = (W_q & 0b00111000000000000000000000000000) >> 27
tmp[1 * _step : 2 * _step] = (W_q & 0b00000111000000000000000000000000) >> 24
tmp[2 * _step : 3 * _step] = (W_q & 0b00000000111000000000000000000000) >> 21
tmp[3 * _step : 4 * _step] = (W_q & 0b00000000000111000000000000000000) >> 18
tmp[4 * _step : 5 * _step] = (W_q & 0b00000000000000111000000000000000) >> 15
tmp[5 * _step : 6 * _step] = (W_q & 0b00000000000000000111000000000000) >> 12
tmp[6 * _step : 7 * _step] = (W_q & 0b00000000000000000000111000000000) >> 9
tmp[7 * _step : 8 * _step] = (W_q & 0b00000000000000000000000111000000) >> 6
tmp[8 * _step : 9 * _step] = (W_q & 0b00000000000000000000000000111000) >> 3
tmp[9 * _step : 10 * _step] = W_q & 0b00000000000000000000000000000111
return tmp
# 1-bit
################################################
@staticmethod
def pack_1bit_u8(W_q: Tensor) -> Tensor:
W_q = W_q.to(uint8)
_step = int(len(W_q) / 8)
return (
W_q[:_step] << 7
| W_q[1 * _step : 2 * _step] << 6
| W_q[2 * _step : 3 * _step] << 5
| W_q[3 * _step : 4 * _step] << 4
| W_q[4 * _step : 5 * _step] << 3
| W_q[5 * _step : 6 * _step] << 2
| W_q[6 * _step : 7 * _step] << 1
| W_q[7 * _step : 8 * _step]
)
@staticmethod
def unpack_1bit_u8(W_q: Tensor, dtype=uint8) -> Tensor:
_step = W_q.shape[0]
tmp = torch.empty([8 * _step, W_q.shape[1]], dtype=dtype, device=W_q.device)
tmp[0 * _step : 1 * _step] = (W_q & 0b10000000) >> 7
tmp[1 * _step : 2 * _step] = (W_q & 0b01000000) >> 6
tmp[2 * _step : 3 * _step] = (W_q & 0b00100000) >> 5
tmp[3 * _step : 4 * _step] = (W_q & 0b00010000) >> 4
tmp[4 * _step : 5 * _step] = (W_q & 0b00001000) >> 3
tmp[5 * _step : 6 * _step] = (W_q & 0b00000100) >> 2
tmp[6 * _step : 7 * _step] = (W_q & 0b00000010) >> 1
tmp[7 * _step : 8 * _step] = W_q & 0b00000001
return tmp