Skip to content

Commit

Permalink
include/rdma: new fi_peer_match structure
Browse files Browse the repository at this point in the history
Add a new structure fi_peer_match to collect the parameters which need
to be passed to the get_msg and get_tag functions.

Update the util_get_tag() and util_get_msg() function callbacks.
Compilation gives a warning but not failing. This causes memory
corruption when the callbacks are called.

Signed-off-by: Amir Shehata <shehataa@ornl.gov>
  • Loading branch information
amirshehataornl committed May 16, 2024
1 parent bcc5181 commit 3ddf85f
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 22 deletions.
17 changes: 13 additions & 4 deletions include/rdma/providers/fi_peer.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,21 @@ struct fi_peer_rx_entry {
struct iovec *iov;
};

struct fi_peer_match {
fi_addr_t addr;
uint64_t tag;
size_t size;
void *context;
};

struct fi_ops_srx_owner {
size_t size;
int (*get_msg)(struct fid_peer_srx *srx, fi_addr_t addr,
size_t size, struct fi_peer_rx_entry **entry);
int (*get_tag)(struct fid_peer_srx *srx, fi_addr_t addr,
uint64_t tag, struct fi_peer_rx_entry **entry);
int (*get_msg)(struct fid_peer_srx *srx,
struct fi_peer_match *match,
struct fi_peer_rx_entry **entry);
int (*get_tag)(struct fid_peer_srx *srx,
struct fi_peer_match *match,
struct fi_peer_rx_entry **entry);
int (*queue_msg)(struct fi_peer_rx_entry *entry);
int (*queue_tag)(struct fi_peer_rx_entry *entry);
void (*foreach_unspec_addr)(struct fid_peer_srx *srx,
Expand Down
14 changes: 10 additions & 4 deletions prov/efa/src/rdm/efa_rdm_msg.c
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,7 @@ struct efa_rdm_ope *efa_rdm_msg_alloc_rxe_for_msgrtm(struct efa_rdm_ep *ep,
size_t data_size;
int ret;
int pkt_type;
struct fi_peer_match match = {0};

if ((*pkt_entry_ptr)->alloc_type == EFA_RDM_PKE_FROM_USER_BUFFER) {
/* If a pkt_entry is constructred from user supplied buffer,
Expand All @@ -782,7 +783,10 @@ struct efa_rdm_ope *efa_rdm_msg_alloc_rxe_for_msgrtm(struct efa_rdm_ep *ep,
peer_srx = util_get_peer_srx(ep->peer_srx_ep);
data_size = efa_rdm_pke_get_rtm_msg_length(*pkt_entry_ptr);

ret = peer_srx->owner_ops->get_msg(peer_srx, (*pkt_entry_ptr)->addr, data_size, &peer_rxe);
match.addr = (*pkt_entry_ptr)->addr;
match.size = data_size;

ret = peer_srx->owner_ops->get_msg(peer_srx, &match, &peer_rxe);

if (ret == FI_SUCCESS) { /* A matched rxe is found */
rxe = efa_rdm_msg_alloc_matched_rxe_for_rtm(ep, *pkt_entry_ptr, peer_rxe, ofi_op_msg);
Expand Down Expand Up @@ -844,12 +848,14 @@ struct efa_rdm_ope *efa_rdm_msg_alloc_rxe_for_tagrtm(struct efa_rdm_ep *ep,
struct efa_rdm_ope *rxe;
int ret;
int pkt_type;
struct fi_peer_match match = {0};

peer_srx = util_get_peer_srx(ep->peer_srx_ep);

ret = peer_srx->owner_ops->get_tag(peer_srx, (*pkt_entry_ptr)->addr,
efa_rdm_pke_get_rtm_tag(*pkt_entry_ptr),
&peer_rxe);
match.addr = (*pkt_entry_ptr)->addr;
match.tag = efa_rdm_pke_get_rtm_tag(*pkt_entry_ptr);

ret = peer_srx->owner_ops->get_tag(peer_srx, &match, &peer_rxe);

if (ret == FI_SUCCESS) { /* A matched rxe is found */
rxe = efa_rdm_msg_alloc_matched_rxe_for_rtm(ep, *pkt_entry_ptr, peer_rxe, ofi_op_tagged);
Expand Down
16 changes: 10 additions & 6 deletions prov/shm/src/smr_progress.c
Original file line number Diff line number Diff line change
Expand Up @@ -987,13 +987,16 @@ static int smr_progress_cmd_msg(struct smr_ep *ep, struct smr_cmd *cmd)
{
struct fid_peer_srx *peer_srx = smr_get_peer_srx(ep);
struct fi_peer_rx_entry *rx_entry;
fi_addr_t addr;
struct fi_peer_match match_info = {0};
int ret;

addr = ep->region->map->peers[cmd->msg.hdr.id].fiaddr;
match_info.addr = ep->region->map->peers[cmd->msg.hdr.id].fiaddr;
match_info.context = NULL;

if (cmd->msg.hdr.op == ofi_op_tagged) {
ret = peer_srx->owner_ops->get_tag(peer_srx, addr,
cmd->msg.hdr.tag, &rx_entry);
match_info.tag = cmd->msg.hdr.tag;
ret = peer_srx->owner_ops->get_tag(peer_srx, &match_info,
&rx_entry);
if (ret == -FI_ENOENT) {
ret = smr_alloc_cmd_ctx(ep, rx_entry, cmd);
if (ret) {
Expand All @@ -1009,8 +1012,9 @@ static int smr_progress_cmd_msg(struct smr_ep *ep, struct smr_cmd *cmd)
goto out;
}
} else {
ret = peer_srx->owner_ops->get_msg(peer_srx, addr,
cmd->msg.hdr.size, &rx_entry);
match_info.size = cmd->msg.hdr.size;
ret = peer_srx->owner_ops->get_msg(peer_srx, &match_info,
&rx_entry);
if (ret == -FI_ENOENT) {
ret = smr_alloc_cmd_ctx(ep, rx_entry, cmd);
if (ret) {
Expand Down
12 changes: 8 additions & 4 deletions prov/sm2/src/sm2_progress.c
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ static int sm2_progress_recv_msg(struct sm2_ep *ep,
{
struct fid_peer_srx *peer_srx = sm2_get_peer_srx(ep);
struct fi_peer_rx_entry *rx_entry;
struct fi_peer_match match;
struct sm2_av *sm2_av;
fi_addr_t addr;
int ret = 0;
Expand All @@ -540,9 +541,12 @@ static int sm2_progress_recv_msg(struct sm2_ep *ep,
sm2_av = container_of(ep->util_ep.av, struct sm2_av, util_av);
addr = sm2_av->reverse_lookup[xfer_entry->hdr.sender_gid];

memset(&match, 0, sizeof(match));
match.addr = addr;

if (xfer_entry->hdr.op == ofi_op_tagged) {
ret = peer_srx->owner_ops->get_tag(
peer_srx, addr, xfer_entry->hdr.tag, &rx_entry);
match.tag = xfer_entry->hdr.tag;
ret = peer_srx->owner_ops->get_tag(peer_srx, &match, &rx_entry);
if (ret == -FI_ENOENT) {
xfer_entry->hdr.proto_flags |= SM2_UNEXP;
ret = sm2_alloc_xfer_entry_ctx(ep, rx_entry,
Expand All @@ -557,8 +561,8 @@ static int sm2_progress_recv_msg(struct sm2_ep *ep,
goto out;
}
} else {
ret = peer_srx->owner_ops->get_msg(
peer_srx, addr, xfer_entry->hdr.size, &rx_entry);
match.size = xfer_entry->hdr.size;
ret = peer_srx->owner_ops->get_msg(peer_srx, &match, &rx_entry);
if (ret == -FI_ENOENT) {
xfer_entry->hdr.proto_flags |= SM2_UNEXP;
ret = sm2_alloc_xfer_entry_ctx(ep, rx_entry,
Expand Down
14 changes: 10 additions & 4 deletions prov/util/src/util_srx.c
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,15 @@ static int util_match_msg(struct fid_peer_srx *srx, fi_addr_t addr, size_t size,
return ret;
}

static int util_get_msg(struct fid_peer_srx *srx, fi_addr_t addr,
size_t size, struct fi_peer_rx_entry **rx_entry)
static int util_get_msg(struct fid_peer_srx *srx,
struct fi_peer_match *match_info,
struct fi_peer_rx_entry **rx_entry)
{
struct util_srx_ctx *srx_ctx;
struct util_rx_entry *util_entry, *any_entry;
struct slist *queue;
fi_addr_t addr = match_info->addr;
size_t size = match_info->size;

srx_ctx = srx->ep_fid.fid.context;
assert(ofi_genlock_held(srx_ctx->lock));
Expand Down Expand Up @@ -269,14 +272,17 @@ static int util_match_tag(struct fid_peer_srx *srx, fi_addr_t addr,
return ret;
}

static int util_get_tag(struct fid_peer_srx *srx, fi_addr_t addr,
uint64_t tag, struct fi_peer_rx_entry **rx_entry)
static int util_get_tag(struct fid_peer_srx *srx,
struct fi_peer_match *match_info,
struct fi_peer_rx_entry **rx_entry)
{
struct util_srx_ctx *srx_ctx;
struct slist *queue;
struct slist_entry *any_item, *any_prev;
struct slist_entry *item, *prev;
struct util_rx_entry *util_entry, *any_entry;
uint64_t tag = match_info->tag;
fi_addr_t addr = match_info->addr;
int ret = FI_SUCCESS;

srx_ctx = srx->ep_fid.fid.context;
Expand Down

0 comments on commit 3ddf85f

Please sign in to comment.