Skip to content

Commit

Permalink
bigchange: add multiquery support in run.c. we can now train and infe…
Browse files Browse the repository at this point in the history
…rence multiquery models (where n_kv_heads < n_heads). this also means that we, in principle, support Llama 2 34B and 70B models, which are multiquery
  • Loading branch information
karpathy committed Aug 13, 2023
1 parent 9ff459b commit 38bfac9
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 26 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,6 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg

- revive tests; train a tiny Llama test model (committed to repo) and use it as reference in unit tests
- make it easier to add a new dataset with not too much pain
- add multiquery support into run.c
- should calculate freq_cis online in the script run.c instead of loading them
- int4/8 quantization
- export the model in a more sensible output format with a proper header, etc.
Expand Down
1 change: 1 addition & 0 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
model_parallel_size = 1
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
Expand Down
53 changes: 30 additions & 23 deletions run.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ typedef struct {
// weights for rmsnorms
float* rms_att_weight; // (layer, dim) rmsnorm weights
float* rms_ffn_weight; // (layer, dim)
// weights for matmuls
float* wq; // (layer, dim, dim)
float* wk; // (layer, dim, dim)
float* wv; // (layer, dim, dim)
float* wo; // (layer, dim, dim)
// weights for matmuls. note dim == n_heads * head_size
float* wq; // (layer, dim, n_heads * head_size)
float* wk; // (layer, dim, n_kv_heads * head_size)
float* wv; // (layer, dim, n_kv_heads * head_size)
float* wo; // (layer, n_heads * head_size, dim)
// weights for ffn
float* w1; // (layer, hidden_dim, dim)
float* w2; // (layer, dim, hidden_dim)
Expand Down Expand Up @@ -82,6 +82,7 @@ typedef struct {

void malloc_run_state(RunState* s, Config* p) {
// we calloc instead of malloc to keep valgrind happy
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
s->x = calloc(p->dim, sizeof(float));
s->xb = calloc(p->dim, sizeof(float));
s->xb2 = calloc(p->dim, sizeof(float));
Expand All @@ -93,8 +94,8 @@ void malloc_run_state(RunState* s, Config* p) {
s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
s->logits = calloc(p->vocab_size, sizeof(float));
s->probindex = calloc(p->vocab_size, sizeof(ProbIndex));
s->key_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float));
s->value_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float));
s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
// ensure all mallocs went fine
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache
Expand Down Expand Up @@ -124,19 +125,20 @@ void free_run_state(RunState* s) {
// initialization: read from checkpoint

void checkpoint_init_weights(TransformerWeights *w, Config* p, float* f, int shared_weights) {
int head_size = p->dim / p->n_heads;
float* ptr = f;
w->token_embedding_table = ptr;
ptr += p->vocab_size * p->dim;
w->rms_att_weight = ptr;
ptr += p->n_layers * p->dim;
w->wq = ptr;
ptr += p->n_layers * p->dim * p->dim;
ptr += p->n_layers * p->dim * (p->n_heads * head_size);
w->wk = ptr;
ptr += p->n_layers * p->dim * p->dim;
ptr += p->n_layers * p->dim * (p->n_kv_heads * head_size);
w->wv = ptr;
ptr += p->n_layers * p->dim * p->dim;
ptr += p->n_layers * p->dim * (p->n_kv_heads * head_size);
w->wo = ptr;
ptr += p->n_layers * p->dim * p->dim;
ptr += p->n_layers * (p->n_heads * head_size) * p->dim;
w->rms_ffn_weight = ptr;
ptr += p->n_layers * p->dim;
w->w1 = ptr;
Expand All @@ -148,7 +150,6 @@ void checkpoint_init_weights(TransformerWeights *w, Config* p, float* f, int sha
w->rms_final_weight = ptr;
ptr += p->dim;
w->freq_cis_real = ptr;
int head_size = p->dim / p->n_heads;
ptr += p->seq_len * head_size / 2;
w->freq_cis_imag = ptr;
ptr += p->seq_len * head_size / 2;
Expand Down Expand Up @@ -218,6 +219,8 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
// a few convenience variables
float *x = s->x;
int dim = p->dim;
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiquery
int hidden_dim = p->hidden_dim;
int head_size = dim / p->n_heads;

Expand All @@ -237,29 +240,33 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*

// qkv matmuls for this position
matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim);
matmul(s->k, s->xb, w->wk + l*dim*dim, dim, dim);
matmul(s->v, s->xb, w->wv + l*dim*dim, dim, dim);
matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim);
matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim);

// RoPE relative positional encoding: complex-valued rotate q and k by freq_cis in each head
for (int i = 0; i < dim; i+=2) {
float q0 = s->q[i];
float q1 = s->q[i+1];
float k0 = s->k[i];
float k1 = s->k[i+1];
float fcr = freq_cis_real_row[(i % head_size) / 2];
float fci = freq_cis_imag_row[(i % head_size) / 2];
s->q[i] = q0 * fcr - q1 * fci;
s->q[i+1] = q0 * fci + q1 * fcr;
}
for (int i = 0; i < kv_dim; i+=2) {
float k0 = s->k[i];
float k1 = s->k[i+1];
float fcr = freq_cis_real_row[(i % head_size) / 2];
float fci = freq_cis_imag_row[(i % head_size) / 2];
s->k[i] = k0 * fcr - k1 * fci;
s->k[i+1] = k0 * fci + k1 * fcr;
}

// save key,value at this time step (pos) to our kv cache
int loff = l * p->seq_len * dim; // kv cache layer offset for convenience
float* key_cache_row = s->key_cache + loff + pos * dim;
float* value_cache_row = s->value_cache + loff + pos * dim;
memcpy(key_cache_row, s->k, dim*sizeof(*key_cache_row));
memcpy(value_cache_row, s->v, dim*sizeof(*value_cache_row));
int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
float* key_cache_row = s->key_cache + loff + pos * kv_dim;
float* value_cache_row = s->value_cache + loff + pos * kv_dim;
memcpy(key_cache_row, s->k, kv_dim * sizeof(*key_cache_row));
memcpy(value_cache_row, s->v, kv_dim * sizeof(*value_cache_row));

// multihead attention. iterate over all heads
int h;
Expand All @@ -272,7 +279,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
// iterate over all timesteps, including the current one
for (int t = 0; t <= pos; t++) {
// get the key vector for this head and at this timestep
float* k = s->key_cache + loff + t * dim + h * head_size;
float* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
// calculate the attention score as the dot product of q and k
float score = 0.0f;
for (int i = 0; i < head_size; i++) {
Expand All @@ -291,7 +298,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
memset(xb, 0, head_size * sizeof(float));
for (int t = 0; t <= pos; t++) {
// get the value vector for this head and at this timestep
float* v = s->value_cache + loff + t * dim + h * head_size;
float* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
// get the attention weight for this timestep
float a = att[t];
// accumulate the weighted value into xb
Expand Down
1 change: 0 additions & 1 deletion sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
model = torch.compile(model) # requires PyTorch 2.0 (optional)

# load the tokenizer
assert checkpoint["config"]["dataset"] == "tinystories" # TODO: generalize
tokenizer_model = get_tokenizer_model_path(vocab_size=gptconf.vocab_size)
enc = Tokenizer(tokenizer_model=tokenizer_model)

Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
dim = 288
n_layers = 6
n_heads = 6
n_kv_heads = 6
multiple_of = 32
dropout = 0.0
# adamw optimizer
Expand Down Expand Up @@ -146,7 +147,7 @@
dim=dim,
n_layers=n_layers,
n_heads=n_heads,
n_kv_heads=n_heads,
n_kv_heads=n_kv_heads,
vocab_size=vocab_size,
multiple_of=multiple_of,
max_seq_len=max_seq_len,
Expand Down

0 comments on commit 38bfac9

Please sign in to comment.