From c0d22b8f4a45b2f98f515a30c40250a0949bc82e Mon Sep 17 00:00:00 2001 From: Trond Norbye Date: Fri, 24 Feb 2012 11:40:00 +0100 Subject: [PATCH] Add support for SASL auth in mcstat Change-Id: I5a309e131e80c3a89c94a879e2a316549ea1b8a4 --- programs/mcstat.c | 109 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 97 insertions(+), 12 deletions(-) diff --git a/programs/mcstat.c b/programs/mcstat.c index 22fb741b..772a54d0 100644 --- a/programs/mcstat.c +++ b/programs/mcstat.c @@ -9,18 +9,88 @@ #include #include +static void retry_send(int sock, const void* buf, size_t len); +static void retry_recv(int sock, void *buf, size_t len); + +#if defined(ENABLE_ISASL) || defined(ENABLE_SASL) +static int do_sasl_auth(int sock, const char *user, const char *pass) +{ + /* + * For now just shortcut the SASL phase by requesting a "PLAIN" + * sasl authentication. + */ + size_t ulen = strlen(user) + 1; + size_t plen = pass ? strlen(pass) + 1 : 1; + size_t tlen = ulen + plen + 1; + + protocol_binary_request_stats request = { + .message.header.request = { + .magic = PROTOCOL_BINARY_REQ, + .opcode = PROTOCOL_BINARY_CMD_SASL_AUTH, + .keylen = htons(5), + .bodylen = htonl(5 + tlen) + } + }; + + retry_send(sock, &request, sizeof(request)); + retry_send(sock, "PLAIN", 5); + retry_send(sock, "", 1); + retry_send(sock, user, ulen); + if (pass) { + retry_send(sock, pass, plen); + } else { + retry_send(sock, "", 1); + } + + protocol_binary_response_no_extras response; + retry_recv(sock, &response, sizeof(response.bytes)); + uint32_t vallen = ntohl(response.message.header.response.bodylen); + char *buffer = NULL; + + if (vallen != 0) { + buffer = malloc(vallen); + retry_recv(sock, buffer, vallen); + } + + protocol_binary_response_status status; + status = ntohs(response.message.header.response.status); + + if (status != PROTOCOL_BINARY_RESPONSE_SUCCESS) { + fprintf(stderr, "Failed to authenticate to the server\n"); + close(sock); + sock = -1; + return -1; + } + + free(buffer); + return sock; +} +#else +static int do_sasl_auth(int sock, const char *user, const char *pass) +{ + (void)sock; (void)user; (void)pass; + fprintf(stderr, "mcstat is not built with sasl support\n"); + return -1; +} +#endif + /** * Try to connect to the server * @param host the name of the server * @param port the port to connect to + * @param user the username to use for SASL auth (NULL = NO SASL) + * @param pass the password to use for SASL auth * @return a socket descriptor connected to host:port for success, -1 otherwise */ -static int connect_server(const char *hostname, const char *port) +static int connect_server(const char *hostname, const char *port, + const char *user, const char *pass) { struct addrinfo *ainfo = NULL; - struct addrinfo hints = { .ai_family = AF_UNSPEC, - .ai_protocol = IPPROTO_TCP, - .ai_socktype = SOCK_STREAM }; + struct addrinfo hints = { + .ai_flags = AI_ALL, + .ai_family = PF_UNSPEC, + .ai_socktype = SOCK_STREAM, + .ai_protocol = IPPROTO_TCP}; if (getaddrinfo(hostname, port, &hints, &ainfo) != 0) { return -1; @@ -30,7 +100,6 @@ static int connect_server(const char *hostname, const char *port) struct addrinfo *ai = ainfo; while (ai != NULL) { sock = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol); - if (sock != -1) { if (connect(sock, ai->ai_addr, ai->ai_addrlen) != -1) { break; @@ -42,6 +111,15 @@ static int connect_server(const char *hostname, const char *port) } freeaddrinfo(ainfo); + + if (sock == -1) { + fprintf(stderr, "Failed to connect to memcached server (%s:%s): %s\n", + hostname, port, strerror(errno)); + } else if (user != NULL && do_sasl_auth(sock, user, pass) == -1) { + close(sock); + sock = -1; + } + return sock; } @@ -79,7 +157,8 @@ static void retry_send(int sock, const void* buf, size_t len) * @param buf buffer to store data to * @param len length of data to receive */ -static void retry_recv(int sock, void *buf, size_t len) { +static void retry_recv(int sock, void *buf, size_t len) +{ if (len == 0) { return; } @@ -180,12 +259,14 @@ int main(int argc, char **argv) const char * const default_ports[] = { "memcache", "11211", NULL }; const char *port = NULL; const char *host = NULL; + const char *user = NULL; + const char *pass = NULL; char *ptr; /* Initialize the socket subsystem */ initialize_sockets(); - while ((cmd = getopt(argc, argv, "h:p:")) != EOF) { + while ((cmd = getopt(argc, argv, "h:p:u:P:")) != EOF) { switch (cmd) { case 'h' : host = optarg; @@ -198,9 +279,15 @@ int main(int argc, char **argv) case 'p': port = optarg; break; + case 'u' : + user = optarg; + break; + case 'P': + pass = optarg; + break; default: fprintf(stderr, - "Usage mcstat [-h host[:port]] [-p port] [statkey]*\n"); + "Usage mcstat [-h host[:port]] [-p port] [-u user] [-p pass] [statkey]*\n"); return 1; } } @@ -214,15 +301,13 @@ int main(int argc, char **argv) int ii = 0; do { port = default_ports[ii++]; - sock = connect_server(host, port); + sock = connect_server(host, port, user, pass); } while (sock == -1 && default_ports[ii] != NULL); } else { - sock = connect_server(host, port); + sock = connect_server(host, port, user, pass); } if (sock == -1) { - fprintf(stderr, "Failed to connect to memcached server (%s:%s): %s\n", - host, port, strerror(errno)); return 1; }