Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -260,16 +260,18 @@ jobs:
architecture: ${{ matrix.arch == 'x64' && 'win64' || 'win64a' }}

- name: Build
shell: cmd
env:
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
run: |
cmake -S . -B build -G "Ninja Multi-Config" `
-D CMAKE_TOOLCHAIN_FILE=cmake/${{ matrix.arch }}-windows-llvm.cmake `
-DGGML_NATIVE=OFF `
-DGGML_BACKEND_DL=ON `
-DGGML_CPU_ALL_VARIANTS=ON `
-DGGML_OPENMP=OFF `
-DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include" `
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" ${{ matrix.arch }}
cmake -S . -B build -G "Ninja Multi-Config" ^
-D CMAKE_TOOLCHAIN_FILE=cmake/${{ matrix.arch }}-windows-llvm.cmake ^
-DGGML_NATIVE=OFF ^
-DGGML_BACKEND_DL=ON ^
-DGGML_CPU_ALL_VARIANTS=${{ matrix.arch == 'x64' && 'ON' || 'OFF' }} ^
-DGGML_OPENMP=ON ^
-DCURL_LIBRARY="%CURL_PATH%/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="%CURL_PATH%/include" ^
${{ env.CMAKE_ARGS }}
cmake --build build --config Release

Expand All @@ -279,6 +281,7 @@ jobs:
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
run: |
Copy-Item $env:CURL_PATH\bin\libcurl-${{ matrix.arch }}.dll .\build\bin\Release\
Copy-Item "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Redist\MSVC\14.42.34433\debug_nonredist\${{ matrix.arch }}\Microsoft.VC143.OpenMP.LLVM\libomp140.${{ matrix.arch == 'x64' && 'x86_64' || 'aarch64' }}.dll" .\build\bin\Release\
7z a llama-bin-win-cpu-${{ matrix.arch }}.zip .\build\bin\Release\*

- name: Upload artifacts
Expand Down
42 changes: 42 additions & 0 deletions .github/workflows/winget.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
name: Update Winget Package

on:
workflow_dispatch: # allows manual triggering
schedule:
- cron: '28 5 * * *' # Update every day at 5:28 UTC

jobs:
update:
name: Update Winget Package
runs-on: ubuntu-latest

steps:
- name: Install cargo binstall
uses: cargo-bins/cargo-binstall@268643a6b5ea099f5718ee5cd3ff7dc89a5eb49b

- name: Install komac
run: |
cargo binstall komac@2.11.2 -y

- name: Find latest release
id: find_latest_release
uses: actions/github-script@v6
with:
script: |
const { data: releases } = await github.rest.repos.listReleases({
owner: context.repo.owner,
repo: context.repo.repo,
});
console.log("Latest release:", releases[0].tag_name);
return releases[0].tag_name;

- name: Update manifest
env:
VERSION: ${{ steps.find_latest_release.outputs.result }}
run: |
echo "Updating manifest..."
komac update --version ${{ env.VERSION }} \
--urls "https://github.com/ggml-org/llama.cpp/releases/download/${{ env.VERSION }}/llama-${{ env.VERSION }}-bin-win-vulkan-x64.zip" \
--token ${{ secrets.WINGET_GITHUB_TOKEN }} \
--submit \
ggml.llamacpp
13 changes: 13 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -3484,6 +3484,19 @@ void ggml_cpu_init(void) {
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);

GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0);

#ifdef GGML_USE_OPENMP
//if (!getenv("OMP_WAIT_POLICY")) {
// // set the wait policy to active, so that OpenMP threads don't sleep
// putenv("OMP_WAIT_POLICY=active");
//}

if (!getenv("KMP_BLOCKTIME")) {
// set the time to wait before sleeping a thread
// this is less aggressive than setting the wait policy to active, but should achieve similar results in most cases
putenv("KMP_BLOCKTIME=200"); // 200ms
}
#endif
}

#if defined(__ARM_ARCH)
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cuda/fattn-vec-f16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ static __global__ void flash_attn_vec_ext_f16(
}
}
if (__all_sync(0xFFFFFFFF, skip)) {
__syncthreads();
continue;
}
#endif // GGML_USE_HIP
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cuda/fattn-vec-f32.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ static __global__ void flash_attn_vec_ext_f32(
}
}
if (__all_sync(0xFFFFFFFF, skip)) {
__syncthreads();
continue;
}
#endif // GGML_USE_HIP
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2192,6 +2192,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_SILU:
ggml_cuda_op_silu(ctx, dst);
break;
case GGML_UNARY_OP_GELU_ERF:
ggml_cuda_op_gelu_erf(ctx, dst);
break;
case GGML_UNARY_OP_GELU_QUICK:
ggml_cuda_op_gelu_quick(ctx, dst);
break;
Expand Down Expand Up @@ -2977,6 +2980,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_EXP:
Expand Down
10 changes: 10 additions & 0 deletions ggml/src/ggml-cuda/unary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ static __device__ __forceinline__ float op_gelu(float x) {
return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}

static __device__ __forceinline__ float op_gelu_erf(float x) {
const float SQRT_2_INV = 0.70710678118654752440084436210484f;

return 0.5f*x*(1.0f + erff(x*SQRT_2_INV));
}

static __device__ __forceinline__ float op_gelu_quick(float x) {
const float GELU_QUICK_COEF = -1.702f;

Expand Down Expand Up @@ -134,6 +140,10 @@ void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_gelu>(ctx, dst);
}

void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_gelu_erf>(ctx, dst);
}

void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_gelu_quick>(ctx, dst);
}
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-cuda/unary.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
Expand Down
8 changes: 4 additions & 4 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1287,6 +1287,10 @@ ggml_tensor * llm_graph_context::build_attn(

if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4) {
// GLM4 seems to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
}

if (wo_b) {
Expand Down Expand Up @@ -1367,10 +1371,6 @@ ggml_tensor * llm_graph_context::build_attn(

if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4) {
// GLM4 seems to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
}

if (wo_b) {
Expand Down
8 changes: 4 additions & 4 deletions src/llama-vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ struct llm_tokenizer_ugm_session {
}

// initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores
std::vector<struct best_tokenization> tokenization_results(input_len + 1, {vocab.token_unk(), 0, -FLT_MAX});
std::vector<struct best_tokenization> tokenization_results(input_len + 1, {vocab.token_unk(), 0, -DBL_MAX});
// at the beginning tokenization score is zero
tokenization_results[0] = { vocab.token_unk(), 0, 0 };

Expand Down Expand Up @@ -867,7 +867,7 @@ struct llm_tokenizer_ugm_session {
const double challenger_score = current_best.score_sum + token_score;
struct best_tokenization & current_champ = tokenization_results[prefix_offset];
if (challenger_score > current_champ.score_sum) {
struct best_tokenization challenger = { token_id, input_offset, (float) challenger_score };
struct best_tokenization challenger = { token_id, input_offset, challenger_score };
current_champ = challenger;
}
}
Expand All @@ -881,7 +881,7 @@ struct llm_tokenizer_ugm_session {
prefix_offset = input_offset + n_utf8_code_units;
struct best_tokenization & current_champ = tokenization_results[prefix_offset];
if (challenger_score > current_champ.score_sum) {
struct best_tokenization challenger = { vocab.token_unk(), input_offset, (float) challenger_score };
struct best_tokenization challenger = { vocab.token_unk(), input_offset, challenger_score };
current_champ = challenger;
}
}
Expand Down Expand Up @@ -1007,7 +1007,7 @@ struct llm_tokenizer_ugm_session {
struct best_tokenization {
llama_token token_id;
size_t input_offset;
float score_sum;
double score_sum;
};

struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) {
Expand Down