Skip to content

Commit

Permalink
calculate the freq_cis online, no need to write/read them to/from che…
Browse files Browse the repository at this point in the history
…ckpoints
  • Loading branch information
karpathy committed Aug 17, 2023
1 parent b68a6d2 commit bd18228
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions run.c
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ typedef struct {
float* w3; // (layer, hidden_dim, dim)
// final rmsnorm
float* rms_final_weight; // (dim,)
// freq_cis for RoPE relatively positional embeddings
// freq_cis for RoPE relatively positional embeddings (not used anymore)
float* freq_cis_real; // (seq_len, head_size/2)
float* freq_cis_imag; // (seq_len, head_size/2)
// (optional) classifier weights for the logits, on the last layer
Expand Down Expand Up @@ -214,10 +214,6 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
float* content_row = &(w->token_embedding_table[token * dim]);
memcpy(x, content_row, dim*sizeof(*x));

// pluck out the "pos" row of freq_cis_real and freq_cis_imag
float* freq_cis_real_row = w->freq_cis_real + pos * head_size / 2;
float* freq_cis_imag_row = w->freq_cis_imag + pos * head_size / 2;

// forward all the layers
for(int l = 0; l < p->n_layers; l++) {

Expand All @@ -229,15 +225,18 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim);
matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim);

// RoPE relative positional encoding: complex-valued rotate q and k by freq_cis in each head
for (int v = 0; v < 2; v++) {
float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key)
int vec_size = v == 0 ? dim : kv_dim; // the size of the vector
for (int i = 0; i < vec_size; i+=2) {
// RoPE relative positional encoding: complex-valued rotate q and k in each head
for (int i = 0; i < dim; i+=2) {
int head_dim = i % head_size;
float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size);
float val = pos * freq;
float fcr = cosf(val);
float fci = sinf(val);
int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
for (int v = 0; v < rotn; v++) {
float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key)
float v0 = vec[i];
float v1 = vec[i+1];
float fcr = freq_cis_real_row[(i % head_size) / 2];
float fci = freq_cis_imag_row[(i % head_size) / 2];
vec[i] = v0 * fcr - v1 * fci;
vec[i+1] = v0 * fci + v1 * fcr;
}
Expand Down

0 comments on commit bd18228

Please sign in to comment.