Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions clip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -641,10 +641,10 @@ class CLIPVisionEmbeddings : public GGMLBlock {
// concat(patch_embedding, class_embedding) + position_embedding
struct ggml_tensor* patch_embedding;
int64_t N = pixel_values->ne[3];
patch_embedding = ggml_nn_conv_2d(ctx, pixel_values, patch_embed_weight, nullptr, patch_size, patch_size); // [N, embed_dim, image_size // pacht_size, image_size // pacht_size]
patch_embedding = ggml_reshape_3d(ctx, patch_embedding, num_patches, embed_dim, N); // [N, embed_dim, num_patches]
patch_embedding = ggml_cont(ctx, ggml_permute(ctx, patch_embedding, 1, 0, 2, 3)); // [N, num_patches, embed_dim]
patch_embedding = ggml_reshape_4d(ctx, patch_embedding, 1, embed_dim, num_patches, N); // [N, num_patches, embed_dim, 1]
patch_embedding = ggml_ext_conv_2d(ctx, pixel_values, patch_embed_weight, nullptr, patch_size, patch_size); // [N, embed_dim, image_size // pacht_size, image_size // pacht_size]
patch_embedding = ggml_reshape_3d(ctx, patch_embedding, num_patches, embed_dim, N); // [N, embed_dim, num_patches]
patch_embedding = ggml_cont(ctx, ggml_permute(ctx, patch_embedding, 1, 0, 2, 3)); // [N, num_patches, embed_dim]
patch_embedding = ggml_reshape_4d(ctx, patch_embedding, 1, embed_dim, num_patches, N); // [N, num_patches, embed_dim, 1]

struct ggml_tensor* class_embedding = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, embed_dim, N);
class_embedding = ggml_repeat(ctx, class_embed_weight, class_embedding); // [N, embed_dim]
Expand Down Expand Up @@ -736,7 +736,7 @@ class CLIPTextModel : public GGMLBlock {
auto text_projection = params["text_projection"];
ggml_tensor* pooled = ggml_view_1d(ctx, x, hidden_size, x->nb[1] * max_token_idx);
if (text_projection != nullptr) {
pooled = ggml_nn_linear(ctx, pooled, text_projection, nullptr);
pooled = ggml_ext_linear(ctx, pooled, text_projection, nullptr);
} else {
LOG_DEBUG("identity projection");
}
Expand Down Expand Up @@ -836,7 +836,7 @@ class CLIPProjection : public UnaryBlock {
if (transpose_weight) {
w = ggml_cont(ctx, ggml_transpose(ctx, w));
}
return ggml_nn_linear(ctx, x, w, nullptr);
return ggml_ext_linear(ctx, x, w, nullptr);
}
};

Expand Down
8 changes: 4 additions & 4 deletions common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ class GEGLU : public UnaryBlock {
auto gate_b = ggml_view_1d(ctx, b, b->ne[0] / 2, b->nb[0] * b->ne[0] / 2); // [dim_out, ]

auto x_in = x;
x = ggml_nn_linear(ctx, x_in, x_w, x_b); // [ne3, ne2, ne1, dim_out]
auto gate = ggml_nn_linear(ctx, x_in, gate_w, gate_b); // [ne3, ne2, ne1, dim_out]
x = ggml_ext_linear(ctx, x_in, x_w, x_b); // [ne3, ne2, ne1, dim_out]
auto gate = ggml_ext_linear(ctx, x_in, gate_w, gate_b); // [ne3, ne2, ne1, dim_out]

gate = ggml_gelu_inplace(ctx, gate);

Expand Down Expand Up @@ -325,7 +325,7 @@ class CrossAttention : public GGMLBlock {
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]

x = ggml_nn_attention_ext(ctx, backend, q, k, v, n_head, nullptr, false, false, flash_attn); // [N, n_token, inner_dim]
x = ggml_ext_attention_ext(ctx, backend, q, k, v, n_head, nullptr, false, false, flash_attn); // [N, n_token, inner_dim]

x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
return x;
Expand Down Expand Up @@ -492,7 +492,7 @@ class AlphaBlender : public GGMLBlock {
float get_alpha() {
// image_only_indicator is always tensor([0.]) and since mix_factor.shape is [1,]
// so learned_with_images is same as learned
float alpha = ggml_backend_tensor_get_f32(params["mix_factor"]);
float alpha = ggml_ext_backend_tensor_get_f32(params["mix_factor"]);
return sigmoid(alpha);
}

Expand Down
92 changes: 46 additions & 46 deletions conditioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
clip_skip,
&chunk_hidden_states2, work_ctx);
// concat
chunk_hidden_states = ggml_tensor_concat(work_ctx, chunk_hidden_states1, chunk_hidden_states2, 0);
chunk_hidden_states = ggml_ext_tensor_concat(work_ctx, chunk_hidden_states1, chunk_hidden_states2, 0);

if (chunk_idx == 0) {
text_model2->compute(n_threads,
Expand All @@ -484,18 +484,18 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
ggml_tensor* result = ggml_dup_tensor(work_ctx, chunk_hidden_states);
{
float original_mean = ggml_tensor_mean(chunk_hidden_states);
float original_mean = ggml_ext_tensor_mean(chunk_hidden_states);
for (int i2 = 0; i2 < chunk_hidden_states->ne[2]; i2++) {
for (int i1 = 0; i1 < chunk_hidden_states->ne[1]; i1++) {
for (int i0 = 0; i0 < chunk_hidden_states->ne[0]; i0++) {
float value = ggml_tensor_get_f32(chunk_hidden_states, i0, i1, i2);
float value = ggml_ext_tensor_get_f32(chunk_hidden_states, i0, i1, i2);
value *= chunk_weights[i1];
ggml_tensor_set_f32(result, value, i0, i1, i2);
ggml_ext_tensor_set_f32(result, value, i0, i1, i2);
}
}
}
float new_mean = ggml_tensor_mean(result);
ggml_tensor_scale(result, (original_mean / new_mean));
float new_mean = ggml_ext_tensor_mean(result);
ggml_ext_tensor_scale_inplace(result, (original_mean / new_mean));
}
if (zero_out_masked) {
float* vec = (float*)result->data;
Expand Down Expand Up @@ -874,18 +874,18 @@ struct SD3CLIPEmbedder : public Conditioner {
work_ctx);
{
auto tensor = chunk_hidden_states_l;
float original_mean = ggml_tensor_mean(tensor);
float original_mean = ggml_ext_tensor_mean(tensor);
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float value = ggml_tensor_get_f32(tensor, i0, i1, i2);
float value = ggml_ext_tensor_get_f32(tensor, i0, i1, i2);
value *= chunk_weights[i1];
ggml_tensor_set_f32(tensor, value, i0, i1, i2);
ggml_ext_tensor_set_f32(tensor, value, i0, i1, i2);
}
}
}
float new_mean = ggml_tensor_mean(tensor);
ggml_tensor_scale(tensor, (original_mean / new_mean));
float new_mean = ggml_ext_tensor_mean(tensor);
ggml_ext_tensor_scale_inplace(tensor, (original_mean / new_mean));
}

if (chunk_idx == 0) {
Expand Down Expand Up @@ -932,18 +932,18 @@ struct SD3CLIPEmbedder : public Conditioner {

{
auto tensor = chunk_hidden_states_g;
float original_mean = ggml_tensor_mean(tensor);
float original_mean = ggml_ext_tensor_mean(tensor);
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float value = ggml_tensor_get_f32(tensor, i0, i1, i2);
float value = ggml_ext_tensor_get_f32(tensor, i0, i1, i2);
value *= chunk_weights[i1];
ggml_tensor_set_f32(tensor, value, i0, i1, i2);
ggml_ext_tensor_set_f32(tensor, value, i0, i1, i2);
}
}
}
float new_mean = ggml_tensor_mean(tensor);
ggml_tensor_scale(tensor, (original_mean / new_mean));
float new_mean = ggml_ext_tensor_mean(tensor);
ggml_ext_tensor_scale_inplace(tensor, (original_mean / new_mean));
}

if (chunk_idx == 0) {
Expand Down Expand Up @@ -984,18 +984,18 @@ struct SD3CLIPEmbedder : public Conditioner {
work_ctx);
{
auto tensor = chunk_hidden_states_t5;
float original_mean = ggml_tensor_mean(tensor);
float original_mean = ggml_ext_tensor_mean(tensor);
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float value = ggml_tensor_get_f32(tensor, i0, i1, i2);
float value = ggml_ext_tensor_get_f32(tensor, i0, i1, i2);
value *= chunk_weights[i1];
ggml_tensor_set_f32(tensor, value, i0, i1, i2);
ggml_ext_tensor_set_f32(tensor, value, i0, i1, i2);
}
}
}
float new_mean = ggml_tensor_mean(tensor);
ggml_tensor_scale(tensor, (original_mean / new_mean));
float new_mean = ggml_ext_tensor_mean(tensor);
ggml_ext_tensor_scale_inplace(tensor, (original_mean / new_mean));
}
} else {
chunk_hidden_states_t5 = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, chunk_len);
Expand All @@ -1013,19 +1013,19 @@ struct SD3CLIPEmbedder : public Conditioner {
for (int i0 = 0; i0 < chunk_hidden_states_lg_pad->ne[0]; i0++) {
float value = 0.f;
if (i0 < chunk_hidden_states_l->ne[0]) {
value = ggml_tensor_get_f32(chunk_hidden_states_l, i0, i1, i2);
value = ggml_ext_tensor_get_f32(chunk_hidden_states_l, i0, i1, i2);
} else if (i0 < chunk_hidden_states_l->ne[0] + chunk_hidden_states_g->ne[0]) {
value = ggml_tensor_get_f32(chunk_hidden_states_g, i0 - chunk_hidden_states_l->ne[0], i1, i2);
value = ggml_ext_tensor_get_f32(chunk_hidden_states_g, i0 - chunk_hidden_states_l->ne[0], i1, i2);
}
ggml_tensor_set_f32(chunk_hidden_states_lg_pad, value, i0, i1, i2);
ggml_ext_tensor_set_f32(chunk_hidden_states_lg_pad, value, i0, i1, i2);
}
}
}

chunk_hidden_states = ggml_tensor_concat(work_ctx, chunk_hidden_states_lg_pad, chunk_hidden_states_t5, 1); // [n_token*2, 4096]
chunk_hidden_states = ggml_ext_tensor_concat(work_ctx, chunk_hidden_states_lg_pad, chunk_hidden_states_t5, 1); // [n_token*2, 4096]

if (chunk_idx == 0) {
pooled = ggml_tensor_concat(work_ctx, pooled_l, pooled_g, 0); // [768 + 1280]
pooled = ggml_ext_tensor_concat(work_ctx, pooled_l, pooled_g, 0); // [768 + 1280]
}

int64_t t1 = ggml_time_ms();
Expand Down Expand Up @@ -1269,18 +1269,18 @@ struct FluxCLIPEmbedder : public Conditioner {
work_ctx);
{
auto tensor = chunk_hidden_states;
float original_mean = ggml_tensor_mean(tensor);
float original_mean = ggml_ext_tensor_mean(tensor);
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float value = ggml_tensor_get_f32(tensor, i0, i1, i2);
float value = ggml_ext_tensor_get_f32(tensor, i0, i1, i2);
value *= chunk_weights[i1];
ggml_tensor_set_f32(tensor, value, i0, i1, i2);
ggml_ext_tensor_set_f32(tensor, value, i0, i1, i2);
}
}
}
float new_mean = ggml_tensor_mean(tensor);
ggml_tensor_scale(tensor, (original_mean / new_mean));
float new_mean = ggml_ext_tensor_mean(tensor);
ggml_ext_tensor_scale_inplace(tensor, (original_mean / new_mean));
}
} else {
chunk_hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, chunk_len);
Expand Down Expand Up @@ -1483,18 +1483,18 @@ struct T5CLIPEmbedder : public Conditioner {
work_ctx);
{
auto tensor = chunk_hidden_states;
float original_mean = ggml_tensor_mean(tensor);
float original_mean = ggml_ext_tensor_mean(tensor);
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float value = ggml_tensor_get_f32(tensor, i0, i1, i2);
float value = ggml_ext_tensor_get_f32(tensor, i0, i1, i2);
value *= chunk_weights[i1];
ggml_tensor_set_f32(tensor, value, i0, i1, i2);
ggml_ext_tensor_set_f32(tensor, value, i0, i1, i2);
}
}
}
float new_mean = ggml_tensor_mean(tensor);
ggml_tensor_scale(tensor, (original_mean / new_mean));
float new_mean = ggml_ext_tensor_mean(tensor);
ggml_ext_tensor_scale_inplace(tensor, (original_mean / new_mean));
}

int64_t t1 = ggml_time_ms();
Expand All @@ -1505,7 +1505,7 @@ struct T5CLIPEmbedder : public Conditioner {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
if (chunk_mask[i1] < 0.f) {
ggml_tensor_set_f32(tensor, 0.f, i0, i1, i2);
ggml_ext_tensor_set_f32(tensor, 0.f, i0, i1, i2);
}
}
}
Expand Down Expand Up @@ -1664,7 +1664,7 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
image.data = nullptr;

ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1);
sd_image_f32_to_tensor(resized_image, image_tensor, false);
sd_image_f32_to_ggml_tensor(resized_image, image_tensor, false);
free(resized_image.data);
resized_image.data = nullptr;

Expand Down Expand Up @@ -1709,18 +1709,18 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
work_ctx);
{
auto tensor = hidden_states;
float original_mean = ggml_tensor_mean(tensor);
float original_mean = ggml_ext_tensor_mean(tensor);
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float value = ggml_tensor_get_f32(tensor, i0, i1, i2);
float value = ggml_ext_tensor_get_f32(tensor, i0, i1, i2);
value *= weights[i1];
ggml_tensor_set_f32(tensor, value, i0, i1, i2);
ggml_ext_tensor_set_f32(tensor, value, i0, i1, i2);
}
}
}
float new_mean = ggml_tensor_mean(tensor);
ggml_tensor_scale(tensor, (original_mean / new_mean));
float new_mean = ggml_ext_tensor_mean(tensor);
ggml_ext_tensor_scale_inplace(tensor, (original_mean / new_mean));
}

GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);
Expand All @@ -1731,9 +1731,9 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
hidden_states->ne[1] - prompt_template_encode_start_idx,
hidden_states->ne[2]);

ggml_tensor_iter(new_hidden_states, [&](ggml_tensor* new_hidden_states, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = ggml_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3);
ggml_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3);
ggml_ext_tensor_iter(new_hidden_states, [&](ggml_tensor* new_hidden_states, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = ggml_ext_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3);
ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3);
});

int64_t t1 = ggml_time_ms();
Expand Down
2 changes: 1 addition & 1 deletion control.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ class ControlNetBlock : public GGMLBlock {

auto middle_block_out = std::dynamic_pointer_cast<Conv2d>(blocks["middle_block_out.0"]);

auto t_emb = ggml_nn_timestep_embedding(ctx, timesteps, model_channels); // [N, model_channels]
auto t_emb = ggml_ext_timestep_embedding(ctx, timesteps, model_channels); // [N, model_channels]

auto emb = time_embed_0->forward(ctx, t_emb);
emb = ggml_silu_inplace(ctx, emb);
Expand Down
Loading
Loading