diff --git a/include/rdma/providers/fi_peer.h b/include/rdma/providers/fi_peer.h index 093290d1b4c..9968027c3aa 100644 --- a/include/rdma/providers/fi_peer.h +++ b/include/rdma/providers/fi_peer.h @@ -173,16 +173,28 @@ struct fi_peer_rx_entry { size_t count; void **desc; void *peer_context; + void *peer_md; void *owner_context; struct iovec *iov; + uint32_t match_id; +}; + +struct fi_peer_match { + fi_addr_t addr; + uint64_t tag; + size_t size; + void *context; + uint32_t match_id; }; struct fi_ops_srx_owner { size_t size; - int (*get_msg)(struct fid_peer_srx *srx, fi_addr_t addr, - size_t size, struct fi_peer_rx_entry **entry); - int (*get_tag)(struct fid_peer_srx *srx, fi_addr_t addr, - uint64_t tag, struct fi_peer_rx_entry **entry); + int (*get_msg)(struct fid_peer_srx *srx, + struct fi_peer_match *match, + struct fi_peer_rx_entry **entry); + int (*get_tag)(struct fid_peer_srx *srx, + struct fi_peer_match *match, + struct fi_peer_rx_entry **entry); int (*queue_msg)(struct fi_peer_rx_entry *entry); int (*queue_tag)(struct fi_peer_rx_entry *entry); void (*foreach_unspec_addr)(struct fid_peer_srx *srx, @@ -197,6 +209,9 @@ struct fi_ops_srx_peer { int (*start_tag)(struct fi_peer_rx_entry *entry); int (*discard_msg)(struct fi_peer_rx_entry *entry); int (*discard_tag)(struct fi_peer_rx_entry *entry); + int (*addr_match)(fi_addr_t addr, struct fi_peer_match *match); + int (*mem_reg)(struct fid_ep *ep, struct iovec *iov, fi_addr_t addr, + void **md, uint32_t *match_id); }; struct fid_peer_srx { diff --git a/prov/efa/src/rdm/efa_rdm_msg.c b/prov/efa/src/rdm/efa_rdm_msg.c index 7ff4a116928..e71fbc3c988 100644 --- a/prov/efa/src/rdm/efa_rdm_msg.c +++ b/prov/efa/src/rdm/efa_rdm_msg.c @@ -764,6 +764,7 @@ struct efa_rdm_ope *efa_rdm_msg_alloc_rxe_for_msgrtm(struct efa_rdm_ep *ep, size_t data_size; int ret; int pkt_type; + struct fi_peer_match match = {0}; if ((*pkt_entry_ptr)->alloc_type == EFA_RDM_PKE_FROM_USER_BUFFER) { /* If a pkt_entry is constructred from user supplied buffer, @@ -782,7 +783,10 @@ struct efa_rdm_ope *efa_rdm_msg_alloc_rxe_for_msgrtm(struct efa_rdm_ep *ep, peer_srx = util_get_peer_srx(ep->peer_srx_ep); data_size = efa_rdm_pke_get_rtm_msg_length(*pkt_entry_ptr); - ret = peer_srx->owner_ops->get_msg(peer_srx, (*pkt_entry_ptr)->addr, data_size, &peer_rxe); + match.addr = (*pkt_entry_ptr)->addr; + match.size = data_size; + + ret = peer_srx->owner_ops->get_msg(peer_srx, &match, &peer_rxe); if (ret == FI_SUCCESS) { /* A matched rxe is found */ rxe = efa_rdm_msg_alloc_matched_rxe_for_rtm(ep, *pkt_entry_ptr, peer_rxe, ofi_op_msg); @@ -844,12 +848,14 @@ struct efa_rdm_ope *efa_rdm_msg_alloc_rxe_for_tagrtm(struct efa_rdm_ep *ep, struct efa_rdm_ope *rxe; int ret; int pkt_type; + struct fi_peer_match match = {0}; peer_srx = util_get_peer_srx(ep->peer_srx_ep); - ret = peer_srx->owner_ops->get_tag(peer_srx, (*pkt_entry_ptr)->addr, - efa_rdm_pke_get_rtm_tag(*pkt_entry_ptr), - &peer_rxe); + match.addr = (*pkt_entry_ptr)->addr; + match.tag = efa_rdm_pke_get_rtm_tag(*pkt_entry_ptr); + + ret = peer_srx->owner_ops->get_tag(peer_srx, &match, &peer_rxe); if (ret == FI_SUCCESS) { /* A matched rxe is found */ rxe = efa_rdm_msg_alloc_matched_rxe_for_rtm(ep, *pkt_entry_ptr, peer_rxe, ofi_op_tagged); diff --git a/prov/shm/src/smr_progress.c b/prov/shm/src/smr_progress.c index b30918e6986..b46d1916f4c 100644 --- a/prov/shm/src/smr_progress.c +++ b/prov/shm/src/smr_progress.c @@ -987,13 +987,16 @@ static int smr_progress_cmd_msg(struct smr_ep *ep, struct smr_cmd *cmd) { struct fid_peer_srx *peer_srx = smr_get_peer_srx(ep); struct fi_peer_rx_entry *rx_entry; - fi_addr_t addr; + struct fi_peer_match match_info = {0}; int ret; - addr = ep->region->map->peers[cmd->msg.hdr.id].fiaddr; + match_info.addr = ep->region->map->peers[cmd->msg.hdr.id].fiaddr; + match_info.context = NULL; + if (cmd->msg.hdr.op == ofi_op_tagged) { - ret = peer_srx->owner_ops->get_tag(peer_srx, addr, - cmd->msg.hdr.tag, &rx_entry); + match_info.tag = cmd->msg.hdr.tag; + ret = peer_srx->owner_ops->get_tag(peer_srx, &match_info, + &rx_entry); if (ret == -FI_ENOENT) { ret = smr_alloc_cmd_ctx(ep, rx_entry, cmd); if (ret) { @@ -1009,8 +1012,9 @@ static int smr_progress_cmd_msg(struct smr_ep *ep, struct smr_cmd *cmd) goto out; } } else { - ret = peer_srx->owner_ops->get_msg(peer_srx, addr, - cmd->msg.hdr.size, &rx_entry); + match_info.size = cmd->msg.hdr.size; + ret = peer_srx->owner_ops->get_msg(peer_srx, &match_info, + &rx_entry); if (ret == -FI_ENOENT) { ret = smr_alloc_cmd_ctx(ep, rx_entry, cmd); if (ret) { diff --git a/prov/sm2/src/sm2_progress.c b/prov/sm2/src/sm2_progress.c index 8d2b822907f..fb7ab53b250 100644 --- a/prov/sm2/src/sm2_progress.c +++ b/prov/sm2/src/sm2_progress.c @@ -520,6 +520,7 @@ static int sm2_progress_recv_msg(struct sm2_ep *ep, { struct fid_peer_srx *peer_srx = sm2_get_peer_srx(ep); struct fi_peer_rx_entry *rx_entry; + struct fi_peer_match match; struct sm2_av *sm2_av; fi_addr_t addr; int ret = 0; @@ -540,9 +541,12 @@ static int sm2_progress_recv_msg(struct sm2_ep *ep, sm2_av = container_of(ep->util_ep.av, struct sm2_av, util_av); addr = sm2_av->reverse_lookup[xfer_entry->hdr.sender_gid]; + memset(&match, 0, sizeof(match)); + match.addr = addr; + if (xfer_entry->hdr.op == ofi_op_tagged) { - ret = peer_srx->owner_ops->get_tag( - peer_srx, addr, xfer_entry->hdr.tag, &rx_entry); + match.tag = xfer_entry->hdr.tag; + ret = peer_srx->owner_ops->get_tag(peer_srx, &match, &rx_entry); if (ret == -FI_ENOENT) { xfer_entry->hdr.proto_flags |= SM2_UNEXP; ret = sm2_alloc_xfer_entry_ctx(ep, rx_entry, @@ -557,8 +561,8 @@ static int sm2_progress_recv_msg(struct sm2_ep *ep, goto out; } } else { - ret = peer_srx->owner_ops->get_msg( - peer_srx, addr, xfer_entry->hdr.size, &rx_entry); + match.size = xfer_entry->hdr.size; + ret = peer_srx->owner_ops->get_msg(peer_srx, &match, &rx_entry); if (ret == -FI_ENOENT) { xfer_entry->hdr.proto_flags |= SM2_UNEXP; ret = sm2_alloc_xfer_entry_ctx(ep, rx_entry, diff --git a/prov/util/src/util_attr.c b/prov/util/src/util_attr.c index 1665e2e5f1a..8a2b91b5955 100644 --- a/prov/util/src/util_attr.c +++ b/prov/util/src/util_attr.c @@ -406,6 +406,7 @@ int ofi_check_fabric_attr(const struct fi_provider *prov, * user's hints, if one is specified. */ if (prov_attr->prov_name && user_attr->prov_name && + user_attr->prov_name[0] != '^' && !strcasestr(user_attr->prov_name, prov_attr->prov_name)) { FI_INFO(prov, FI_LOG_CORE, "Requesting provider %s, skipping %s\n", diff --git a/prov/util/src/util_srx.c b/prov/util/src/util_srx.c index 6911acb9755..7c49af516c4 100644 --- a/prov/util/src/util_srx.c +++ b/prov/util/src/util_srx.c @@ -194,12 +194,15 @@ static int util_match_msg(struct fid_peer_srx *srx, fi_addr_t addr, size_t size, return ret; } -static int util_get_msg(struct fid_peer_srx *srx, fi_addr_t addr, - size_t size, struct fi_peer_rx_entry **rx_entry) +static int util_get_msg(struct fid_peer_srx *srx, + struct fi_peer_match *match_info, + struct fi_peer_rx_entry **rx_entry) { struct util_srx_ctx *srx_ctx; struct util_rx_entry *util_entry, *any_entry; struct slist *queue; + fi_addr_t addr = match_info->addr; + size_t size = match_info->size; srx_ctx = srx->ep_fid.fid.context; assert(ofi_genlock_held(srx_ctx->lock)); @@ -269,14 +272,17 @@ static int util_match_tag(struct fid_peer_srx *srx, fi_addr_t addr, return ret; } -static int util_get_tag(struct fid_peer_srx *srx, fi_addr_t addr, - uint64_t tag, struct fi_peer_rx_entry **rx_entry) +static int util_get_tag(struct fid_peer_srx *srx, + struct fi_peer_match *match_info, + struct fi_peer_rx_entry **rx_entry) { struct util_srx_ctx *srx_ctx; struct slist *queue; struct slist_entry *any_item, *any_prev; struct slist_entry *item, *prev; struct util_rx_entry *util_entry, *any_entry; + uint64_t tag = match_info->tag; + fi_addr_t addr = match_info->addr; int ret = FI_SUCCESS; srx_ctx = srx->ep_fid.fid.context; diff --git a/src/fi_tostr.c b/src/fi_tostr.c index 9cc20e4f4d7..67e03151007 100644 --- a/src/fi_tostr.c +++ b/src/fi_tostr.c @@ -227,6 +227,7 @@ static void ofi_tostr_caps(char *buf, size_t len, uint64_t caps) IFFLAGSTRN(caps, FI_NAMED_RX_CTX, len); IFFLAGSTRN(caps, FI_DIRECTED_RECV, len); IFFLAGSTRN(caps, FI_HMEM, len); + IFFLAGSTRN(caps, FI_PEER, len); ofi_remove_comma(buf); }