Skip to content

Commit

Permalink
lib-master: master-login-auth - Use the connection API.
Browse files Browse the repository at this point in the history
  • Loading branch information
stephanbosch committed Feb 3, 2019
1 parent 5a11903 commit 4b676d0
Showing 1 changed file with 132 additions and 102 deletions.
234 changes: 132 additions & 102 deletions src/lib-master/master-login-auth.c
Expand Up @@ -12,6 +12,7 @@
#include "str.h"
#include "strescape.h"
#include "time-util.h"
#include "connection.h"
#include "master-interface.h"
#include "master-service.h"
#include "master-auth.h"
Expand All @@ -38,16 +39,15 @@ struct master_login_auth_request {
};

struct master_login_auth {
struct connection conn;
struct connection_list *clist;
pool_t pool;
const char *auth_socket_path;
int refcount;

const char *auth_socket_path;

struct timeval connect_time, handshake_time;

int fd;
struct io *io;
struct istream *input;
struct ostream *output;
struct timeout *to;

unsigned int id_counter;
Expand All @@ -59,14 +59,39 @@ struct master_login_auth {

unsigned int timeout_msecs;

bool connected:1;
bool request_auth_token:1;
bool version_received:1;
bool spid_received:1;
};

static void master_login_auth_connected(struct connection *_conn, bool success);
static int
master_login_auth_input_args(struct connection *_conn, const char *const *args);
static int
master_login_auth_handshake_line(struct connection *_conn, const char *line);
static void master_login_auth_destroy(struct connection *_conn);

static void master_login_auth_update_timeout(struct master_login_auth *auth);
static void master_login_auth_check_spids(struct master_login_auth *auth);

static const struct connection_vfuncs master_login_auth_vfuncs = {
.destroy = master_login_auth_destroy,
.handshake_line = master_login_auth_handshake_line,
.input_args = master_login_auth_input_args,
.client_connected = master_login_auth_connected,
};

static const struct connection_settings master_login_auth_set = {
.dont_send_version = TRUE,
.service_name_in = "auth-master",
.service_name_out = "auth-master",
.major_version = AUTH_MASTER_PROTOCOL_MAJOR_VERSION,
.minor_version = AUTH_MASTER_PROTOCOL_MINOR_VERSION,
.unix_client_connect_msecs = 1000,
.input_max_size = AUTH_MAX_INBUF_SIZE,
.output_max_size = (size_t)-1,
.client = TRUE,
};

struct master_login_auth *
master_login_auth_init(const char *auth_socket_path, bool request_auth_token)
{
Expand All @@ -79,10 +104,14 @@ master_login_auth_init(const char *auth_socket_path, bool request_auth_token)
auth->auth_socket_path = p_strdup(pool, auth_socket_path);
auth->request_auth_token = request_auth_token;
auth->refcount = 1;
auth->fd = -1;
hash_table_create_direct(&auth->requests, pool, 0);
auth->id_counter = i_rand_limit(32767) * 131072U;

auth->clist = connection_list_init(&master_login_auth_set,
&master_login_auth_vfuncs);
connection_init_client_unix(auth->clist, &auth->conn,
auth->auth_socket_path);

auth->timeout_msecs = 1000 * MASTER_AUTH_LOOKUP_TIMEOUT_SECS;
return auth;
}
Expand Down Expand Up @@ -115,38 +144,42 @@ request_internal_failure(struct master_login_auth *auth,
request_failure(auth, request, reason, MASTER_AUTH_ERRMSG_INTERNAL_FAILURE);
}

void master_login_auth_disconnect(struct master_login_auth *auth)
static void
master_login_auth_fail(struct master_login_auth *auth,
const char *reason) ATTR_NULL(2)
{
struct master_login_auth_request *request;

if (reason == NULL)
reason = "Disconnected from auth server, aborting";

auth->connected = FALSE;

while (auth->request_head != NULL) {
request = auth->request_head;
DLLIST2_REMOVE(&auth->request_head,
&auth->request_tail, request);

request_internal_failure(auth, request,
"Disconnected from auth server, aborting");
request_internal_failure(auth, request, reason);
i_free(request);
}
hash_table_clear(auth->requests, FALSE);

timeout_remove(&auth->to);
io_remove(&auth->io);
if (auth->fd != -1) {
i_stream_destroy(&auth->input);
o_stream_destroy(&auth->output);

net_disconnect(auth->fd);
auth->fd = -1;
}
auth->version_received = FALSE;
i_zero(&auth->connect_time);
i_zero(&auth->handshake_time);
}

void master_login_auth_disconnect(struct master_login_auth *auth)
{
connection_disconnect(&auth->conn);
master_login_auth_fail(auth, NULL);
}

static void master_login_auth_unref(struct master_login_auth **_auth)
{
struct master_login_auth *auth = *_auth;
struct connection_list *clist = auth->clist;

*_auth = NULL;

Expand All @@ -155,6 +188,8 @@ static void master_login_auth_unref(struct master_login_auth **_auth)
return;

hash_table_destroy(&auth->requests);
connection_deinit(&auth->conn);
connection_list_deinit(&clist);
pool_unref(&auth->pool);
}

Expand All @@ -174,6 +209,32 @@ void master_login_auth_set_timeout(struct master_login_auth *auth,
auth->timeout_msecs = msecs;
}

static void master_login_auth_destroy(struct connection *_conn)
{
struct master_login_auth *auth =
container_of(_conn, struct master_login_auth, conn);

auth->connected = FALSE;

switch (_conn->disconnect_reason) {
case CONNECTION_DISCONNECT_HANDSHAKE_FAILED:
master_login_auth_fail(auth,
"Handshake with auth service failed");
break;
case CONNECTION_DISCONNECT_BUFFER_FULL:
/* buffer full */
i_error("Auth server sent us too long line");
master_login_auth_fail(auth, NULL);
break;
default:
/* disconnected. stop accepting new connections, because in
default configuration we no longer have permissions to
connect back to auth-master */
master_service_stop_new_connections(master_service);
master_login_auth_fail(auth, NULL);
}
}

static unsigned int auth_get_next_timeout_msecs(struct master_login_auth *auth)
{
struct timeval expires;
Expand Down Expand Up @@ -222,6 +283,40 @@ static void master_login_auth_update_timeout(struct master_login_auth *auth)
}
}

static int
master_login_auth_handshake_line(struct connection *_conn, const char *line)
{
struct master_login_auth *auth =
container_of(_conn, struct master_login_auth, conn);
const char *const *tmp;
unsigned int major_version, minor_version;

tmp = t_strsplit_tabescaped(line);
if (!auth->conn.version_received && strcmp(tmp[0], "VERSION") == 0 &&
tmp[1] != NULL && tmp[2] != NULL) {
if (str_to_uint(tmp[1], &major_version) < 0 ||
str_to_uint(tmp[2], &minor_version) < 0) {
i_error("Auth server sent invalid version line: %s",
line);
return -1;
}

if (connection_verify_version(_conn, "auth-master",
major_version,
minor_version) < 0)
return -1;
return 0;
}
if (strcmp(tmp[0], "SPID") != 0 ||
str_to_pid(tmp[1], &auth->auth_server_pid) < 0) {
i_error("Auth server did not send valid SPID: %s", line);
return -1;
}

master_login_auth_check_spids(auth);
return 1;
}

static void
master_login_auth_request_remove(struct master_login_auth *auth,
struct master_login_auth_request *request)
Expand Down Expand Up @@ -321,10 +416,10 @@ master_login_auth_input_fail(struct master_login_auth *auth, unsigned int id,
}

static int
master_login_auth_input_args(struct master_login_auth *auth,
const char *const *args)
master_login_auth_input_args(struct connection *_conn, const char *const *args)
{
bool ret = TRUE;
struct master_login_auth *auth =
container_of(_conn, struct master_login_auth, conn);
unsigned int id;

if (args[0] == NULL || args[1] == NULL ||
Expand All @@ -341,108 +436,43 @@ master_login_auth_input_args(struct master_login_auth *auth,
master_login_auth_input_notfound(auth, id, &args[2]);
else if (strcmp(args[0], "FAIL") == 0)
master_login_auth_input_fail(auth, id, &args[2]);

if (auth->input == NULL) {
master_login_auth_disconnect(auth);
ret = FALSE;
}
master_login_auth_unref(&auth);

return (ret ? 0 : -1);
return 0;
}

static void master_login_auth_input(struct master_login_auth *auth)
static void master_login_auth_connected(struct connection *_conn, bool success)
{
const char *line;

switch (i_stream_read(auth->input)) {
case 0:
return;
case -1:
/* disconnected. stop accepting new connections, because in
default configuration we no longer have permissions to
connect back to auth-master */
master_service_stop_new_connections(master_service);
master_login_auth_disconnect(auth);
return;
case -2:
/* buffer full */
i_error("Auth server sent us too long line");
master_login_auth_disconnect(auth);
return;
}

if (!auth->version_received) {
line = i_stream_next_line(auth->input);
if (line == NULL)
return;

/* make sure the major version matches */
if (!str_begins(line, "VERSION\t") ||
!str_uint_equals(t_strcut(line + 8, '\t'),
AUTH_MASTER_PROTOCOL_MAJOR_VERSION)) {
i_error("Authentication server not compatible with "
"master process (mixed old and new binaries?)");
master_login_auth_disconnect(auth);
return;
}
auth->version_received = TRUE;
auth->handshake_time = ioloop_timeval;
}
if (!auth->spid_received) {
line = i_stream_next_line(auth->input);
if (line == NULL)
return;
struct master_login_auth *auth =
container_of(_conn, struct master_login_auth, conn);

if (!str_begins(line, "SPID\t") ||
str_to_pid(line + 5, &auth->auth_server_pid) < 0) {
i_error("Authentication server didn't "
"send valid SPID as expected: %s", line);
master_login_auth_disconnect(auth);
return;
}
auth->spid_received = TRUE;
master_login_auth_check_spids(auth);
}

while ((line = i_stream_next_line(auth->input)) != NULL) {
const char *const *args = t_strsplit_tabescaped(line);
int ret;
/* Cannot get here unless connect() was successful */
i_assert(success);

ret = master_login_auth_input_args(auth, args);
if (ret < 0)
break;
}
auth->connected = TRUE;
}

static int
master_login_auth_connect(struct master_login_auth *auth)
{
int fd;

i_assert(auth->fd == -1);
i_assert(!auth->connected);

fd = net_connect_unix_with_retries(auth->auth_socket_path, 1000);
if (fd == -1) {
if (connection_client_connect(&auth->conn) < 0) {
i_error("net_connect_unix(%s) failed: %m",
auth->auth_socket_path);
return -1;
}
io_loop_time_refresh();
auth->connect_time = ioloop_timeval;
auth->fd = fd;
auth->input = i_stream_create_fd(fd, AUTH_MAX_INBUF_SIZE);
auth->output = o_stream_create_fd(fd, (size_t)-1);
o_stream_set_no_error_handling(auth->output, TRUE);
auth->io = io_add(fd, IO_READ, master_login_auth_input, auth);
return 0;
}

static bool
auth_request_check_spid(struct master_login_auth *auth,
struct master_login_auth_request *req)
{
if (auth->auth_server_pid != req->auth_pid && auth->spid_received) {
if (auth->auth_server_pid != req->auth_pid &&
auth->conn.handshake_received) {
/* auth server was restarted. don't even attempt a login. */
i_warning("Auth server restarted (pid %u -> %u), aborting auth",
(unsigned int)req->auth_pid,
Expand Down Expand Up @@ -485,7 +515,7 @@ master_login_auth_send_request(struct master_login_auth *auth,
if (auth->request_auth_token)
str_append(str, "\trequest_auth_token");
str_append_c(str, '\n');
o_stream_nsend(auth->output, str_data(str), str_len(str));
o_stream_nsend(auth->conn.output, str_data(str), str_len(str));
}

void master_login_auth_request(struct master_login_auth *auth,
Expand All @@ -496,7 +526,7 @@ void master_login_auth_request(struct master_login_auth *auth,
struct master_login_auth_request *login_req;
unsigned int id;

if (auth->fd == -1) {
if (!auth->connected) {
if (master_login_auth_connect(auth) < 0) {
/* we couldn't connect to auth now,
so we probably can't in future either. */
Expand All @@ -505,7 +535,7 @@ void master_login_auth_request(struct master_login_auth *auth,
context);
return;
}
o_stream_nsend_str(auth->output,
o_stream_nsend_str(auth->conn.output,
t_strdup_printf("VERSION\t%u\t%u\n",
AUTH_MASTER_PROTOCOL_MAJOR_VERSION,
AUTH_MASTER_PROTOCOL_MINOR_VERSION));
Expand Down

0 comments on commit 4b676d0

Please sign in to comment.