@@ -8,8 +8,8 @@ package ucx
88// #include "goucx.h"
99import "C"
1010import (
11- "sync"
1211 "unsafe"
12+ "runtime/cgo"
1313)
1414
1515type UcpCallback interface {}
@@ -23,46 +23,48 @@ type UcpAmDataRecvCallback = func(request *UcpRequest, status UcsStatus, length
2323type UcpAmRecvCallback = func (header unsafe.Pointer , headerSize uint64 ,
2424 data * UcpAmData , replyEp * UcpEp ) UcsStatus
2525
26+ type UcpAmRecvCallbackBundle struct {
27+ cb UcpAmRecvCallback
28+ worker * UcpWorker
29+ }
30+
2631// This callback routine is invoked on the server side to handle incoming
2732// connections from remote clients.
2833type UcpListenerConnectionHandler = func (connRequest * UcpConnectionRequest )
2934
30- // Map from the callback id that is passed to C to the actual go callback.
31- var callback_map = make (map [uint64 ]UcpCallback )
35+ type PackedCallback cgo.Handle
36+ func packCallback (cb UcpCallback ) PackedCallback {
37+ if cb == nil {
38+ return 0
39+ }
40+
41+ return PackedCallback (cgo .NewHandle (cb ))
42+ }
3243
33- // Unique index for each go callback, that passes to user_data.
34- var callback_id uint64 = 1
44+ func (pc PackedCallback ) unpackInternal (freeHandle bool ) UcpCallback {
45+ if pc == 0 {
46+ return nil
47+ }
3548
36- var mu sync.Mutex
49+ h := cgo .Handle (pc )
50+ if freeHandle {
51+ defer h .Delete ()
52+ }
3753
38- // Associates go callback with a unique id
39- func register (cb UcpCallback ) uint64 {
40- mu .Lock ()
41- defer mu .Unlock ()
42- callback_id ++
43- callback_map [callback_id ] = cb
44- return callback_id
54+ return h .Value ().(UcpCallback )
4555}
4656
47- // Atomically removes registered callback by it's id
48- func deregister (id uint64 ) (UcpCallback , bool ) {
49- mu .Lock ()
50- defer mu .Unlock ()
51- val , ret := callback_map [id ]
52- delete (callback_map , id )
53- return val , ret
57+ func (pc PackedCallback ) unpack () UcpCallback {
58+ return pc .unpackInternal (false )
5459}
5560
56- func getCallback (id uint64 ) (UcpCallback , bool ) {
57- mu .Lock ()
58- defer mu .Unlock ()
59- val , ret := callback_map [id ]
60- return val , ret
61+ func (pc PackedCallback ) unpackAndFree () UcpCallback {
62+ return pc .unpackInternal (true )
6163}
6264
6365//export ucxgo_completeGoSendRequest
64- func ucxgo_completeGoSendRequest (request unsafe.Pointer , status C.ucs_status_t , callbackId unsafe.Pointer ) {
65- if callback , found := deregister ( uint64 ( uintptr ( callbackId ))); found {
66+ func ucxgo_completeGoSendRequest (request unsafe.Pointer , status C.ucs_status_t , packedCb unsafe.Pointer ) {
67+ if callback := PackedCallback ( packedCb ). unpackAndFree (); callback != nil {
6668 callback .(UcpSendCallback )(& UcpRequest {
6769 request : request ,
6870 Status : UcsStatus (status ),
@@ -71,8 +73,8 @@ func ucxgo_completeGoSendRequest(request unsafe.Pointer, status C.ucs_status_t,
7173}
7274
7375//export ucxgo_completeGoTagRecvRequest
74- func ucxgo_completeGoTagRecvRequest (request unsafe.Pointer , status C.ucs_status_t , tag_info * C.ucp_tag_recv_info_t , callbackId unsafe.Pointer ) {
75- if callback , found := deregister ( uint64 ( uintptr ( callbackId ))); found {
76+ func ucxgo_completeGoTagRecvRequest (request unsafe.Pointer , status C.ucs_status_t , tag_info * C.ucp_tag_recv_info_t , packedCb unsafe.Pointer ) {
77+ if callback := PackedCallback ( packedCb ). unpackAndFree (); callback != nil {
7678 callback .(UcpTagRecvCallback )(& UcpRequest {
7779 request : request ,
7880 Status : UcsStatus (status ),
@@ -84,30 +86,30 @@ func ucxgo_completeGoTagRecvRequest(request unsafe.Pointer, status C.ucs_status_
8486}
8587
8688//export ucxgo_amRecvCallback
87- func ucxgo_amRecvCallback (calbackId unsafe.Pointer , header unsafe.Pointer , headerSize C.size_t ,
89+ func ucxgo_amRecvCallback (packedCb unsafe.Pointer , header unsafe.Pointer , headerSize C.size_t ,
8890 data unsafe.Pointer , dataSize C.size_t , params * C.ucp_am_recv_param_t ) C.ucs_status_t {
89- cbId := uint64 ( uintptr ( calbackId ))
90- if callback , found := getCallback ( cbId ); found {
91+ if callback := PackedCallback ( packedCb ). unpack (); callback != nil {
92+ bundle := callback .( * UcpAmRecvCallbackBundle )
9193 var replyEp * UcpEp
9294 if (params .recv_attr & C .UCP_AM_RECV_ATTR_FIELD_REPLY_EP ) != 0 {
9395 replyEp = & UcpEp {ep : params .reply_ep }
9496 }
9597 amData := & UcpAmData {
96- worker : idToWorker [ cbId ] ,
98+ worker : bundle . worker ,
9799 flags : UcpAmRecvAttrs (params .recv_attr ),
98100 dataPtr : data ,
99101 length : uint64 (dataSize ),
100102 }
101- return C .ucs_status_t (callback .( UcpAmRecvCallback ) (header , uint64 (headerSize ), amData , replyEp ))
103+ return C .ucs_status_t (bundle . cb (header , uint64 (headerSize ), amData , replyEp ))
102104 }
103105 return C .UCS_OK
104106}
105107
106108//export ucxgo_completeAmRecvData
107109func ucxgo_completeAmRecvData (request unsafe.Pointer , status C.ucs_status_t ,
108- length C.size_t , callbackId unsafe.Pointer ) {
110+ length C.size_t , packedCb unsafe.Pointer ) {
109111
110- if callback , found := deregister ( uint64 ( uintptr ( callbackId ))); found {
112+ if callback := PackedCallback ( packedCb ). unpackAndFree (); callback != nil {
111113 callback .(UcpAmDataRecvCallback )(& UcpRequest {
112114 request : request ,
113115 Status : UcsStatus (status ),
0 commit comments