Skip to content

Commit

Permalink
Add first-class WorkerdApi binding
Browse files Browse the repository at this point in the history
  • Loading branch information
nhynes committed Mar 6, 2024
1 parent 08f318d commit 019b0e3
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 89 deletions.
31 changes: 31 additions & 0 deletions src/workerd/api/workerd.c++
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include "workerd.h"
#include <kj/compat/http.h>
#include <capnp/compat/json.h>
#include <workerd/util/http-util.h>

namespace workerd::api {

kj::Promise<uint> doCreateWorkerRequest(kj::Own<kj::HttpClient> client, kj::String serializedArgs) {
auto& context = IoContext::current();
auto headers = kj::HttpHeaders(context.getHeaderTable());
auto req = client->request(kj::HttpMethod::POST, "http://workerd.local/workers"_kjc, headers, serializedArgs.size());
co_await req.body->write(serializedArgs.begin(), serializedArgs.size());
auto res = co_await req.response;
auto resBody = co_await res.body->readAllText();
if (res.statusCode >= 400) {
JSG_FAIL_REQUIRE(Error, resBody);
}
co_return atoi(resBody.cStr());
}

jsg::Promise<jsg::Ref<Fetcher>> WorkerdApi::newWorker(jsg::Lock& js, jsg::JsValue args) {
auto& context = IoContext::current();
kj::String serializedArgs = args.toJson(js);
auto client = context.getHttpClient(serviceChannel, true, kj::none, "create_worker"_kjc);
auto promise = doCreateWorkerRequest(kj::mv(client), kj::mv(serializedArgs));
return context.awaitIo(js, kj::mv(promise), [](jsg::Lock& js, uint chan) {
return jsg::alloc<Fetcher>(chan, Fetcher::RequiresHostAndProtocol::NO, true);
});
}

} // namespace workerd::api
24 changes: 24 additions & 0 deletions src/workerd/api/workerd.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once

#include <workerd/jsg/jsg.h>
#include <workerd/api/http.h>

namespace workerd::api {

// A special binding object that allows for dynamic evaluation.
class WorkerdApi: public jsg::Object {
public:
explicit WorkerdApi(uint serviceChannel): serviceChannel(serviceChannel) {}

jsg::Promise<jsg::Ref<Fetcher>> newWorker(jsg::Lock& js, jsg::JsValue args);

JSG_RESOURCE_TYPE(WorkerdApi) {
JSG_METHOD(newWorker);
}

private:
uint serviceChannel;
};
#define EW_WORKERD_API_ISOLATE_TYPES \
api::WorkerdApi
} // namespace workerd::api
180 changes: 92 additions & 88 deletions src/workerd/server/server.c++
Original file line number Diff line number Diff line change
Expand Up @@ -2334,14 +2334,15 @@ kj::Array<kj::byte> measureConfig(config::Worker::Reader& config) {
return digest;
}

class Server::WorkerdApiService final: public Service, private WorkerInterface {
class Server::WorkerdApiService final: public Service {
// Service used when the service is configured as network service.

public:
WorkerdApiService(Server& server): server(server) {}

kj::Own<WorkerInterface> startRequest(IoChannelFactory::SubrequestMetadata metadata) override {
return { this, kj::NullDisposer::instance };
auto& context = IoContext::current();
return kj::heap<WorkerdApiRequestHandler>(*this, context.getWorker());
}

bool hasHandler(kj::StringPtr handlerName) override {
Expand All @@ -2351,102 +2352,102 @@ public:
private:
Server& server;

kj::Promise<void> request(
kj::HttpMethod method, kj::StringPtr url, const kj::HttpHeaders& headers,
kj::AsyncInputStream& requestBody, kj::HttpService::Response& response) override {
if (url == "http://workerd.local/workers") {
return requestBody.readAllText().then([this, &headers, &response](auto confJson) {
capnp::MallocMessageBuilder confArena;
capnp::JsonCodec json;
json.handleByAnnotation<config::NewWorker>();
auto conf = confArena.initRoot<config::NewWorker>();
json.decode(confJson, conf);
class WorkerdApiRequestHandler final: public WorkerInterface {
public:
WorkerdApiRequestHandler(WorkerdApiService& parent, const Worker& requester)
: parent(parent), requester(requester) {}

kj::String id = workerd::randomUUID(kj::none);
kj::Promise<void> request(
kj::HttpMethod method, kj::StringPtr url, const kj::HttpHeaders& headers,
kj::AsyncInputStream& requestBody, kj::HttpService::Response& response) override {
if (url != "http://workerd.local/workers") {
auto out = response.send(404, "Not Found", headers, kj::none);
auto errMsg = "Unknown workerd API endpoint"_kjc.asBytes();
co_await out->write(errMsg.begin(), errMsg.size());
co_return;
}

server.actorConfigs.insert(kj::str(id), {});
auto confJson = co_await requestBody.readAllText();

kj::Maybe<kj::Array<kj::byte>> expectedMeasurement = kj::none;
if (conf.hasExpectedMeasurement()) {
auto res = kj::decodeHex(conf.getExpectedMeasurement());
if (res.hadErrors) {
auto out = response.send(400, "Bad Request", headers, kj::none);
auto errMsg = "invalid expected measurement"_kjc.asBytes();
return out->write(errMsg.begin(), errMsg.size());
}
expectedMeasurement = kj::mv(res);
}
capnp::MallocMessageBuilder confArena;
capnp::JsonCodec json;
json.handleByAnnotation<config::NewWorker>();
auto conf = confArena.initRoot<config::NewWorker>();
json.decode(confJson, conf);

kj::Maybe<kj::String> configError = kj::none;
auto workerService = server.makeWorker(
id, conf.getWorker().asReader(), {},
[&configError](auto err) { configError = kj::mv(err); },
kj::mv(expectedMeasurement)
);
KJ_IF_SOME(err, configError) {
throw KJ_EXCEPTION(FAILED, err);
}
auto& worker = server.services.insert(kj::str(id), kj::mv(workerService)).value;
worker->link();
WorkerService* requesterService;
KJ_IF_SOME(svc, parent.server.services.find(requester.getName())) {
requesterService = &kj::downcast<WorkerService>(*svc);
} else {
auto out = response.send(500, "Internal Server Error", headers, kj::none);
auto errMsg = "unable to locate requester service"_kjc.asBytes();
co_await out->write(errMsg.begin(), errMsg.size());
co_return;
}

auto resMessage = kj::heap<capnp::MallocMessageBuilder>(); // TODO: not alloc
auto res = resMessage->initRoot<config::ServiceDesignator>();
res.setName(id);
auto resJson = json.encode(res);
kj::String id = workerd::randomUUID(kj::none);

auto out = response.send(201, "Created", headers, kj::none);
return out->write(resJson.begin(), resJson.size()).attach(kj::mv(out), kj::mv(resJson));
});
} else if (url.startsWith("http://workerd.local/workers/"_kjc) &&
url.endsWith("/events/scheduled")) {
auto workerId = url.slice(29, 29 + 36);
return requestBody.readAllText().then([this, workerId = kj::mv(workerId),
&headers, &response](auto confJson) {
capnp::MallocMessageBuilder confArena;
capnp::JsonCodec json;
json.handleByAnnotation<rpc::Trace::ScheduledEventInfo>();
auto event = confArena.initRoot<rpc::Trace::ScheduledEventInfo>();
json.decode(confJson, event);

KJ_IF_SOME(svc, server.services.find(workerId)) {
IoChannelFactory::SubrequestMetadata metadata;
auto worker = svc->startRequest(kj::mv(metadata));
kj::Date scheduledTime = kj::UNIX_EPOCH +
static_cast<long long>(event.getScheduledTime()) * kj::MILLISECONDS;
auto cron = event.getCron();
return worker->runScheduled(scheduledTime, cron)
.then([&response, &headers](auto scheduledResult) {
return response.send(204, "No Content", headers, kj::none)->write(nullptr, 0);
}).attach(kj::mv(cron));
} else {
return response.send(404, "Not Found", headers, kj::none)->write(nullptr, 0);
}
});
} else {
return response.send(404, "Not Found", headers, kj::none)->write(nullptr, 0);
parent.server.actorConfigs.insert(kj::str(id), {});

kj::Maybe<kj::Array<kj::byte>> expectedMeasurement = kj::none;
if (conf.hasExpectedMeasurement()) {
auto res = kj::decodeHex(conf.getExpectedMeasurement());
if (res.hadErrors) {
auto out = response.send(400, "Bad Request", headers, kj::none);
auto errMsg = "invalid expected measurement"_kjc.asBytes();
co_await out->write(errMsg.begin(), errMsg.size());
co_return;
}
expectedMeasurement = kj::mv(res);
}

kj::Maybe<kj::String> configError = kj::none;
auto workerService = parent.server.makeWorker(
id, conf.getWorker().asReader(), {},
[&configError](auto err) { configError = kj::mv(err); },
kj::mv(expectedMeasurement)
);
KJ_IF_SOME(err, configError) {
auto out = response.send(400, "Bad Request", headers, kj::none);
auto errMsg = kj::str(err);
co_await out->write(errMsg.begin(), errMsg.size());
co_return;
}
auto& worker = parent.server.services.insert(kj::str(id), kj::mv(workerService)).value;
worker->link();

uint newWorkerChannel = requesterService->addChannel(worker);

auto resMsg = kj::str(newWorkerChannel);
auto out = response.send(201, "Created", headers, kj::none);
co_await out->write(resMsg.begin(), resMsg.size()).attach(kj::mv(out), kj::mv(resMsg));
co_return;
}
}

kj::Promise<void> connect(
kj::StringPtr host, const kj::HttpHeaders& headers, kj::AsyncIoStream& connection,
ConnectResponse& tunnel, kj::HttpConnectSettings settings) override {
throwUnsupported();
}
kj::Promise<void> connect(
kj::StringPtr host, const kj::HttpHeaders& headers, kj::AsyncIoStream& connection,
ConnectResponse& tunnel, kj::HttpConnectSettings settings) override {
throwUnsupported();
}
void prewarm(kj::StringPtr url) override {}
kj::Promise<ScheduledResult> runScheduled(kj::Date scheduledTime, kj::StringPtr cron) override {
throwUnsupported();
}
kj::Promise<AlarmResult> runAlarm(kj::Date scheduledTime, uint32_t retryCount) override {
throwUnsupported();
}
kj::Promise<CustomEvent::Result> customEvent(kj::Own<CustomEvent> event) override {
throwUnsupported();
}

void prewarm(kj::StringPtr url) override {}
kj::Promise<ScheduledResult> runScheduled(kj::Date scheduledTime, kj::StringPtr cron) override {
throwUnsupported();
}
kj::Promise<AlarmResult> runAlarm(kj::Date scheduledTime, uint32_t retryCount) override {
throwUnsupported();
}
kj::Promise<CustomEvent::Result> customEvent(kj::Own<CustomEvent> event) override {
throwUnsupported();
}
[[noreturn]] void throwUnsupported() {
JSG_FAIL_REQUIRE(Error, "WorkerdApiService does not support this event type.");
}

[[noreturn]] void throwUnsupported() {
JSG_FAIL_REQUIRE(Error, "WorkerdApiService does not support this event type.");
}
private:
WorkerdApiService& parent;
const Worker& requester;
};
};

// =======================================================================================
Expand Down Expand Up @@ -2609,6 +2610,9 @@ static kj::Maybe<WorkerdApi::Global> createBinding(
binding.getService(),
kj::mv(errorContext)
});
if (binding.getService().getName() == "@workerd") {
return makeGlobal(Global::WorkerdApi { .subrequestChannel = channel });
}
return makeGlobal(Global::Fetcher {
.channel = channel,
.requiresHost = true,
Expand Down
8 changes: 8 additions & 0 deletions src/workerd/server/workerd-api.c++
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <workerd/api/trace.h>
#include <workerd/api/unsafe.h>
#include <workerd/api/urlpattern.h>
#include <workerd/api/workerd.h>
#include <workerd/api/node/node.h>
#include <workerd/io/promise-wrapper.h>
#include <workerd/util/thread-scopes.h>
Expand Down Expand Up @@ -70,6 +71,7 @@ JSG_DECLARE_ISOLATE_TYPE(JsgWorkerdIsolate,
EW_CACHE_ISOLATE_TYPES,
EW_CRYPTO_ISOLATE_TYPES,
EW_NSM_ISOLATE_TYPES,
EW_WORKERD_API_ISOLATE_TYPES,
EW_ENCODING_ISOLATE_TYPES,
EW_FORMDATA_ISOLATE_TYPES,
EW_HTML_REWRITER_ISOLATE_TYPES,
Expand Down Expand Up @@ -699,6 +701,9 @@ static v8::Local<v8::Value> createBindingValue(
KJ_CASE_ONEOF(nsm, Global::NitroSecureModule) {
value = lock.wrap(context, jsg::alloc<api::NitroSecureModule>());
}
KJ_CASE_ONEOF(workerdApi, Global::WorkerdApi) {
value = lock.wrap(context, jsg::alloc<api::WorkerdApi>(workerdApi.subrequestChannel));
}
}

return value;
Expand Down Expand Up @@ -785,6 +790,9 @@ WorkerdApi::Global WorkerdApi::Global::clone() const {
KJ_CASE_ONEOF(nsm, Global::NitroSecureModule) {
result.value = Global::NitroSecureModule {};
}
KJ_CASE_ONEOF(workerdApi, Global::WorkerdApi) {
result.value = workerdApi.clone();
}
}

return result;
Expand Down
11 changes: 10 additions & 1 deletion src/workerd/server/workerd-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,19 @@ class WorkerdApi final: public Worker::Api {
};
struct UnsafeEval {};
struct NitroSecureModule {};
struct WorkerdApi {
uint subrequestChannel;

WorkerdApi clone() const {
return WorkerdApi {
.subrequestChannel = subrequestChannel,
};
}
};
kj::String name;
kj::OneOf<Json, Fetcher, KvNamespace, R2Bucket, R2Admin, CryptoKey, EphemeralActorNamespace,
DurableActorNamespace, QueueBinding, kj::String, kj::Array<byte>, Wrapped,
AnalyticsEngine, Hyperdrive, UnsafeEval, NitroSecureModule> value;
AnalyticsEngine, Hyperdrive, UnsafeEval, WorkerdApi, NitroSecureModule> value;

Global clone() const;
};
Expand Down

0 comments on commit 019b0e3

Please sign in to comment.