Skip to content

Commit a4f942f

Browse files
committed
Make GeLU go 10x faster (take two)
Truncating the lowest bit of 𝑥 solves the issues we had earlier. See 2781abf
1 parent 835985b commit a4f942f

File tree

9 files changed

+101
-6
lines changed

9 files changed

+101
-6
lines changed

llama.cpp/common.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
205205
return true;
206206
}
207207
if (arg == "--fast") {
208-
FLAG_precise = false;
208+
FLAG_fast = true;
209209
return true;
210210
}
211211
if (arg == "--precise") {

llama.cpp/ggml-vector.inc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "ggml-vector.h"
55
#include "ggml-impl.h"
66
#include "llamafile/llamafile.h"
7+
#include "llamafile/tanhf.h"
78
#include <stdatomic.h>
89

910
void ggml_once(atomic_uint *, void (*)(void));
@@ -1330,8 +1331,14 @@ void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (i
13301331
//
13311332

13321333
static inline float ggml_gelu_f32(float x) {
1333-
// GeLU approximation that goes slower and we seem to be stuck with.
1334-
return .5f * x * (1.f + tanhf(sqrtf(M_2_PI) * (x + .044715f * x * x * x)));
1334+
union {
1335+
float f;
1336+
uint32_t i;
1337+
} u = {x};
1338+
u.i &= 0xfffffffe;
1339+
x = u.f;
1340+
x = .5f * x * (1.f + Tanhf(sqrtf(M_2_PI) * (x + .044715f * x * x * x)));
1341+
return x;
13351342
}
13361343

13371344
void ggml_vec_gelu_f32(const int n, float * y, const float * x) {

llama.cpp/llama.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
33
#define LLAMA_API_INTERNAL
44
#include "llamafile/log.h"
5+
#include "llamafile/latency.h"
56
#include "llamafile/debug.h"
67

78
#include "llama-impl.h"

llama.cpp/server/server.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2546,7 +2546,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
25462546
}
25472547
else if (arg == "--fast")
25482548
{
2549-
FLAG_precise = false;
2549+
FLAG_fast = true;
25502550
}
25512551
else if (arg == "--precise")
25522552
{

llamafile/flags.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "llama.cpp/llama.h"
2828

2929
bool FLAGS_READY = false;
30+
bool FLAG_fast = false;
3031
bool FLAG_log_disable = false;
3132
bool FLAG_mlock = false;
3233
bool FLAG_mmap = true;
@@ -220,7 +221,7 @@ void llamafile_get_flags(int argc, char **argv) {
220221
// cpu flags
221222

222223
if (!strcmp(flag, "--fast")) {
223-
FLAG_precise = false;
224+
FLAG_fast = true;
224225
continue;
225226
}
226227

llamafile/llamafile.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ extern "C" {
77
#endif
88

99
extern bool FLAGS_READY;
10+
extern bool FLAG_fast;
1011
extern bool FLAG_log_disable;
1112
extern bool FLAG_mlock;
1213
extern bool FLAG_mmap;

llamafile/tanhf.h

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
2+
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
3+
//
4+
// Copyright 2023 Arm Limited
5+
//
6+
// Licensed under the Apache License, Version 2.0 (the "License");
7+
// you may not use this file except in compliance with the License.
8+
// You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing, software
13+
// distributed under the License is distributed on an "AS IS" BASIS,
14+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
// See the License for the specific language governing permissions and
16+
// limitations under the License.
17+
18+
#pragma once
19+
20+
/* Helper routine for calculating exp(x) - 1.
21+
Copied from expm1f_1u6.c, with several simplifications:
22+
- No special-case handling for tiny or special values, instead return early
23+
from the main routine.
24+
- No special handling for large values:
25+
- No early return for infinity.
26+
- Simpler combination of p and t in final stage of algorithm.
27+
- |i| < 27, so can calculate t by simpler shift-and-add, instead of ldexpf.
28+
From Optimized Routines by Arm Limited. */
29+
static inline float Expm1f(float x) {
30+
/* Reduce argument: f in [-ln2/2, ln2/2], i is exact. */
31+
float Shift = 0x1.8p23f;
32+
float j = fmaf(0x1.715476p+0f, x, Shift) - Shift;
33+
int i = j;
34+
float f = fmaf(j, -0x1.62e4p-1f, x);
35+
f = fmaf(j, -0x1.7f7d1cp-20f, f);
36+
37+
/* Approximate expm1(f) with polynomial P, expm1(f) ~= f + f^2 * P(f).
38+
Uses Estrin scheme, where the main expm1f routine uses Horner. */
39+
float f2 = f * f;
40+
float p_01 = fmaf(f, 0x1.5554aep-3, 0x1.fffffep-2);
41+
float p_23 = fmaf(f, 0x1.12287cp-7, 0x1.555736p-5);
42+
float p = fmaf(f2, p_23, p_01);
43+
p = fmaf(f2 * f2, 0x1.6b55a2p-10, p);
44+
p = fmaf(f2, p, f);
45+
46+
/* t = 2^i. */
47+
union {
48+
unsigned i;
49+
float f;
50+
} u = {(i + 127) << 23};
51+
float t = u.f;
52+
53+
/* expm1(x) ~= p * t + (t - 1). */
54+
return fmaf(p, t, t - 1);
55+
}
56+
57+
/* Single-precision tanh(x) approximation.
58+
The maximum error is 2.58 ULP.
59+
Designed by Arm Limited. */
60+
static inline float Tanhf(float x) {
61+
union {
62+
float f;
63+
unsigned i;
64+
} u = {x};
65+
unsigned iax = u.i & 0x7fffffff;
66+
unsigned sign = u.i & ~0x7fffffff;
67+
68+
/* Above 0x1.205966p+3 tanhf rounds to 1 (or -1 for negative). */
69+
if (iax > 0x41102cb3) {
70+
if (iax > 0x7f800000)
71+
return (x - x) / (x - x);
72+
u.i = 0x3f800000 | sign;
73+
return u.f;
74+
}
75+
if (iax < 0x34000000)
76+
return x;
77+
78+
/* tanh(x) = (e^2x - 1) / (e^2x + 1). */
79+
float q = Expm1f(2 * x);
80+
return q / (q + 2);
81+
}

stable-diffusion.cpp/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
216216

217217
// [jart]
218218
if (arg == "--fast") {
219-
FLAG_precise = false;
219+
FLAG_fast = true;
220220
} else if (arg == "--precise") {
221221
FLAG_precise = true;
222222
} else if (arg == "--trace") {

whisper.cpp/main.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,10 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
126126

127127
if (arg == "--log-disable") {
128128
FLAG_log_disable = true;
129+
} else if (arg == "--fast") {
130+
FLAG_fast = true;
131+
} else if (arg == "--precise") {
132+
FLAG_precise = true;
129133
} else if (arg == "--trace") {
130134
FLAG_trace = true;
131135
} else if (arg == "--trap") {

0 commit comments

Comments
 (0)