Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [X] [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1)
- [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral)
- [x] [DBRX](https://huggingface.co/databricks/dbrx-instruct)
- [x] [Jamba](https://huggingface.co/ai21labs)
- [X] [Falcon](https://huggingface.co/models?search=tiiuae/falcon)
- [X] [Chinese LLaMA / Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca) and [Chinese LLaMA-2 / Alpaca-2](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2)
- [X] [Vigogne (French)](https://github.com/bofenghuang/vigogne)
Expand Down
23 changes: 17 additions & 6 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,12 @@ def set_gguf_parameters(self):
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
self.gguf_writer.add_expert_used_count(n_experts_used)
logger.info(f"gguf: experts used count = {n_experts_used}")
if (n_expert_groups := self.hparams.get("n_group")) is not None:
self.gguf_writer.add_expert_group_count(n_expert_groups)
logger.info(f"gguf: expert groups count = {n_expert_groups}")
if (n_group_used := self.hparams.get("topk_group")) is not None:
self.gguf_writer.add_expert_group_used_count(n_group_used)
logger.info(f"gguf: expert groups used count = {n_group_used}")

if (head_dim := self.hparams.get("head_dim")) is not None:
self.gguf_writer.add_key_length(head_dim)
Expand Down Expand Up @@ -1497,6 +1503,17 @@ def get_audio_config(self) -> dict[str, Any] | None:
def set_type(self):
self.gguf_writer.add_type(gguf.GGUFType.MMPROJ)

def prepare_metadata(self, vocab_only: bool):
super().prepare_metadata(vocab_only=vocab_only)

output_type: str = self.ftype.name.partition("_")[2]

if self.fname_out.is_dir():
fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, size_label=None, output_type=output_type, model_type=None)
self.fname_out = self.fname_out / f"mmproj-{fname_default}.gguf"
else:
self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type)

def set_gguf_parameters(self):
self.gguf_writer.add_file_type(self.ftype)

Expand Down Expand Up @@ -8222,8 +8239,6 @@ def set_gguf_parameters(self):
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
self.gguf_writer.add_expert_count(hparams["num_experts"])
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
self.gguf_writer.add_expert_group_count(hparams["n_group"])
self.gguf_writer.add_expert_group_used_count(hparams["topk_group"])
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])

if hparams["score_function"] == "sigmoid":
Expand Down Expand Up @@ -9729,10 +9744,6 @@ def main() -> None:

logger.info(f"Loading model: {dir_model.name}")

if args.mmproj:
if "mmproj" not in fname_out.name:
fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-")

is_mistral_format = args.mistral_format
if is_mistral_format and not _mistral_common_installed:
raise ImportError(_mistral_import_error_msg)
Expand Down
15 changes: 11 additions & 4 deletions ggml/src/ggml-alloc.c
Original file line number Diff line number Diff line change
Expand Up @@ -226,16 +226,23 @@ static struct buffer_address ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * al
}

if (best_fit_block == -1) {
// no suitable block found, try the last block (this will grow a chunks size)
// no suitable block found, try the last block (this may grow a chunks size)
int64_t best_reuse = INT64_MIN;
for (int c = 0; c < alloc->n_chunks; ++c) {
struct tallocr_chunk * chunk = alloc->chunks[c];
if (chunk->n_free_blocks > 0) {
struct free_block * block = &chunk->free_blocks[chunk->n_free_blocks - 1];
max_avail = MAX(max_avail, block->size);
if (block->size >= size) {
int64_t reuse_factor = chunk->max_size - block->offset - size;
// reuse_factor < 0 : amount of extra memory that needs to be allocated
// reuse_factor = 0 : allocated free space exactly matches tensor size
// reuse_factor > 0 : superfluous memory that will remain unused
bool better_reuse = best_reuse < 0 && reuse_factor > best_reuse;
bool better_fit = reuse_factor >= 0 && reuse_factor < best_reuse;
if (block->size >= size && (better_reuse || better_fit)) {
best_fit_chunk = c;
best_fit_block = chunk->n_free_blocks - 1;
break;
best_reuse = reuse_factor;
}
}
}
Expand Down Expand Up @@ -268,7 +275,7 @@ static struct buffer_address ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * al
#ifdef GGML_ALLOCATOR_DEBUG
add_allocated_tensor(alloc, addr, tensor);
size_t cur_max = addr.offset + size;
if (cur_max > alloc->max_size[addr.chunk]) {
if (cur_max > chunk->max_size) {
// sort allocated_tensors by chunk/offset
for (int i = 0; i < 1024; i++) {
for (int j = i + 1; j < 1024; j++) {
Expand Down
13 changes: 13 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1005,3 +1005,16 @@ struct ggml_backend_cuda_context {
return pool(device);
}
};

struct ggml_cuda_mm_fusion_args_host {
const ggml_tensor * x_bias = nullptr;
const ggml_tensor * gate = nullptr;
const ggml_tensor * gate_bias = nullptr;
ggml_glu_op glu_op;
};
struct ggml_cuda_mm_fusion_args_device {
const void * x_bias = nullptr;
const void * gate = nullptr;
const void * gate_bias = nullptr;
ggml_glu_op glu_op;
};
1 change: 1 addition & 0 deletions ggml/src/ggml-cuda/convert.cuh
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#pragma once
#include "common.cuh"

#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
Expand Down
80 changes: 69 additions & 11 deletions ggml/src/ggml-cuda/cpy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,30 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
cpy_blck(cx + x_offset, cdst + dst_offset);
}

template<typename src_t, typename dst_t>
static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const int64_t ne) {
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;

if (i >= ne) {
return;
}

const src_t * x = (const src_t *) cx;
dst_t * dst = (dst_t *) cdst;

dst[i] = ggml_cuda_cast<dst_t>(x[i]);
}

template<typename src_t, typename dst_t>
static void ggml_cpy_flt_contiguous_cuda(
const char * cx, char * cdst, const int64_t ne,
cudaStream_t stream) {

const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_flt_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne);
}

template<typename src_t, typename dst_t>
static void ggml_cpy_flt_cuda(
const char * cx, char * cdst, const int ne,
Expand Down Expand Up @@ -285,7 +309,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
char * src0_ddc = (char *) src0->data;
char * src1_ddc = (char *) src1->data;

if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);

if (src0->type == src1->type && contiguous_srcs) {
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
Expand All @@ -296,11 +322,19 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
if (contiguous_srcs) {
ggml_cpy_flt_contiguous_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
} else {
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
if (contiguous_srcs) {
ggml_cpy_flt_contiguous_cuda<float, half> (src0_ddc, src1_ddc, ne, main_stream);
} else {
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
Expand All @@ -327,21 +361,45 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
if (contiguous_srcs) {
ggml_cpy_flt_contiguous_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
} else {
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
if (contiguous_srcs) {
ggml_cpy_flt_contiguous_cuda<half, float> (src0_ddc, src1_ddc, ne, main_stream);
} else {
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
if (contiguous_srcs) {
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, main_stream);
} else {
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
if (contiguous_srcs) {
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, main_stream);
} else {
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
if (contiguous_srcs) {
ggml_cpy_flt_contiguous_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, main_stream);
} else {
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
if (contiguous_srcs) {
ggml_cpy_flt_contiguous_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, main_stream);
} else {
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
}
} else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
Expand Down
Loading