Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix IO re-start #192

Merged
merged 2 commits into from Dec 2, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
99 changes: 58 additions & 41 deletions src/tlsuv.c
Expand Up @@ -40,10 +40,7 @@
#endif

static uv_os_sock_t new_socket(const struct addrinfo *addr);

static void tcp_connect_cb(uv_connect_t* req, int status);
static void on_clt_io(uv_poll_t *, int, int);
static ssize_t try_write(tlsuv_stream_t *, uv_buf_t *);

static tls_context *DEFAULT_TLS = NULL;

Expand Down Expand Up @@ -216,8 +213,8 @@ static void process_connect(tlsuv_stream_t *clt, int status) {
if (status != 0) {
UM_LOG(ERR, "failed connect: %d/%s", status, uv_strerror(status));
clt->conn_req = NULL;
req->cb(req, status);
uv_poll_stop(&clt->watcher);
req->cb(req, status);
return;
}

Expand All @@ -236,22 +233,52 @@ static void process_connect(tlsuv_stream_t *clt, int status) {
const char *error = clt->tls_engine->strerror(clt->tls_engine);
UM_LOG(ERR, "TLS handshake failed: %s", error);
clt->conn_req = NULL;
req->cb(req, UV_ECONNABORTED);
uv_poll_stop(&clt->watcher);
req->cb(req, UV_ECONNABORTED);
return;
}

if (rc == TLS_HS_COMPLETE) {
UM_LOG(DEBG, "handshake completed");
clt->conn_req = NULL;
req->cb(req, 0);
start_io(clt);
req->cb(req, 0);
} else {
// wait for incoming handshake messages
uv_poll_start(&clt->watcher, UV_READABLE, on_clt_io);
}
}

static ssize_t write_req(tlsuv_stream_t *clt, uv_buf_t *buf) {
int rc = clt->tls_engine->write(clt->tls_engine, buf->base, buf->len);
if (rc > 0) {
return rc;
}

if (rc == TLS_ERR) {
UM_LOG(WARN, "tls connection error: %s", clt->tls_engine->strerror(clt->tls_engine));
return UV_ECONNABORTED;
}

if (rc == TLS_AGAIN) {
return UV_EAGAIN;
}

return UV_EINVAL;
}

static void fail_pending_reqs(tlsuv_stream_t *clt, int err) {
while(!TAILQ_EMPTY(&clt->queue)) {
tlsuv_write_t *req = TAILQ_FIRST(&clt->queue);
TAILQ_REMOVE(&clt->queue, req, _next);
clt->queue_len -= 1;

req->wr->cb(req->wr, (int)err);
free(req);

}
}

static void process_outbound(tlsuv_stream_t *clt) {
tlsuv_write_t *req;
ssize_t ret;
Expand All @@ -262,7 +289,7 @@ static void process_outbound(tlsuv_stream_t *clt) {
}

req = TAILQ_FIRST(&clt->queue);
ret = tlsuv_stream_try_write(clt, &req->buf);
ret = write_req(clt, &req->buf);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there no concern for jumping the line here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here we're processing the head of the line

if (ret > 0) {
req->buf.base += ret;
req->buf.len -= ret;
Expand All @@ -287,13 +314,7 @@ static void process_outbound(tlsuv_stream_t *clt) {

// error handling
// fail all pending requests
do {
clt->queue_len -= 1;
TAILQ_REMOVE(&clt->queue, req, _next);
req->wr->cb(req->wr, (int)ret);
free(req);
req = TAILQ_FIRST(&clt->queue);
} while (!TAILQ_EMPTY(&clt->queue));
fail_pending_reqs(clt, (int)ret);
}

static void on_clt_io(uv_poll_t *p, int status, int events) {
Expand All @@ -305,6 +326,7 @@ static void on_clt_io(uv_poll_t *p, int status, int events) {
}

if (status != 0) {
UM_LOG(WARN, "IO failed: %d/%s", status, uv_strerror(status));
if (clt->read_cb) {
uv_buf_t buf;
clt->alloc_cb((uv_handle_t *) clt, 32 * 1024, &buf);
Expand Down Expand Up @@ -340,6 +362,7 @@ static void on_clt_io(uv_poll_t *p, int status, int events) {

if (rc == TLS_ERR) {
clt->read_cb((uv_stream_t *)clt, UV_ECONNABORTED, &buf);
fail_pending_reqs(clt, UV_ECONNABORTED);
return;
}

Expand Down Expand Up @@ -461,10 +484,13 @@ int tlsuv_stream_read_start(tlsuv_stream_t *clt, uv_alloc_cb alloc_cb, uv_read_c
return UV_EALREADY;
}

int rc = uv_poll_start(&clt->watcher, UV_READABLE, on_clt_io);
if (rc == 0) {
clt->alloc_cb = alloc_cb;
clt->read_cb = read_cb;
clt->alloc_cb = alloc_cb;
clt->read_cb = read_cb;

int rc = start_io(clt);
if (rc != 0) {
clt->alloc_cb = NULL;
clt->read_cb = NULL;
}
return rc;
}
Expand All @@ -484,21 +510,12 @@ int tlsuv_stream_read_stop(tlsuv_stream_t *clt) {
}

int tlsuv_stream_try_write(tlsuv_stream_t *clt, uv_buf_t *buf) {
int rc = clt->tls_engine->write(clt->tls_engine, buf->base, buf->len);
if (rc > 0) {
return rc;
}

if (rc == TLS_ERR) {
UM_LOG(WARN, "tls connection error: %s", clt->tls_engine->strerror(clt->tls_engine));
return UV_ECONNABORTED;
}

if (rc == TLS_AGAIN) {
// do not allow to cut the line
if (!TAILQ_EMPTY(&clt->queue)) {
return UV_EAGAIN;
}

return UV_EINVAL;
return (int) write_req(clt, buf);
}

int tlsuv_stream_write(uv_write_t *req, tlsuv_stream_t *clt, uv_buf_t *buf, uv_write_cb cb) {
Expand All @@ -524,21 +541,21 @@ int tlsuv_stream_write(uv_write_t *req, tlsuv_stream_t *clt, uv_buf_t *buf, uv_w
return (int)count;
}

// successfully wrote the whole request
if (count == buf->len) {
// successfully wrote the whole request
cb(req, 0);
} else {
// queue request
tlsuv_write_t *wr = malloc(sizeof(*wr));
wr->wr = req;
wr->buf = uv_buf_init(buf->base + count, buf->len - count);

clt->queue_len += 1;
TAILQ_INSERT_TAIL(&clt->queue, wr, _next);
UM_LOG(INFO, "qlen = %zd", clt->queue_len);
return 0;
}

return 0;
// queue request or whatever left
tlsuv_write_t *wr = malloc(sizeof(*wr));
wr->wr = req;
wr->buf = uv_buf_init(buf->base + count, buf->len - count);
clt->queue_len += 1;
TAILQ_INSERT_TAIL(&clt->queue, wr, _next);

// make sure to re-arm IO after queuing request
return start_io(clt);
}

int tlsuv_stream_free(tlsuv_stream_t *clt) {
Expand Down