diff --git a/prov/sockets/include/sock.h b/prov/sockets/include/sock.h index 33191574198..06273c9572e 100644 --- a/prov/sockets/include/sock.h +++ b/prov/sockets/include/sock.h @@ -210,6 +210,7 @@ struct sock_conn_listener { fastlock_t signal_lock; /* acquire before map lock */ pthread_t listener_thread; int do_listen; + bool removed_from_epollfd; }; struct sock_ep_cm_head { @@ -219,6 +220,7 @@ struct sock_ep_cm_head { pthread_t listener_thread; struct dlist_entry msg_list; int do_listen; + bool removed_from_epollfd; }; struct sock_domain { diff --git a/prov/sockets/src/sock_conn.c b/prov/sockets/src/sock_conn.c index 4cf919a6ef7..0d39956a825 100644 --- a/prov/sockets/src/sock_conn.c +++ b/prov/sockets/src/sock_conn.c @@ -332,6 +332,15 @@ static void *sock_conn_listener_thread(void *arg) } fastlock_acquire(&conn_listener->signal_lock); + if (conn_listener->removed_from_epollfd) { + /* The epoll set changed between calling wait and wait + * returning. Get an updated set of events to avoid + * possible use after free error. + */ + conn_listener->removed_from_epollfd = false; + goto skip; + } + for (i = 0; i < num_fds; i++) { conn_handle = ep_contexts[i]; @@ -360,6 +369,7 @@ static void *sock_conn_listener_thread(void *arg) fastlock_release(&ep_attr->cmap.lock); sock_pe_signal(ep_attr->domain->pe); } +skip: fastlock_release(&conn_listener->signal_lock); } @@ -393,6 +403,7 @@ int sock_conn_start_listener_thread(struct sock_conn_listener *conn_listener) } conn_listener->do_listen = 1; + conn_listener->removed_from_epollfd = false; ret = pthread_create(&conn_listener->listener_thread, NULL, sock_conn_listener_thread, conn_listener); if (ret < 0) { diff --git a/prov/sockets/src/sock_ep.c b/prov/sockets/src/sock_ep.c index 64813ebdc58..9f5145d8423 100644 --- a/prov/sockets/src/sock_ep.c +++ b/prov/sockets/src/sock_ep.c @@ -682,6 +682,7 @@ static int sock_ep_close(struct fid *fid) fastlock_acquire(&sock_ep->attr->domain->conn_listener.signal_lock); ofi_epoll_del(sock_ep->attr->domain->conn_listener.epollfd, sock_ep->attr->conn_handle.sock); + sock_ep->attr->domain->conn_listener.removed_from_epollfd = true; fastlock_release(&sock_ep->attr->domain->conn_listener.signal_lock); ofi_close_socket(sock_ep->attr->conn_handle.sock); sock_ep->attr->conn_handle.do_listen = 0; diff --git a/prov/sockets/src/sock_ep_msg.c b/prov/sockets/src/sock_ep_msg.c index b3a4030522d..50498685c43 100644 --- a/prov/sockets/src/sock_ep_msg.c +++ b/prov/sockets/src/sock_ep_msg.c @@ -255,6 +255,7 @@ sock_ep_cm_unmonitor_handle_locked(struct sock_ep_cm_head *cm_head, SOCK_LOG_ERROR("failed to unmonitor fd %d: %d\n", handle->sock_fd, ret); handle->monitored = 0; + cm_head->removed_from_epollfd = true; } /* Multiple threads might call sock_ep_cm_unmonitor_handle() at the @@ -1174,6 +1175,15 @@ static void *sock_ep_cm_thread(void *arg) } pthread_mutex_lock(&cm_head->signal_lock); + if (cm_head->removed_from_epollfd) { + /* If we removed a socket from the epollfd after + * ofi_epoll_wait returned, we can hit a use after + * free error. If a change was made, we skip processing + * and recheck for events. + */ + cm_head->removed_from_epollfd = false; + goto skip; + } for (i = 0; i < num_fds; i++) { handle = ep_contexts[i]; @@ -1195,6 +1205,7 @@ static void *sock_ep_cm_thread(void *arg) assert(handle->sock_fd != INVALID_SOCKET); sock_ep_cm_handle_rx(cm_head, handle); } +skip: pthread_mutex_unlock(&cm_head->signal_lock); } return NULL; @@ -1230,6 +1241,7 @@ int sock_ep_cm_start_thread(struct sock_ep_cm_head *cm_head) } cm_head->do_listen = 1; + cm_head->removed_from_epollfd = false; ret = pthread_create(&cm_head->listener_thread, 0, sock_ep_cm_thread, cm_head); if (ret) {