Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Fetching contributors…

Cannot retrieve contributors at this time

394 lines (344 sloc) 11.974 kB
#include "coroutine.h"
#include <assert.h>
#define __GNU_SOURCE
#include <dlfcn.h>
#include <pthread.h>
#include <ucontext.h>
#include <stdexcept>
#include <stack>
#include <vector>
#define MAX_POOL_SIZE 120
#include <iostream>
using namespace std;
/**
* These are all the pthread functions we hook. Some are more complicated to hook than others.
*/
typedef void(*pthread_dtor_t)(void*);
static int (*o_pthread_create)(pthread_t*, const pthread_attr_t*, void*(*)(void*), void*);
static int (*o_pthread_key_delete)(pthread_key_t);
static int (*o_pthread_equal)(pthread_t, pthread_t);
static void* (*o_pthread_getspecific)(pthread_key_t);
static int (*o_pthread_join)(pthread_key_t, void**);
static pthread_t (*o_pthread_self)(void);
static int (*o_pthread_setspecific)(pthread_key_t, const void*);
typedef int(pthread_key_create_t)(pthread_key_t*, pthread_dtor_t);
static pthread_key_create_t* o_pthread_key_create = NULL;
static pthread_key_create_t& dyn_pthread_key_create();
/**
* Very early on when this library is loaded there are callers to pthread_key_create,
* pthread_getspecific, and pthread_setspecific. We normally could pass these calls through down to
* the original implementation but that doesn't work because dlsym() tries to get a lock which ends
* up calling pthread_getspecific and pthread_setspecific. So have to implement our own versions of
* these functions assuming one thread only and then as soon as we can, put all that saved data into
* a better structure.
*
* If there are keys reserved that are lower than `thread_key` we always pass those through to the
* underlying implementation.
*
* This code assumes the underlying pthread library uses increasing TLS keys, while remaining
* under a constant number of them. These are both not safe assumptions since pthread_t is
* technically opaque.
*/
static const size_t MAX_EARLY_KEYS = 500;
static const void* pthread_early_vals[500] = { NULL };
static pthread_dtor_t pthread_early_dtors[500] = { NULL };
static pthread_key_t prev_synthetic_key = NULL;
static bool did_hook_pthreads = false;
static bool did_reserve_key = false;
static pthread_key_t thread_key;
static bool initialized = false;
/**
* Boing
*/
void thread_trampoline(void** data);
/**
* Thread is only used internally for this library. It keeps track of all the fibers this thread
* is currently running, and handles all the fiber-local storage logic. We store a handle to a
* Thread object in TLS, and then it emulates TLS on top of fibers.
*/
class Thread {
private:
static vector<pthread_dtor_t> dtors;
vector<Coroutine*> fiber_pool;
public:
pthread_t handle;
volatile Coroutine* current_fiber;
volatile Coroutine* delete_me;
static void free(void* that) {
delete static_cast<Thread*>(that);
}
Thread() : handle(NULL), delete_me(NULL) {
current_fiber = new Coroutine(*this);
}
~Thread() {
for (size_t ii = 0; ii < fiber_pool.size(); ++ii) {
delete fiber_pool[ii];
}
}
void coroutine_fls_dtor(Coroutine& fiber) {
bool did_delete;
do {
did_delete = false;
for (size_t ii = 0; ii < fiber.fls_data.size(); ++ii) {
if (fiber.fls_data[ii]) {
if (dtors[ii]) {
void* tmp = fiber.fls_data[ii];
fiber.fls_data[ii] = NULL;
dtors[ii](tmp);
did_delete = true;
} else {
fiber.fls_data[ii] = NULL;
}
}
}
} while (did_delete);
}
void fiber_did_finish(Coroutine& fiber) {
if (fiber_pool.size() < MAX_POOL_SIZE) {
fiber_pool.push_back(&fiber);
} else {
coroutine_fls_dtor(fiber);
// Can't delete right now because we're currently on this stack!
assert(delete_me == NULL);
delete_me = &fiber;
}
}
Coroutine& create_fiber(Coroutine::entry_t& entry, void* arg) {
if (!fiber_pool.empty()) {
Coroutine& fiber = *fiber_pool.back();
fiber_pool.pop_back();
fiber.reset(entry, arg);
return fiber;
}
return *new Coroutine(*this, entry, arg);
}
void* get_specific(pthread_key_t key) {
if (const_cast<Coroutine*>(current_fiber)->fls_data.size() <= key) {
return NULL;
}
return const_cast<Coroutine*>(current_fiber)->fls_data[key];
}
void set_specific(pthread_key_t key, const void* data) {
if (const_cast<Coroutine*>(current_fiber)->fls_data.size() <= key) {
const_cast<Coroutine*>(current_fiber)->fls_data.resize(key + 1);
}
const_cast<Coroutine*>(current_fiber)->fls_data[key] = (void*)data;
}
void key_create(pthread_key_t* key, pthread_dtor_t dtor) {
dtors.push_back(dtor);
*key = dtors.size() - 1; // TODO: This is NOT thread-safe! =O
}
void key_delete(pthread_key_t key) {
if (!dtors[key]) {
return;
}
// This doesn't call the dtor on all threads / fibers. Do I really care?
if (get_specific(key)) {
dtors[key](get_specific(key));
set_specific(key, NULL);
}
}
};
vector<pthread_dtor_t> Thread::dtors;
/**
* Coroutine class definition
*/
size_t Coroutine::stack_size = 0;
void Coroutine::trampoline(Coroutine &that) {
while (true) {
that.entry(const_cast<void*>(that.arg));
}
}
Coroutine& Coroutine::current() {
Thread& thread = *static_cast<Thread*>(o_pthread_getspecific(thread_key));
return *const_cast<Coroutine*>(thread.current_fiber);
}
const bool Coroutine::is_local_storage_enabled() {
return did_hook_pthreads;
}
void Coroutine::set_stack_size(size_t size) {
assert(!Coroutine::stack_size);
Coroutine::stack_size = size;
}
Coroutine::Coroutine(Thread& t) : thread(t) {}
Coroutine::Coroutine(Thread& t, entry_t& entry, void* arg) :
thread(t),
stack(stack_size),
entry(entry),
arg(arg) {
getcontext(&context);
context.uc_stack.ss_size = stack_size;
context.uc_stack.ss_sp = &stack[0];
makecontext(&context, (void(*)(void))trampoline, 1, this);
}
Coroutine& Coroutine::create_fiber(entry_t* entry, void* arg) {
Thread& thread = *static_cast<Thread*>(o_pthread_getspecific(thread_key));
return thread.create_fiber(*entry, arg);
}
void Coroutine::reset(entry_t* entry, void* arg) {
this->entry = entry;
this->arg = arg;
}
void Coroutine::run() volatile {
Coroutine& current = *const_cast<Coroutine*>(thread.current_fiber);
assert(&current != this);
if (thread.delete_me) {
assert(this != thread.delete_me);
assert(&current != thread.delete_me);
delete thread.delete_me;
thread.delete_me = NULL;
}
thread.current_fiber = this;
swapcontext(&current.context, const_cast<ucontext_t*>(&context));
}
void Coroutine::finish(Coroutine& next) {
this->thread.fiber_did_finish(*this);
thread.current_fiber = &next;
swapcontext(&context, &next.context);
}
void* Coroutine::bottom() const {
return (char*)&stack[0];
}
size_t Coroutine::size() const {
return sizeof(Coroutine) + stack_size;
}
/**
* TLS hooks
*/
// See comment above MAX_EARLY_KEYS as to why these functions are difficult to hook.
// Note well that in the `!initialized` case there is no heap. Calls to malloc, etc will crash your
// shit.
void* pthread_getspecific(pthread_key_t key) {
if (initialized) {
if (thread_key >= key) {
return o_pthread_getspecific(key);
}
Thread& thread = *static_cast<Thread*>(o_pthread_getspecific(thread_key));
return thread.get_specific(key - thread_key - 1);
} else {
// We can't invoke the original function because dlsym tries to call pthread_getspecific
return const_cast<void*>(pthread_early_vals[key - thread_key - 1]);
}
}
int pthread_setspecific(pthread_key_t key, const void* data) {
if (initialized) {
if (thread_key >= key) {
return o_pthread_setspecific(key, data);
}
Thread& thread = *static_cast<Thread*>(o_pthread_getspecific(thread_key));
thread.set_specific(key - thread_key - 1, data);
return 0;
} else {
pthread_early_vals[key - thread_key - 1] = data;
return 0;
}
}
static pthread_key_create_t& dyn_pthread_key_create() {
did_hook_pthreads = true;
if (o_pthread_key_create == NULL) {
o_pthread_key_create = (pthread_key_create_t*)dlsym(RTLD_NEXT, "pthread_key_create");
}
return *o_pthread_key_create;
}
int pthread_key_create(pthread_key_t* key, pthread_dtor_t dtor) {
if (initialized) {
Thread& thread = *static_cast<Thread*>(o_pthread_getspecific(thread_key));
thread.key_create(key, dtor);
*key += thread_key + 1;
return 0;
} else {
if (!did_reserve_key) {
did_reserve_key = true;
dyn_pthread_key_create()(&thread_key, Thread::free);
prev_synthetic_key = thread_key;
}
*key = ++prev_synthetic_key;
pthread_early_dtors[*key] = dtor;
assert(prev_synthetic_key < MAX_EARLY_KEYS);
return 0;
}
}
/**
* Other pthread-related hooks.
*/
// Entry point for pthread_create. We need this to record the Thread in real TLS.
void thread_trampoline(void** args_vector) {
void* (*entry)(void*) = (void*(*)(void*))args_vector[0];
void* arg = args_vector[1];
Thread& thread = *static_cast<Thread*>(args_vector[2]);
delete[] args_vector;
o_pthread_setspecific(thread_key, &thread);
entry(arg);
}
int pthread_create(pthread_t* handle, const pthread_attr_t* attr, void* (*entry)(void*), void* arg) {
assert(initialized);
void** args_vector = new void*[3];
args_vector[0] = (void*)entry;
args_vector[1] = arg;
Thread* thread = new Thread;
args_vector[2] = thread;
*handle = (pthread_t)thread;
return o_pthread_create(
&thread->handle, attr, (void* (*)(void*))thread_trampoline, (void*)args_vector);
}
int pthread_key_delete(pthread_key_t key) {
assert(initialized);
if (thread_key >= key) {
return o_pthread_key_delete(key);
}
Thread& thread = *static_cast<Thread*>(o_pthread_getspecific(thread_key));
thread.key_delete(key - thread_key - 1);
return 0;
}
int pthread_equal(pthread_t left, pthread_t right) {
return left == right;
}
int pthread_join(pthread_t thread, void** retval) {
assert(initialized);
// pthread_join should return EDEADLK if you try to join with yourself..
return pthread_join(reinterpret_cast<Thread*>(thread)->handle, retval);
}
pthread_t pthread_self() {
assert(initialized);
Thread& thread = *static_cast<Thread*>(o_pthread_getspecific(thread_key));
return (pthread_t)thread.current_fiber;
}
/**
* Initialization of this library. By the time we make it here the heap should be good to go. Also
* it's possible the TLS functions have been called, so we need to clean up that mess.
*/
class Loader {
public: Loader() {
// Grab hooks to the real version of all hooked functions.
o_pthread_create = (int(*)(pthread_t*, const pthread_attr_t*, void* (*)(void*), void*))dlsym(RTLD_NEXT, "pthread_create");
o_pthread_key_delete = (int(*)(pthread_key_t))dlsym(RTLD_NEXT, "pthread_key_delete");
o_pthread_equal = (int(*)(pthread_t, pthread_t))dlsym(RTLD_NEXT, "pthread_equal");
o_pthread_getspecific = (void*(*)(pthread_key_t))dlsym(RTLD_NEXT, "pthread_getspecific");
o_pthread_join = (int(*)(pthread_key_t, void**))dlsym(RTLD_NEXT, "pthread_join");
o_pthread_self = (pthread_t(*)(void))dlsym(RTLD_NEXT, "pthread_self");
o_pthread_setspecific = (int(*)(pthread_key_t, const void*))dlsym(RTLD_NEXT, "pthread_setspecific");
dyn_pthread_key_create();
// Create a real TLS key to store the handle to Thread.
if (!did_reserve_key) {
did_reserve_key = true;
o_pthread_key_create(&thread_key, Thread::free);
prev_synthetic_key = thread_key;
}
Thread* thread = new Thread;
thread->handle = o_pthread_self();
o_pthread_setspecific(thread_key, thread);
// Put all the data from the fake pthread_setspecific into FLS
initialized = true;
for (size_t ii = thread_key + 1; ii <= prev_synthetic_key; ++ii) {
pthread_key_t tmp;
thread->key_create(&tmp, pthread_early_dtors[ii]);
assert(tmp == ii - thread_key - 1);
thread->set_specific(tmp, pthread_early_vals[ii]);
}
// Undo fiber-shim so that child processes don't get shimmed as well. This also seems to prevent
// this library from being loaded multiple times.
setenv("DYLD_INSERT_LIBRARIES", "", 1);
setenv("LD_PRELOAD", "", 1);
}
};
Loader loader;
Jump to Line
Something went wrong with that request. Please try again.