Skip to content

Commit

Permalink
fix IO re-start (#192)
Browse files Browse the repository at this point in the history
* fix IO re-start

- start_read: set read_cb before configuring IO
- try_write: prevent going before previous writes
- make sure to watch for UV_WRITABLE after queuing write request
  • Loading branch information
ekoby committed Dec 2, 2023
1 parent 01f3c5d commit 1b30ab5
Showing 1 changed file with 58 additions and 41 deletions.
99 changes: 58 additions & 41 deletions src/tlsuv.c
Expand Up @@ -40,10 +40,7 @@
#endif

static uv_os_sock_t new_socket(const struct addrinfo *addr);

static void tcp_connect_cb(uv_connect_t* req, int status);
static void on_clt_io(uv_poll_t *, int, int);
static ssize_t try_write(tlsuv_stream_t *, uv_buf_t *);

static tls_context *DEFAULT_TLS = NULL;

Expand Down Expand Up @@ -216,8 +213,8 @@ static void process_connect(tlsuv_stream_t *clt, int status) {
if (status != 0) {
UM_LOG(ERR, "failed connect: %d/%s", status, uv_strerror(status));
clt->conn_req = NULL;
req->cb(req, status);
uv_poll_stop(&clt->watcher);
req->cb(req, status);
return;
}

Expand All @@ -236,22 +233,52 @@ static void process_connect(tlsuv_stream_t *clt, int status) {
const char *error = clt->tls_engine->strerror(clt->tls_engine);
UM_LOG(ERR, "TLS handshake failed: %s", error);
clt->conn_req = NULL;
req->cb(req, UV_ECONNABORTED);
uv_poll_stop(&clt->watcher);
req->cb(req, UV_ECONNABORTED);
return;
}

if (rc == TLS_HS_COMPLETE) {
UM_LOG(DEBG, "handshake completed");
clt->conn_req = NULL;
req->cb(req, 0);
start_io(clt);
req->cb(req, 0);
} else {
// wait for incoming handshake messages
uv_poll_start(&clt->watcher, UV_READABLE, on_clt_io);
}
}

static ssize_t write_req(tlsuv_stream_t *clt, uv_buf_t *buf) {
int rc = clt->tls_engine->write(clt->tls_engine, buf->base, buf->len);
if (rc > 0) {
return rc;
}

if (rc == TLS_ERR) {
UM_LOG(WARN, "tls connection error: %s", clt->tls_engine->strerror(clt->tls_engine));
return UV_ECONNABORTED;
}

if (rc == TLS_AGAIN) {
return UV_EAGAIN;
}

return UV_EINVAL;
}

static void fail_pending_reqs(tlsuv_stream_t *clt, int err) {
while(!TAILQ_EMPTY(&clt->queue)) {
tlsuv_write_t *req = TAILQ_FIRST(&clt->queue);
TAILQ_REMOVE(&clt->queue, req, _next);
clt->queue_len -= 1;

req->wr->cb(req->wr, (int)err);
free(req);

}
}

static void process_outbound(tlsuv_stream_t *clt) {
tlsuv_write_t *req;
ssize_t ret;
Expand All @@ -262,7 +289,7 @@ static void process_outbound(tlsuv_stream_t *clt) {
}

req = TAILQ_FIRST(&clt->queue);
ret = tlsuv_stream_try_write(clt, &req->buf);
ret = write_req(clt, &req->buf);
if (ret > 0) {
req->buf.base += ret;
req->buf.len -= ret;
Expand All @@ -287,13 +314,7 @@ static void process_outbound(tlsuv_stream_t *clt) {

// error handling
// fail all pending requests
do {
clt->queue_len -= 1;
TAILQ_REMOVE(&clt->queue, req, _next);
req->wr->cb(req->wr, (int)ret);
free(req);
req = TAILQ_FIRST(&clt->queue);
} while (!TAILQ_EMPTY(&clt->queue));
fail_pending_reqs(clt, (int)ret);
}

static void on_clt_io(uv_poll_t *p, int status, int events) {
Expand All @@ -305,6 +326,7 @@ static void on_clt_io(uv_poll_t *p, int status, int events) {
}

if (status != 0) {
UM_LOG(WARN, "IO failed: %d/%s", status, uv_strerror(status));
if (clt->read_cb) {
uv_buf_t buf;
clt->alloc_cb((uv_handle_t *) clt, 32 * 1024, &buf);
Expand Down Expand Up @@ -340,6 +362,7 @@ static void on_clt_io(uv_poll_t *p, int status, int events) {

if (rc == TLS_ERR) {
clt->read_cb((uv_stream_t *)clt, UV_ECONNABORTED, &buf);
fail_pending_reqs(clt, UV_ECONNABORTED);
return;
}

Expand Down Expand Up @@ -461,10 +484,13 @@ int tlsuv_stream_read_start(tlsuv_stream_t *clt, uv_alloc_cb alloc_cb, uv_read_c
return UV_EALREADY;
}

int rc = uv_poll_start(&clt->watcher, UV_READABLE, on_clt_io);
if (rc == 0) {
clt->alloc_cb = alloc_cb;
clt->read_cb = read_cb;
clt->alloc_cb = alloc_cb;
clt->read_cb = read_cb;

int rc = start_io(clt);
if (rc != 0) {
clt->alloc_cb = NULL;
clt->read_cb = NULL;
}
return rc;
}
Expand All @@ -484,21 +510,12 @@ int tlsuv_stream_read_stop(tlsuv_stream_t *clt) {
}

int tlsuv_stream_try_write(tlsuv_stream_t *clt, uv_buf_t *buf) {
int rc = clt->tls_engine->write(clt->tls_engine, buf->base, buf->len);
if (rc > 0) {
return rc;
}

if (rc == TLS_ERR) {
UM_LOG(WARN, "tls connection error: %s", clt->tls_engine->strerror(clt->tls_engine));
return UV_ECONNABORTED;
}

if (rc == TLS_AGAIN) {
// do not allow to cut the line
if (!TAILQ_EMPTY(&clt->queue)) {
return UV_EAGAIN;
}

return UV_EINVAL;
return (int) write_req(clt, buf);
}

int tlsuv_stream_write(uv_write_t *req, tlsuv_stream_t *clt, uv_buf_t *buf, uv_write_cb cb) {
Expand All @@ -524,21 +541,21 @@ int tlsuv_stream_write(uv_write_t *req, tlsuv_stream_t *clt, uv_buf_t *buf, uv_w
return (int)count;
}

// successfully wrote the whole request
if (count == buf->len) {
// successfully wrote the whole request
cb(req, 0);
} else {
// queue request
tlsuv_write_t *wr = malloc(sizeof(*wr));
wr->wr = req;
wr->buf = uv_buf_init(buf->base + count, buf->len - count);

clt->queue_len += 1;
TAILQ_INSERT_TAIL(&clt->queue, wr, _next);
UM_LOG(INFO, "qlen = %zd", clt->queue_len);
return 0;
}

return 0;
// queue request or whatever left
tlsuv_write_t *wr = malloc(sizeof(*wr));
wr->wr = req;
wr->buf = uv_buf_init(buf->base + count, buf->len - count);
clt->queue_len += 1;
TAILQ_INSERT_TAIL(&clt->queue, wr, _next);

// make sure to re-arm IO after queuing request
return start_io(clt);
}

int tlsuv_stream_free(tlsuv_stream_t *clt) {
Expand Down

0 comments on commit 1b30ab5

Please sign in to comment.