Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
290 changes: 258 additions & 32 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

#include "rte.glsl"
#include "utils.glsl"
#if RMS_NORM_ROPE_FUSION
#include "rope_params.glsl"
#endif

layout (push_constant) uniform parameter
{
Expand All @@ -12,11 +15,16 @@ layout (push_constant) uniform parameter
uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;
uint misalign_offsets;
float param1; float param2; int param3;
#if RMS_NORM_ROPE_FUSION
rope_params rope;
#endif
} p;

#if !RMS_NORM_ROPE_FUSION
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#endif

// true if src0/src1 are the same shape and the indices can be reused without additional modulus
layout(constant_id = 0) const bool norepeat = false;
Expand Down
44 changes: 43 additions & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,32 @@
#include "generic_binary_head.glsl"
#include "types.glsl"

#if RMS_NORM_ROPE_FUSION

layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};

// data is passed from rms_norm -> rope through shared memory.
// rms_norm calls this data_d, rope calls this rope_data_a.
// Binding 2 is not used
shared FLOAT_TYPE rope_data_a[1024];
#define data_d rope_data_a

layout (binding = 3) readonly buffer R_Y {int rope_data_pos[];};
layout (binding = 4) readonly buffer R_Z {float rope_data_ff[];};
layout (binding = 5) writeonly buffer R_D {ROPE_D_TYPE rope_data_d[];};
layout (binding = 6) readonly buffer R_I {uvec2 rope_data_i[];}; // indices for set_rows

#include "rope_params.glsl"
#include "rope_funcs.glsl"

#define GGML_ROPE_TYPE_NORMAL 0
#define GGML_ROPE_TYPE_NEOX 2
#define GGML_ROPE_TYPE_MROPE 8
#define GGML_ROPE_TYPE_VISION 24

#endif

#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 512

Expand All @@ -28,8 +54,12 @@ void rms_norm(uint num_iters) {

uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
#if RMS_NORM_ROPE_FUSION
// Per-row offset in shared memory
uint32_t d_offset = 0;
#else
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();

#endif
FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp

[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
Expand Down Expand Up @@ -79,6 +109,18 @@ void rms_norm(uint num_iters) {
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
}
}
#if RMS_NORM_ROPE_FUSION
barrier();
rope_params rp = p.rope;
uint rope_row = (samp*nchannels + channel)*nrows + row;
for (uint t = 2*tid; t < ncols; t += 2*BLOCK_SIZE) {
if (rp.rope_mode == GGML_ROPE_TYPE_NEOX) {
rope_neox(t, rope_row, rp);
} else if (rp.rope_mode == GGML_ROPE_TYPE_NORMAL) {
rope_norm(t, rope_row, rp);
}
}
#endif
}

void main() {
Expand Down
227 changes: 227 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@

float rope_yarn_ramp(const float low, const float high, const uint i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
}

uint rope_a_coord(const uint i0, const uint i01, const uint i02, rope_params p) {
#if RMS_NORM_ROPE_FUSION
// Per-row offset in shared memory
const uint ix = i0;
#else
const uint ix = i02*p.nb02 + i01*p.nb01 + i0;
#endif
return ix;
}

void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta, rope_params p) {
float mscale = p.attn_factor;
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = p.freq_scale * theta_extrap;
float theta = theta_interp;
if (p.ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;

// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
}
// Backprogagation uses inverted rotation
if (p.is_back != 0) {
theta = -theta;
}
cos_theta = cos(theta) * mscale;
sin_theta = sin(theta) * mscale;
}

void rope_norm(const uint i0, const uint i1, rope_params p) {
uint ne0 = p.ncols;
uint ne1 = p.p_delta_rows;

if (i0 >= ne0) {
return;
}

// i1 is actually i2*nb2+i1, but the rows are contiguous
const uint i01 = i1 % ne1;
const uint i02 = i1 / ne1;

uint idst = i1*ne0 + i0;
const uint ix = rope_a_coord(i0, i01, i02, p);

// Fusion optimization: ROPE + VIEW + SET_ROWS..
// The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
if (p.set_rows_stride != 0) {
idst = i01*ne0 + i0;
idst += rope_data_i[i02].x * p.set_rows_stride;
}

if (i0 >= p.n_dims) {
rope_data_d[idst + 0] = ROPE_D_TYPE(rope_data_a[ix + 0]);
rope_data_d[idst + 1] = ROPE_D_TYPE(rope_data_a[ix + 1]);

return;
}

const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f);

const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;

float cos_theta, sin_theta;
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);

const float x0 = float(rope_data_a[ix + 0]);
const float x1 = float(rope_data_a[ix + 1]);

rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
rope_data_d[idst + 1] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
}

void rope_neox(const uint i0, const uint i1, rope_params p) {
uint ne0 = p.ncols;
uint ne1 = p.p_delta_rows;

if (i0 >= ne0) {
return;
}

const uint i01 = i1 % ne1;
const uint i02 = i1 / ne1;

uint idst = i1*ne0 + i0/2;
const uint ix = rope_a_coord(i0/2, i01, i02, p);

// Fusion optimization: ROPE + VIEW + SET_ROWS..
// The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
if (p.set_rows_stride != 0) {
idst = i01*ne0 + i0/2;
idst += rope_data_i[i02].x * p.set_rows_stride;
}

if (i0 >= p.n_dims) {
rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]);
rope_data_d[idst + i0/2 + 1] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 1]);

return;
}

const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f);

const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;

float cos_theta, sin_theta;
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);

const float x0 = float(rope_data_a[ix + 0]);
const float x1 = float(rope_data_a[ix + p.n_dims/2]);

rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
}


void rope_multi(const uint i0, const uint i1, rope_params p) {
uint ne0 = p.ncols;
uint ne1 = p.p_delta_rows;
uint ne2 = p.ne02;

if (i0 >= ne0) {
return;
}

const uint i01 = i1 % ne1;
const uint i02 = i1 / ne1;

const uint idst = i1*ne0 + i0/2;
const uint ix = rope_a_coord(i0/2, i01, i02, p);

if (i0 >= p.n_dims) {
rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]);
rope_data_d[idst + i0/2 + 1] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 1]);

return;
}

const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
const int sec_w = p.sections[1] + p.sections[0];
const uint sector = (i0 / 2) % sect_dims;

float theta_base = 0.0;
if (p.is_imrope != 0) {
if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
} else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
} else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
} else {
theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
}
} else {
if (sector < p.sections[0]) {
theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= p.sections[0] && sector < sec_w) {
theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= sec_w + p.sections[2]) {
theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
}
}

const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;

float cos_theta, sin_theta;
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);

const float x0 = float(rope_data_a[ix + 0]);
const float x1 = float(rope_data_a[ix + p.n_dims/2]);

rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
}

void rope_vision(const uint i0, const uint i1, rope_params p) {
uint ne0 = p.ncols;
uint ne1 = p.p_delta_rows;
uint ne2 = p.ne02;

if (i0 >= ne0) {
return;
}

const uint i01 = i1 % ne1;
const uint i02 = i1 / ne1;

const uint idst = i1*ne0 + i0/2;
const uint ix = rope_a_coord(i0/2, i01, i02, p);

const int sect_dims = p.sections[0] + p.sections[1];
const int sec_w = p.sections[1] + p.sections[0];
const uint sector = (i0 / 2) % sect_dims;

float theta_base = 0.0;
if (sector < p.sections[0]) {
const uint p0 = sector;
theta_base = rope_data_pos[i02]*pow(p.theta_scale, p0);
}
else if (sector >= p.sections[0] && sector < sec_w) {
const uint p0 = sector - p.sections[0];
theta_base = rope_data_pos[i02 + ne2]*pow(p.theta_scale, p0);
}

const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;

float cos_theta, sin_theta;
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p);

const float x0 = float(rope_data_a[ix + 0]);
const float x1 = float(rope_data_a[ix + p.n_dims]);

rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta);
rope_data_d[idst + p.n_dims] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
}

56 changes: 9 additions & 47 deletions ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -3,56 +3,18 @@
#extension GL_EXT_shader_16bit_storage : require

#include "rte.glsl"
#include "rope_params.glsl"

layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;

layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {int data_pos[];};
layout (binding = 2) readonly buffer Z {float data_ff[];};
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
layout (binding = 4) readonly buffer I {uvec2 data_i[];}; // indices for set_rows
layout (binding = 0) readonly buffer X {A_TYPE rope_data_a[];};
layout (binding = 1) readonly buffer Y {int rope_data_pos[];};
layout (binding = 2) readonly buffer Z {float rope_data_ff[];};
layout (binding = 3) writeonly buffer D {ROPE_D_TYPE rope_data_d[];};
layout (binding = 4) readonly buffer I {uvec2 rope_data_i[];}; // indices for set_rows

layout (push_constant) uniform parameter {
uint ncols;
uint n_dims;
float freq_scale;
uint p_delta_rows;
float freq_base;
float ext_factor;
float attn_factor;
float corr_dims[2];
float theta_scale;
uint has_ff;
uint ne02;
uint s1;
uint s2;
int sections[4];
uint is_imrope;
uint is_back;
uint set_rows_stride;
} p;

float rope_yarn_ramp(const float low, const float high, const uint i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
}

void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta) {
float mscale = p.attn_factor;
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = p.freq_scale * theta_extrap;
float theta = theta_interp;
if (p.ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
layout (push_constant) uniform parameter {
rope_params pc;
};

// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
}
// Backprogagation uses inverted rotation
if (p.is_back != 0) {
theta = -theta;
}
cos_theta = cos(theta) * mscale;
sin_theta = sin(theta) * mscale;
}
Loading
Loading