-
Notifications
You must be signed in to change notification settings - Fork 30
/
embedding.馃敟
144 lines (115 loc) 路 5.24 KB
/
embedding.馃敟
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2024, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""Neural network embedding layers."""
from tensor import Tensor, TensorShape
from max.graph import ops, Dim, StaticDim, Symbol, TensorType
from max.graph.quantization import (
Float32Encoding,
QuantizationEncoding,
Q4_0Encoding,
Q4_KEncoding,
Q6_KEncoding,
)
from pipelines.nn import Linear
from pipelines.weights.ggml_quants import BlockQ40, BlockQ4K, BlockQ6K
def _out_dim(quantized_dim: Dim, encoding_id: String) -> Int:
if encoding_id == Q4_0Encoding.id():
num_blocks = quantized_dim.value[StaticDim].dim // sizeof[BlockQ40]()
return int(BlockQ40.elements_per_block() * num_blocks)
if encoding_id == Q4_KEncoding.id():
num_blocks = quantized_dim.value[StaticDim].dim // sizeof[BlockQ4K]()
return int(BlockQ4K.elements_per_block() * num_blocks)
if encoding_id == Q6_KEncoding.id():
num_blocks = quantized_dim.value[StaticDim].dim // sizeof[BlockQ6K]()
return int(BlockQ6K.elements_per_block() * num_blocks)
raise "unsupported quantization encoding in Embedding: " + encoding_id
def _dequantize(reshaped_tokens: Symbol, encoding_id: String) -> Symbol:
if encoding_id == Q4_0Encoding.id():
return ops.custom["ggml_q4_0_dequantize"](
reshaped_tokens,
TensorType(DType.float32, Dim.dynamic(), Dim.dynamic()),
)
if encoding_id == Q4_KEncoding.id():
return ops.custom["ggml_q4_k_dequantize"](
reshaped_tokens,
TensorType(DType.float32, Dim.dynamic(), Dim.dynamic()),
)
if encoding_id == Q6_KEncoding.id():
return ops.custom["ggml_q6_k_dequantize"](
reshaped_tokens,
TensorType(DType.float32, Dim.dynamic(), Dim.dynamic()),
)
raise "unsupported quantization encoding in Embedding: " + encoding_id
@value
struct Embedding:
"""Quantized embedding (can be in GGML Q4_0 packed format)."""
var weights: Symbol
"""The embedding weights, which are possibly quantized."""
var encoding_id: String
"""Id for the quantization encoding of this embedding."""
def __init__(inout self, weights: Symbol):
self.weights = weights
# Default to float32 if not provided.
self.encoding_id = Float32Encoding.id()
def __init__(inout self, encoded_weight: Tuple[Symbol, String]):
self.weights, self.encoding_id = encoded_weight
def __call__(borrowed self, indices: Symbol) -> Symbol:
"""Gathers and dequantize rows of the quantized embedding, as needed.
Args:
indices: Rows of embedding to return.
Returns:
[Dequantized, as needed] embedding rows corresponding to `indices`.
"""
# If there is no quantization, then just gather and return early.
if self.encoding_id == Float32Encoding.id():
return ops.gather(self.weights, indices, axis=0)
# Otherwise, gather and dequantize the rows.
g = self.weights.graph()
quantized_tokens = ops.gather(self.weights, indices, axis=0)
tokens_type = quantized_tokens.tensor_type()
# Compute the dequantized output dim as the number of quantized blocks
# times the number of elements (quants) per block.
out_dim = _out_dim(
quantized_dim=self.weights.tensor_type().dims[-1],
encoding_id=self.encoding_id,
)
# Compute shapes to reshape embeddings to matrix input expected by
# `ggml_q4_0_dequantize`.
final_dims = List[Dim]()
for i in range(tokens_type.rank() - 1):
final_dims.append(tokens_type.dim(i))
final_dims.append(out_dim)
tokens_shape = ops.shape_of(quantized_tokens)
last_tokens_axis = tokens_type.rank() - 1
reshape_shape = ops.stack(
List(g.scalar(Int64(-1)), tokens_shape[last_tokens_axis])
)
dequantize_dims = List[Dim]()
dequantize_dims.append(Dim.dynamic())
dequantize_dims.append(tokens_type.dim(-1))
final_shape = ops.concat(
List(
tokens_shape[:last_tokens_axis],
g.constant[DType.int64](
Tensor[DType.int64](TensorShape(1), out_dim)
),
)
)
# Reshape gathered token embeddings to a matrix expected by dequantize.
reshaped_tokens = ops.reshape(
quantized_tokens, reshape_shape, dequantize_dims
)
# Dequantize to floating point.
dequantized_tokens = _dequantize(reshaped_tokens, self.encoding_id)
# Restore original `rank(quantized_tokens) - 1` gather dimensions.
return ops.reshape(dequantized_tokens, final_shape, final_dims)