Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
634 changes: 363 additions & 271 deletions convert_hf_to_gguf.py

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions convert_lora_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import gguf

# reuse model definitions from convert_hf_to_gguf.py
from convert_hf_to_gguf import LazyTorchTensor, Model
from convert_hf_to_gguf import LazyTorchTensor, ModelBase

logger = logging.getLogger("lora-to-gguf")

Expand Down Expand Up @@ -340,11 +340,11 @@ def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]:
sys.exit(1)
else:
logger.info(f"Loading base model: {dir_base_model.name}")
hparams = Model.load_hparams(dir_base_model)
hparams = ModelBase.load_hparams(dir_base_model)

with torch.inference_mode():
try:
model_class = Model.from_model_architecture(hparams["architectures"][0])
model_class = ModelBase.from_model_architecture(hparams["architectures"][0])
except NotImplementedError:
logger.error(f"Model {hparams['architectures'][0]} is not supported")
sys.exit(1)
Expand Down
3 changes: 0 additions & 3 deletions examples/llava/clip-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
// tensor name constants
//

#define TN_TOKEN_EMBD "%s.token_embd.weight"
#define TN_POS_EMBD "%s.position_embd.weight"
#define TN_CLASS_EMBD "v.class_embd"
#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat
Expand All @@ -66,8 +65,6 @@
#define TN_LN_2 "%s.blk.%d.ln2.%s"
#define TN_LN_PRE "%s.pre_ln.%s"
#define TN_LN_POST "%s.post_ln.%s"
#define TN_TEXT_PROJ "text_projection.weight"
#define TN_VIS_PROJ "visual_projection.weight"
#define TN_LLAVA_PROJ "mm.%d.%s"
#define TN_MVLM_PROJ_MLP "mm.model.mlp.%d.%s"
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
Expand Down
5 changes: 3 additions & 2 deletions examples/llava/clip.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ struct clip_image_size {
int height;
};

struct clip_image_f32;
struct clip_image_u8_batch;
struct clip_image_f32_batch;

struct clip_context_params {
bool use_gpu;
ggml_log_level verbosity;
enum ggml_log_level verbosity;
};

// deprecated, use clip_init
Expand Down Expand Up @@ -84,7 +85,7 @@ CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
CLIP_API size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch); // equivalent to batch->size()
CLIP_API size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->nx
CLIP_API size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->ny
CLIP_API clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data
CLIP_API struct clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data

/**
* Build image from pixels decoded by other libraries instead of stb_image.h for better performance.
Expand Down
15 changes: 15 additions & 0 deletions ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_SQRT,
GGML_METAL_KERNEL_TYPE_SIN,
GGML_METAL_KERNEL_TYPE_COS,
GGML_METAL_KERNEL_TYPE_NEG,
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
Expand Down Expand Up @@ -1159,6 +1160,7 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
Expand Down Expand Up @@ -1320,6 +1322,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_ELU:
case GGML_UNARY_OP_NEG:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
default:
return false;
Expand Down Expand Up @@ -2010,6 +2013,18 @@ static void ggml_metal_encode_node(

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_NEG:
{
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline;

[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];

const int64_t n = ggml_nelements(dst);

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
default:
{
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
Expand Down
7 changes: 7 additions & 0 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -949,6 +949,13 @@ kernel void kernel_cos(
dst[tpig] = cos(src0[tpig]);
}

kernel void kernel_neg(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = -src0[tpig];
}

kernel void kernel_sum_rows(
device const float * src0,
device float * dst,
Expand Down
22 changes: 18 additions & 4 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2397,7 +2397,7 @@ static void ggml_vk_load_shaders(vk_device& device) {

ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);

Expand Down Expand Up @@ -6006,6 +6006,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
case GGML_OP_ROPE:
case GGML_OP_RMS_NORM:
return true;
default:
return false;
Expand Down Expand Up @@ -6216,7 +6217,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co

switch (op) {
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
case GGML_OP_L2_NORM:
case GGML_OP_SOFT_MAX:
Expand All @@ -6233,6 +6233,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
elements = { nr, 1, 1 };
}
} break;
case GGML_OP_RMS_NORM:
elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
break;

case GGML_OP_SUM:
// We use GGML_OP_SUM_ROWS with 1 row.
elements = { 1, 1, 1 };
Expand Down Expand Up @@ -6883,7 +6887,17 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx

static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);

ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
op_params[0], 0.0f,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
}, dryrun);
}

static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
Expand Down Expand Up @@ -9388,10 +9402,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_RMS_NORM:
return true;
case GGML_OP_NORM:
case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_ADD:
Expand Down
32 changes: 21 additions & 11 deletions ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp
Original file line number Diff line number Diff line change
@@ -1,26 +1,36 @@
#version 450

#include "generic_head.comp"
#include "generic_unary_head.comp"
#include "types.comp"

#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 512

layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;

layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};

shared FLOAT_TYPE sum[BLOCK_SIZE];

void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
const uint ncols = p.ne00;
const uint nrows = gl_NumWorkGroups.x;
const uint nchannels = gl_NumWorkGroups.y;

const uint row = gl_WorkGroupID.x;
const uint channel = gl_WorkGroupID.y;
const uint samp = gl_WorkGroupID.z;
const uint tid = gl_LocalInvocationID.x;

const uint stride_row = p.nb01;
const uint stride_channel = p.nb02;
const uint stride_sample = p.nb03;

uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();

sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp

[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]);
sum[tid] += xi * xi;
}

Expand All @@ -33,10 +43,10 @@ void main() {
barrier();
}

const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(p.KX);
const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));

[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
}
}
Loading