Skip to content

Commit 35843c7

Browse files
authored
fix: optimize the handling of embedding weight (#859)
1 parent 6ad46bb commit 35843c7

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

clip.hpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -553,10 +553,9 @@ class CLIPEmbeddings : public GGMLBlock {
553553
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
554554
enum ggml_type token_wtype = GGML_TYPE_F32;
555555
if (!force_clip_f32) {
556-
auto tensor_type = tensor_types.find(prefix + "token_embedding.weight");
557-
std::set<ggml_type> allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0};
558-
if (tensor_type != tensor_types.end() && allow_types.find(tensor_type->second) != allow_types.end()) {
559-
token_wtype = tensor_type->second;
556+
token_wtype = get_type(prefix + "token_embedding.weight", tensor_types, GGML_TYPE_F32);
557+
if (!support_get_rows(token_wtype)) {
558+
token_wtype = GGML_TYPE_F32;
560559
}
561560
}
562561
enum ggml_type position_wtype = GGML_TYPE_F32;

ggml_extend.hpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1967,13 +1967,24 @@ class Linear : public UnaryBlock {
19671967
}
19681968
};
19691969

1970+
__STATIC_INLINE__ bool support_get_rows(ggml_type wtype) {
1971+
std::set<ggml_type> allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0};
1972+
if (allow_types.find(wtype) != allow_types.end()) {
1973+
return true;
1974+
}
1975+
return false;
1976+
}
1977+
19701978
class Embedding : public UnaryBlock {
19711979
protected:
19721980
int64_t embedding_dim;
19731981
int64_t num_embeddings;
19741982
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") {
19751983
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
1976-
params["weight"] = ggml_new_tensor_2d(ctx, wtype, embedding_dim, num_embeddings);
1984+
if (!support_get_rows(wtype)) {
1985+
wtype = GGML_TYPE_F32;
1986+
}
1987+
params["weight"] = ggml_new_tensor_2d(ctx, wtype, embedding_dim, num_embeddings);
19771988
}
19781989

19791990
public:

0 commit comments

Comments
 (0)