Skip to content

Commit

Permalink
release : v1.0.4
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Dec 17, 2022
1 parent dd58b25 commit 1502317
Show file tree
Hide file tree
Showing 5 changed files with 438 additions and 72 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ publish: publish-trigger
\n\
cd /path/to/whisper.cpp/bindings/ios\n\
git commit\n\
git tag 1.0.3\n\
git tag 1.0.4\n\
git push origin master --tags\n\
"

Expand Down
176 changes: 147 additions & 29 deletions Sources/whisper/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
#include <stdint.h>
#include <stdio.h>

// if C99 - static_assert is noop
// ref: https://stackoverflow.com/a/53923785/4039976
#ifndef static_assert
#define static_assert(cond, msg) struct global_scope_noop_trick
#endif

#if defined _MSC_VER || defined(__MINGW32__)

#if !defined(__MINGW32__)
Expand Down Expand Up @@ -135,9 +141,6 @@ ggml_fp16_t ggml_fp32_to_fp16(float x) {
#include <immintrin.h>
#endif

// FP16 <-> FP32
// ref: https://github.com/Maratyszcza/FP16

#ifdef __F16C__
float ggml_fp16_to_fp32(ggml_fp16_t h) {
return _cvtsh_ss(h);
Expand All @@ -151,6 +154,9 @@ ggml_fp16_t ggml_fp32_to_fp16(float f) {

#else

// FP16 <-> FP32
// ref: https://github.com/Maratyszcza/FP16

static inline float fp32_from_bits(uint32_t w) {
union {
uint32_t as_bits;
Expand Down Expand Up @@ -434,10 +440,10 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
y2 = _mm256_loadu_ps(y + i + 16);
y3 = _mm256_loadu_ps(y + i + 24);

sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0);
sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1);
sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2);
sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3);
sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0);
sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1);
sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2);
sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3);
}

sum0 = _mm256_add_ps(sum0, sum1);
Expand Down Expand Up @@ -675,10 +681,10 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16)));
y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24)));

sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0);
sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1);
sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2);
sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3);
sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0);
sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1);
sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2);
sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3);
}

const __m256 sum01 = _mm256_add_ps(sum0, sum1);
Expand Down Expand Up @@ -844,10 +850,10 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
y2 = _mm256_loadu_ps(y + i + 16);
y3 = _mm256_loadu_ps(y + i + 24);

y0 = _mm256_add_ps(_mm256_mul_ps(x0, v4), y0);
y1 = _mm256_add_ps(_mm256_mul_ps(x1, v4), y1);
y2 = _mm256_add_ps(_mm256_mul_ps(x2, v4), y2);
y3 = _mm256_add_ps(_mm256_mul_ps(x3, v4), y3);
y0 = _mm256_add_ps(_mm256_mul_ps(x0, v4), y0);
y1 = _mm256_add_ps(_mm256_mul_ps(x1, v4), y1);
y2 = _mm256_add_ps(_mm256_mul_ps(x2, v4), y2);
y3 = _mm256_add_ps(_mm256_mul_ps(x3, v4), y3);

_mm256_storeu_ps(y + i + 0, y0);
_mm256_storeu_ps(y + i + 8, y1);
Expand Down Expand Up @@ -1041,10 +1047,10 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16)));
x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24)));

y0 = _mm256_add_ps(_mm256_mul_ps(x0, v8), y0);
y1 = _mm256_add_ps(_mm256_mul_ps(x1, v8), y1);
y2 = _mm256_add_ps(_mm256_mul_ps(x2, v8), y2);
y3 = _mm256_add_ps(_mm256_mul_ps(x3, v8), y3);
y0 = _mm256_add_ps(_mm256_mul_ps(x0, v8), y0);
y1 = _mm256_add_ps(_mm256_mul_ps(x1, v8), y1);
y2 = _mm256_add_ps(_mm256_mul_ps(x2, v8), y2);
y3 = _mm256_add_ps(_mm256_mul_ps(x3, v8), y3);

_mm_storeu_si128((__m128i*)(y + i + 0 ), _mm256_cvtps_ph(y0, 0));
_mm_storeu_si128((__m128i*)(y + i + 8 ), _mm256_cvtps_ph(y1, 0));
Expand Down Expand Up @@ -1112,7 +1118,45 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
#endif
}

inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
#if defined(__AVX__) || defined(__AVX2__)
// AVX 256-bit
const int n32 = (n & ~31);

const __m256 v4 = _mm256_set1_ps(v);

__m256 y0, y1, y2, y3;

for (int i = 0; i < n32; i += 32) {
y0 = _mm256_loadu_ps(y + i + 0);
y1 = _mm256_loadu_ps(y + i + 8);
y2 = _mm256_loadu_ps(y + i + 16);
y3 = _mm256_loadu_ps(y + i + 24);

y0 = _mm256_mul_ps(y0, v4);
y1 = _mm256_mul_ps(y1, v4);
y2 = _mm256_mul_ps(y2, v4);
y3 = _mm256_mul_ps(y3, v4);

_mm256_storeu_ps(y + i + 0, y0);
_mm256_storeu_ps(y + i + 8, y1);
_mm256_storeu_ps(y + i + 16, y2);
_mm256_storeu_ps(y + i + 24, y3);
}

// leftovers
for (int i = n32; i < n; ++i) {
y[i] *= v;
}
#else
// scalar
for (int i = 0; i < n; ++i) {
y[i] *= v;
}
#endif
}

inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrt(*s); }
inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrt(x[i]); }
Expand Down Expand Up @@ -3172,22 +3216,96 @@ void ggml_compute_forward_dup_f16(
return;
}

//const int ne00 = src0->ne[0];
//const int ne01 = src0->ne[1];
//const int ne02 = src0->ne[2];
//const int ne03 = src0->ne[3];
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2];
const int ne03 = src0->ne[3];

//const size_t nb00 = src0->nb[0];
//const size_t nb01 = src0->nb[1];
//const size_t nb02 = src0->nb[2];
//const size_t nb03 = src0->nb[3];
const size_t nb00 = src0->nb[0];
const size_t nb01 = src0->nb[1];
const size_t nb02 = src0->nb[2];
const size_t nb03 = src0->nb[3];

if (ggml_is_contiguous(src0) && src0->type == dst->type) {
memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
return;
}

GGML_ASSERT(false); // TODO: implement
if (src0->nb[0] == sizeof(ggml_fp16_t)) {
if (dst->type == GGML_TYPE_F16) {
int id = 0;
const size_t rs = ne00*nb00;

for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
for (int i01 = 0; i01 < ne01; i01++) {
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
char * dst_ptr = (char *) dst->data + id*rs;

memcpy(dst_ptr, src0_ptr, rs);

id++;
}
}
}
} else if (dst->type == GGML_TYPE_F32) {
int id = 0;
float * dst_ptr = (float *) dst->data;

for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
for (int i01 = 0; i01 < ne01; i01++) {
for (int i00 = 0; i00 < ne00; i00++) {
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);

dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
id++;
}
}
}
}
} else {
GGML_ASSERT(false); // TODO: implement
}
} else {
//printf("%s: this is not optimal - fix me\n", __func__);

if (dst->type == GGML_TYPE_F32) {
int id = 0;
float * dst_ptr = (float *) dst->data;

for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
for (int i01 = 0; i01 < ne01; i01++) {
for (int i00 = 0; i00 < ne00; i00++) {
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);

dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
id++;
}
}
}
}
} else if (dst->type == GGML_TYPE_F16) {
int id = 0;
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;

for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
for (int i01 = 0; i01 < ne01; i01++) {
for (int i00 = 0; i00 < ne00; i00++) {
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);

dst_ptr[id] = *src0_ptr;
id++;
}
}
}
}
} else {
GGML_ASSERT(false); // TODO: implement
}
}
}

void ggml_compute_forward_dup_f32(
Expand Down
54 changes: 26 additions & 28 deletions Sources/whisper/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -681,34 +681,32 @@ struct ggml_opt_params {
bool print_forward_graph;
bool print_backward_graph;

union {
// ADAM parameters
struct {
int n_iter;

float alpha; // learning rate
float beta1;
float beta2;
float eps; // epsilon for numerical stability
float eps_f; // epsilon for convergence test
float eps_g; // epsilon for convergence test
} adam;

// LBFGS parameters
struct {
int m; // number of corrections to approximate the inv. Hessian
int n_iter;
int max_linesearch;

float eps; // convergence tolerance
float ftol; // line search tolerance
float wolfe;
float min_step;
float max_step;

enum ggml_linesearch linesearch;
} lbfgs;
};
// ADAM parameters
struct {
int n_iter;

float alpha; // learning rate
float beta1;
float beta2;
float eps; // epsilon for numerical stability
float eps_f; // epsilon for convergence test
float eps_g; // epsilon for convergence test
} adam;

// LBFGS parameters
struct {
int m; // number of corrections to approximate the inv. Hessian
int n_iter;
int max_linesearch;

float eps; // convergence tolerance
float ftol; // line search tolerance
float wolfe;
float min_step;
float max_step;

enum ggml_linesearch linesearch;
} lbfgs;
};

struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
Expand Down
34 changes: 34 additions & 0 deletions Sources/whisper/include/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,41 @@ extern "C" {
WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial);

// Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success, no more than n_max_tokens
// Returns -1 on failure
// TODO: not sure if correct
WHISPER_API int whisper_tokenize(
struct whisper_context * ctx,
const char * text,
whisper_token * tokens,
int n_max_tokens);

// Largest language id (i.e. number of available languages - 1)
WHISPER_API int whisper_lang_max_id();

// Return the id of the specified language, returns -1 if not found
// Examples:
// "de" -> 2
// "german" -> 2
WHISPER_API int whisper_lang_id(const char * lang);

// Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found
WHISPER_API const char * whisper_lang_str(int id);

// Use mel data at offset_ms to try and auto-detect the spoken language
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first
// Returns the top language id or negative on failure
// If not null, fills the lang_probs array with the probabilities of all languages
// The array must be whispe_lang_max_id() + 1 in size
// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
WHISPER_API int whisper_lang_auto_detect(
struct whisper_context * ctx,
int offset_ms,
int n_threads,
float * lang_probs);

WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
Expand All @@ -160,6 +192,7 @@ extern "C" {
WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id);

// Task tokens
WHISPER_API whisper_token whisper_token_translate (void);
Expand Down Expand Up @@ -225,6 +258,7 @@ extern "C" {
const whisper_token * prompt_tokens;
int prompt_n_tokens;

// for auto-detection, set to nullptr, "" or "auto"
const char * language;

struct {
Expand Down

0 comments on commit 1502317

Please sign in to comment.