Skip to content

Commit 5b8023d

Browse files
committed
Implement prototype for instant mmap() loading
This change uses a custom malloc() implementation to transactionally capture to a file dynamic memory created during the loading process. That includes (1) the malloc() allocation for mem_buffer and (2) all the C++ STL objects. On my $1000 personal computer, this change lets me run ./main to generate a single token (-n 1) using the float16 7B model (~12gb size) in one second. In order to do that, there's a one time cost where a 13gb file needs to be generated. This change rocks but it shouldn't be necessary to do something this heroic. We should instead change the file format, so that tensors don't need reshaping and realignment in order to be loaded.
1 parent 2788f37 commit 5b8023d

File tree

2 files changed

+207
-19
lines changed

2 files changed

+207
-19
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ models/*
1818

1919
/main
2020
/quantize
21+
/magic.dat
2122

2223
arm_neon.h
2324
compile_commands.json

main.cpp

+206-19
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,32 @@
33
#include "utils.h"
44

55
#include <cassert>
6+
#include <cerrno>
67
#include <cmath>
78
#include <cstdio>
89
#include <cstring>
910
#include <fstream>
1011
#include <map>
1112
#include <string>
1213
#include <vector>
14+
#include <atomic>
1315

1416
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
17+
#include <fcntl.h>
1518
#include <signal.h>
1619
#include <unistd.h>
20+
#include <sys/mman.h>
21+
#include <sys/stat.h>
1722
#endif
1823

24+
#define ROUNDUP(X, K) (((X) + (K)-1) & -(K))
25+
#define IS2POW(X) (!((X) & ((X)-1)))
26+
27+
#define MAGIC_PATH "magic.dat"
28+
#define MAGIC_ADDR (char *)0x330000000000
29+
#define MAGIC_GRAN 2097152
30+
#define MAGIC_ALGN (sizeof(size_t) * 2)
31+
1932
#define ANSI_COLOR_RED "\x1b[31m"
2033
#define ANSI_COLOR_GREEN "\x1b[32m"
2134
#define ANSI_COLOR_YELLOW "\x1b[33m"
@@ -83,6 +96,173 @@ struct llama_model {
8396
std::map<std::string, struct ggml_tensor *> tensors;
8497
};
8598

99+
struct magic {
100+
uint32_t magic;
101+
std::atomic<unsigned> lock;
102+
int fd;
103+
size_t commit;
104+
size_t offset;
105+
size_t capacity;
106+
gpt_vocab *vocab;
107+
llama_model *model;
108+
};
109+
110+
static struct magic *mag;
111+
112+
static inline void spin_lock(std::atomic<unsigned> &lock) {
113+
while (!lock.exchange(1, std::memory_order_acquire));
114+
}
115+
116+
static inline void spin_unlock(std::atomic<unsigned> &lock) {
117+
lock.store(0, std::memory_order_release);
118+
}
119+
120+
static void *Mmap(void *addr, size_t length, int prot, int flags, int fd, off_t offset) {
121+
void *res;
122+
res = mmap(addr, length, prot, flags, fd, offset);
123+
if (res != MAP_FAILED) return res;
124+
perror("mmap");
125+
exit(77);
126+
}
127+
128+
static void magic_commit(void) {
129+
mag->offset = mag->capacity;
130+
mag->commit = mag->capacity;
131+
mag->magic = 0xFEEDABEE;
132+
msync(mag, mag->commit, MS_ASYNC);
133+
}
134+
135+
static void magic_init(void) {
136+
int fd;
137+
size_t n;
138+
struct stat st;
139+
if (mag) return;
140+
n = ROUNDUP(sizeof(struct magic), MAGIC_GRAN);
141+
if ((fd = open(MAGIC_PATH, O_RDWR)) != -1) {
142+
fstat(fd, &st);
143+
if (st.st_size >= n) {
144+
mag = (struct magic *)Mmap(MAGIC_ADDR, n,
145+
PROT_READ | PROT_WRITE,
146+
MAP_PRIVATE | MAP_FIXED, fd, 0);
147+
if (mag->magic == 0xFEEDABEE) {
148+
mag = (struct magic *)Mmap(MAGIC_ADDR, mag->capacity,
149+
PROT_READ | PROT_WRITE,
150+
MAP_PRIVATE | MAP_FIXED, fd, 0);
151+
madvise(MAGIC_ADDR, mag->capacity, MADV_WILLNEED);
152+
ftruncate(fd, mag->commit);
153+
mag->offset = mag->commit;
154+
mag->capacity = mag->commit;
155+
mag->fd = -1;
156+
return;
157+
}
158+
}
159+
ftruncate(fd, 0);
160+
} else if ((fd = open(MAGIC_PATH, O_RDWR | O_CREAT | O_TRUNC, 0644)) == -1) {
161+
perror(MAGIC_PATH);
162+
exit(77);
163+
}
164+
ftruncate(fd, n);
165+
mag = (struct magic *)Mmap(MAGIC_ADDR, n,
166+
PROT_READ | PROT_WRITE,
167+
MAP_SHARED | MAP_FIXED, fd, 0);
168+
mag->offset = MAGIC_GRAN;
169+
mag->fd = fd;
170+
}
171+
172+
void *memalign(size_t a, size_t n) {
173+
void *p;
174+
size_t i, j, k, m;
175+
static int count;
176+
magic_init();
177+
if (a < MAGIC_ALGN) a = MAGIC_ALGN;
178+
while (!IS2POW(a)) ++a;
179+
m = n ? n : 1;
180+
spin_lock(mag->lock);
181+
i = mag->offset;
182+
i = i + sizeof(size_t);
183+
i = ROUNDUP(i, a);
184+
j = ROUNDUP(i + m, MAGIC_GRAN);
185+
if (j > mag->capacity) {
186+
if (!mag->magic) {
187+
ftruncate(mag->fd, j);
188+
p = mmap(MAGIC_ADDR + mag->capacity,
189+
j - mag->capacity, PROT_READ | PROT_WRITE,
190+
MAP_SHARED | MAP_FIXED, mag->fd, mag->capacity);
191+
} else {
192+
p = mmap(MAGIC_ADDR + mag->capacity,
193+
j - mag->capacity, PROT_READ | PROT_WRITE,
194+
MAP_PRIVATE | MAP_ANONYMOUS | MAP_FIXED, -1, 0);
195+
}
196+
if (p != MAP_FAILED) {
197+
mag->capacity = j;
198+
} else {
199+
spin_unlock(mag->lock);
200+
return 0;
201+
}
202+
}
203+
mag->offset = i + m;
204+
spin_unlock(mag->lock);
205+
p = MAGIC_ADDR + i;
206+
((size_t *)p)[-1] = n;
207+
return p;
208+
}
209+
210+
int posix_memalign(void **pp, size_t a, size_t n) {
211+
int e;
212+
void *m;
213+
size_t q, r;
214+
q = a / sizeof(void *);
215+
r = a % sizeof(void *);
216+
if (!r && q && IS2POW(q)) {
217+
e = errno;
218+
m = memalign(a, n);
219+
if (m) {
220+
*pp = m;
221+
return 0;
222+
} else {
223+
errno = e;
224+
return ENOMEM;
225+
}
226+
} else {
227+
return EINVAL;
228+
}
229+
}
230+
231+
void *malloc(size_t n) {
232+
return memalign(MAGIC_ALGN, n);
233+
}
234+
235+
size_t malloc_usable_size(const void *p) {
236+
return ((const size_t *)p)[-1];
237+
}
238+
239+
void *calloc(size_t n, size_t z) {
240+
void *p;
241+
if ((p = malloc((n *= z)))) {
242+
memset(p, 0, n);
243+
}
244+
return p;
245+
}
246+
247+
void free(void *p) {
248+
// do nothing
249+
}
250+
251+
void *realloc(void *p, size_t n) {
252+
void *q;
253+
if (!p) {
254+
return malloc(n);
255+
}
256+
if (!n) {
257+
free(p);
258+
return 0;
259+
}
260+
if ((q = malloc(n))) {
261+
memcpy(q, p, ((const size_t *)p)[-1]);
262+
}
263+
return q;
264+
}
265+
86266
// load the model's weights from a file
87267
bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx) {
88268
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
@@ -786,6 +966,8 @@ const char * llama_print_system_info(void) {
786966
}
787967

788968
int main(int argc, char ** argv) {
969+
magic_init();
970+
789971
ggml_time_init();
790972
const int64_t t_main_start_us = ggml_time_us();
791973

@@ -812,19 +994,24 @@ int main(int argc, char ** argv) {
812994

813995
int64_t t_load_us = 0;
814996

815-
gpt_vocab vocab;
816-
llama_model model;
817-
818997
// load the model
819-
{
998+
gpt_vocab *vocab;
999+
llama_model *model;
1000+
if (!mag->magic) {
1001+
vocab = new gpt_vocab;
1002+
model = new llama_model;
8201003
const int64_t t_start_us = ggml_time_us();
821-
822-
if (!llama_model_load(params.model, model, vocab, 512)) { // TODO: set context from user input ??
1004+
if (!llama_model_load(params.model, *model, *vocab, 512)) { // TODO: set context from user input ??
8231005
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
8241006
return 1;
8251007
}
826-
8271008
t_load_us = ggml_time_us() - t_start_us;
1009+
mag->vocab = vocab;
1010+
mag->model = model;
1011+
magic_commit();
1012+
} else {
1013+
vocab = mag->vocab;
1014+
model = mag->model;
8281015
}
8291016

8301017
// print system information
@@ -842,18 +1029,18 @@ int main(int argc, char ** argv) {
8421029
std::vector<float> logits;
8431030

8441031
// tokenize the prompt
845-
std::vector<gpt_vocab::id> embd_inp = ::llama_tokenize(vocab, params.prompt, true);
1032+
std::vector<gpt_vocab::id> embd_inp = ::llama_tokenize(*vocab, params.prompt, true);
8461033

847-
params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());
1034+
params.n_predict = std::min(params.n_predict, model->hparams.n_ctx - (int) embd_inp.size());
8481035

8491036
// tokenize the reverse prompt
850-
std::vector<gpt_vocab::id> antiprompt_inp = ::llama_tokenize(vocab, params.antiprompt, false);
1037+
std::vector<gpt_vocab::id> antiprompt_inp = ::llama_tokenize(*vocab, params.antiprompt, false);
8511038

8521039
fprintf(stderr, "\n");
8531040
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
8541041
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
8551042
for (int i = 0; i < (int) embd_inp.size(); i++) {
856-
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str());
1043+
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], vocab->id_to_token.at(embd_inp[i]).c_str());
8571044
}
8581045
fprintf(stderr, "\n");
8591046
if (params.interactive) {
@@ -871,7 +1058,7 @@ int main(int argc, char ** argv) {
8711058
fprintf(stderr, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.c_str());
8721059
fprintf(stderr, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size());
8731060
for (int i = 0; i < (int) antiprompt_inp.size(); i++) {
874-
fprintf(stderr, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str());
1061+
fprintf(stderr, "%6d -> '%s'\n", antiprompt_inp[i], vocab->id_to_token.at(antiprompt_inp[i]).c_str());
8751062
}
8761063
fprintf(stderr, "\n");
8771064
}
@@ -883,7 +1070,7 @@ int main(int argc, char ** argv) {
8831070

8841071
// determine the required inference memory per token:
8851072
size_t mem_per_token = 0;
886-
llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
1073+
llama_eval(*model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
8871074

8881075
int last_n_size = params.repeat_last_n;
8891076
std::vector<gpt_vocab::id> last_n_tokens(last_n_size);
@@ -918,7 +1105,7 @@ int main(int argc, char ** argv) {
9181105
if (embd.size() > 0) {
9191106
const int64_t t_start_us = ggml_time_us();
9201107

921-
if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {
1108+
if (!llama_eval(*model, params.n_threads, n_past, embd, logits, mem_per_token)) {
9221109
fprintf(stderr, "Failed to predict\n");
9231110
return 1;
9241111
}
@@ -936,14 +1123,14 @@ int main(int argc, char ** argv) {
9361123
const float temp = params.temp;
9371124
const float repeat_penalty = params.repeat_penalty;
9381125

939-
const int n_vocab = model.hparams.n_vocab;
1126+
const int n_vocab = model->hparams.n_vocab;
9401127

9411128
gpt_vocab::id id = 0;
9421129

9431130
{
9441131
const int64_t t_start_sample_us = ggml_time_us();
9451132

946-
id = llama_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng);
1133+
id = llama_sample_top_p_top_k(*vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng);
9471134

9481135
last_n_tokens.erase(last_n_tokens.begin());
9491136
last_n_tokens.push_back(id);
@@ -980,7 +1167,7 @@ int main(int argc, char ** argv) {
9801167
// display text
9811168
if (!input_noecho) {
9821169
for (auto id : embd) {
983-
printf("%s", vocab.id_to_token[id].c_str());
1170+
printf("%s", vocab->id_to_token[id].c_str());
9841171
}
9851172
fflush(stdout);
9861173
}
@@ -1018,7 +1205,7 @@ int main(int argc, char ** argv) {
10181205
buf[n_read+1] = 0;
10191206
}
10201207

1021-
std::vector<gpt_vocab::id> line_inp = ::llama_tokenize(vocab, buf, false);
1208+
std::vector<gpt_vocab::id> line_inp = ::llama_tokenize(*vocab, buf, false);
10221209
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
10231210

10241211
remaining_tokens -= line_inp.size();
@@ -1050,7 +1237,7 @@ int main(int argc, char ** argv) {
10501237
fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
10511238
}
10521239

1053-
ggml_free(model.ctx);
1240+
ggml_free(model->ctx);
10541241

10551242
return 0;
10561243
}

0 commit comments

Comments
 (0)