Skip to content

Commit

Permalink
Merge 1bd3e43 into 0014f29
Browse files Browse the repository at this point in the history
  • Loading branch information
Alan-Jowett committed Apr 29, 2021
2 parents 0014f29 + 1bd3e43 commit 57b065c
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 52 deletions.
20 changes: 19 additions & 1 deletion vm/inc/ubpf.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ void ubpf_destroy(struct ubpf_vm *vm);
*/
bool toggle_bounds_check(struct ubpf_vm *vm, bool enable);


/*
* Set the function to be invoked if the jitted program hits divide by zero.
*
* fprintf is the default function to be invoked on division by zero.
*/
void set_error_print(struct ubpf_vm *vm, int (*error_printf)(FILE* stream, const char* format, ...));

/*
* Register an external function
*
Expand Down Expand Up @@ -81,8 +89,18 @@ int ubpf_load(struct ubpf_vm *vm, const void *code, uint32_t code_len, char **er
*/
int ubpf_load_elf(struct ubpf_vm *vm, const void *elf, size_t elf_len, char **errmsg);

uint64_t ubpf_exec(const struct ubpf_vm *vm, void *mem, size_t mem_len);
uint64_t ubpf_exec(const struct ubpf_vm *vm, void *mem, size_t mem_len, char **errmsg);

ubpf_jit_fn ubpf_compile(struct ubpf_vm *vm, char **errmsg);

/*
* Translate the eBPF byte code to x64 machine code and store in buffer.
*
* This must be called after registering all functions.
*
* Returns 0 on success, -1 on error. In case of error a pointer to the error
* message will be stored in 'errmsg' and should be freed by the caller.
*/
int ubpf_translate(struct ubpf_vm *vm, uint8_t *buffer, size_t *size, char **errmsg);

#endif
5 changes: 4 additions & 1 deletion vm/test.c
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,10 @@ int main(int argc, char **argv)
}
ret = fn(mem, mem_len);
} else {
ret = ubpf_exec(vm, mem, mem_len);
ret = ubpf_exec(vm, mem, mem_len, &errmsg);
if (errmsg) {
fprintf(stderr, "%s\n", errmsg);
}
}

printf("0x%"PRIx64"\n", ret);
Expand Down
1 change: 1 addition & 0 deletions vm/ubpf_int.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ struct ubpf_vm {
ext_func *ext_funcs;
const char **ext_func_names;
bool bounds_check_enabled;
int (*error_printf)(FILE* stream, const char* format, ...);
};

char *ubpf_error(const char *fmt, ...);
Expand Down
134 changes: 98 additions & 36 deletions vm/ubpf_jit_x86_64.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,52 @@
#include "ubpf_int.h"
#include "ubpf_jit_x86_64.h"

#if !defined(_countof)
#define _countof(array) (sizeof(array) / sizeof(array[0]))
#endif

/* Special values for target_pc in struct jump */
#define TARGET_PC_EXIT -1
#define TARGET_PC_DIV_BY_ZERO -2

static void muldivmod(struct jit_state *state, uint16_t pc, uint8_t opcode, int src, int dst, int32_t imm);

#define REGISTER_MAP_SIZE 11

/*
* There are two common x86-64 calling conventions, as discussed at
* https://en.wikipedia.org/wiki/X86_calling_conventions#x86-64_calling_conventions
*/

#if defined(_WIN32)
static int platform_nonvolatile_registers[] = {
RBP, RBX, RDI, RSI, R12, R13, R14, R15
};
static int platform_parameter_registers[] = {
RCX, RDX, R8, R9
};
#define RCX_ALT R15
static int register_map[REGISTER_MAP_SIZE] = {
RAX,
R15,
RDX,
R8,
R9,
R10,
R11,
R12,
R13,
R14,
RBP,
};
#else
#define RCX_ALT R9
static int platform_nonvolatile_registers[] = {
RBP, RBX, R13, R14, R15
};
static int platform_parameter_registers[] = {
RDI, RSI, RDX, RCX, R8, R9
};
static int register_map[REGISTER_MAP_SIZE] = {
RAX,
RDI,
Expand All @@ -47,6 +86,7 @@ static int register_map[REGISTER_MAP_SIZE] = {
R15,
RBP,
};
#endif

/* Return the x86 register for the given eBPF register */
static int
Expand Down Expand Up @@ -82,15 +122,17 @@ ubpf_set_register_offset(int x)
static int
translate(struct ubpf_vm *vm, struct jit_state *state, char **errmsg)
{
emit_push(state, RBP);
emit_push(state, RBX);
emit_push(state, R13);
emit_push(state, R14);
emit_push(state, R15);

/* Move rdi into register 1 */
if (map_register(1) != RDI) {
emit_mov(state, RDI, map_register(1));
int i;

/* Save platform non-volatile registers */
for (i = 0; i < _countof(platform_nonvolatile_registers); i++)
{
emit_push(state, platform_nonvolatile_registers[i]);
}

/* Move first platform parameter register into register 1 */
if (map_register(1) != platform_parameter_registers[0]) {
emit_mov(state, platform_parameter_registers[0], map_register(1));
}

/* Copy stack pointer to R10 */
Expand All @@ -99,7 +141,6 @@ translate(struct ubpf_vm *vm, struct jit_state *state, char **errmsg)
/* Allocate stack space */
emit_alu64_imm32(state, 0x81, 5, RSP, STACK_SIZE);

int i;
for (i = 0; i < vm->num_insts; i++) {
struct ebpf_inst inst = vm->insts[i];
state->pc_locs[i] = state->offset;
Expand Down Expand Up @@ -359,7 +400,7 @@ translate(struct ubpf_vm *vm, struct jit_state *state, char **errmsg)
break;
case EBPF_OP_CALL:
/* We reserve RCX for shifts */
emit_mov(state, R9, RCX);
emit_mov(state, RCX_ALT, RCX);
emit_call(state, vm->ext_funcs[inst.imm]);
break;
case EBPF_OP_EXIT:
Expand Down Expand Up @@ -431,21 +472,23 @@ translate(struct ubpf_vm *vm, struct jit_state *state, char **errmsg)
/* Deallocate stack space */
emit_alu64_imm32(state, 0x81, 0, RSP, STACK_SIZE);

emit_pop(state, R15);
emit_pop(state, R14);
emit_pop(state, R13);
emit_pop(state, RBX);
emit_pop(state, RBP);
/* Restore platform non-volatile registers */
for (i = 0; i < _countof(platform_nonvolatile_registers); i++)
{
emit_pop(state, platform_nonvolatile_registers[_countof(platform_nonvolatile_registers) - i - 1]);
}

emit1(state, 0xc3); /* ret */

/* Division by zero handler */
const char *div_by_zero_fmt = "uBPF error: division by zero at PC %u\n";
state->div_by_zero_loc = state->offset;
emit_load_imm(state, RDI, (uintptr_t)stderr);
emit_load_imm(state, RSI, (uintptr_t)div_by_zero_fmt);
emit_mov(state, RCX, RDX); /* muldivmod stored pc in RCX */
emit_call(state, fprintf);
const char *div_by_zero_fmt = "uBPF error: division by zero at PC %u\n";
// RCX is the first parameter register for Windows, so first save the value.
emit_mov(state, RCX, platform_parameter_registers[2]); /* muldivmod stored pc in RCX */
emit_load_imm(state, platform_parameter_registers[0], (uintptr_t)stderr);
emit_load_imm(state, platform_parameter_registers[1], (uintptr_t)div_by_zero_fmt);
emit_call(state, vm->error_printf);

emit_load_imm(state, map_register(0), -1);
emit_jmp(state, TARGET_PC_EXIT);

Expand Down Expand Up @@ -538,12 +581,40 @@ resolve_jumps(struct jit_state *state)
}
}

int
ubpf_translate(struct ubpf_vm *vm, uint8_t * buffer, size_t * size, char **errmsg)
{
struct jit_state state;
int result = -1;

state.offset = 0;
state.size = *size;
state.buf = buffer;
state.pc_locs = calloc(MAX_INSTS+1, sizeof(state.pc_locs[0]));
state.jumps = calloc(MAX_INSTS, sizeof(state.jumps[0]));
state.num_jumps = 0;

if (translate(vm, &state, errmsg) < 0) {
goto out;
}

resolve_jumps(&state);
result = 0;

*size = state.offset;

out:
free(state.pc_locs);
free(state.jumps);
return result;
}

ubpf_jit_fn
ubpf_compile(struct ubpf_vm *vm, char **errmsg)
{
void *jitted = NULL;
uint8_t *buffer = NULL;
size_t jitted_size;
struct jit_state state;

if (vm->jitted) {
return vm->jitted;
Expand All @@ -556,27 +627,20 @@ ubpf_compile(struct ubpf_vm *vm, char **errmsg)
return NULL;
}

state.offset = 0;
state.size = 65536;
state.buf = calloc(state.size, 1);
state.pc_locs = calloc(MAX_INSTS+1, sizeof(state.pc_locs[0]));
state.jumps = calloc(MAX_INSTS, sizeof(state.jumps[0]));
state.num_jumps = 0;
jitted_size = 65536;
buffer = calloc(jitted_size, 1);

if (translate(vm, &state, errmsg) < 0) {
if (ubpf_translate(vm, buffer, &jitted_size, errmsg) < 0) {
goto out;
}

resolve_jumps(&state);

jitted_size = state.offset;
jitted = mmap(0, jitted_size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
if (jitted == MAP_FAILED) {
*errmsg = ubpf_error("internal uBPF error: mmap failed: %s\n", strerror(errno));
goto out;
}

memcpy(jitted, state.buf, jitted_size);
memcpy(jitted, buffer, jitted_size);

if (mprotect(jitted, jitted_size, PROT_READ | PROT_EXEC) < 0) {
*errmsg = ubpf_error("internal uBPF error: mprotect failed: %s\n", strerror(errno));
Expand All @@ -587,9 +651,7 @@ ubpf_compile(struct ubpf_vm *vm, char **errmsg)
vm->jitted_size = jitted_size;

out:
free(state.buf);
free(state.pc_locs);
free(state.jumps);
free(buffer);
if (jitted && vm->jitted == NULL) {
munmap(jitted, jitted_size);
}
Expand Down
Loading

0 comments on commit 57b065c

Please sign in to comment.