Skip to content

Commit ba5cc7e

Browse files
committed
BINDINGS/GO: Bug fix - make gobindings thread safe.
1 parent a8f9138 commit ba5cc7e

File tree

9 files changed

+140
-94
lines changed

9 files changed

+140
-94
lines changed

bindings/go/src/ucx/am_data.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@ type UcpAmData struct {
2020
flags UcpAmRecvAttrs
2121
}
2222

23-
// To connect callback id with worker, to use in AmData.Receive()
24-
var idToWorker = make(map[uint64]*UcpWorker)
25-
2623
// Whether actual data is received or need to call UcpAmData.Receive()
2724
func (d *UcpAmData) IsDataValid() bool {
2825
return (d.flags & UCP_AM_RECV_ATTR_FLAG_RNDV) == 0

bindings/go/src/ucx/callbacks.go

Lines changed: 38 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ package ucx
88
// #include "goucx.h"
99
import "C"
1010
import (
11-
"sync"
1211
"unsafe"
12+
"runtime/cgo"
1313
)
1414

1515
type UcpCallback interface{}
@@ -23,46 +23,48 @@ type UcpAmDataRecvCallback = func(request *UcpRequest, status UcsStatus, length
2323
type UcpAmRecvCallback = func(header unsafe.Pointer, headerSize uint64,
2424
data *UcpAmData, replyEp *UcpEp) UcsStatus
2525

26+
type UcpAmRecvCallbackBundle struct {
27+
cb UcpAmRecvCallback
28+
worker *UcpWorker
29+
}
30+
2631
// This callback routine is invoked on the server side to handle incoming
2732
// connections from remote clients.
2833
type UcpListenerConnectionHandler = func(connRequest *UcpConnectionRequest)
2934

30-
// Map from the callback id that is passed to C to the actual go callback.
31-
var callback_map = make(map[uint64]UcpCallback)
35+
type PackedCallback cgo.Handle
36+
func packCallback(cb UcpCallback) PackedCallback {
37+
if cb == nil {
38+
return 0
39+
}
40+
41+
return PackedCallback(cgo.NewHandle(cb))
42+
}
3243

33-
// Unique index for each go callback, that passes to user_data.
34-
var callback_id uint64 = 1
44+
func (pc PackedCallback) unpackInternal(freeHandle bool) UcpCallback {
45+
if pc == 0 {
46+
return nil
47+
}
3548

36-
var mu sync.Mutex
49+
h := cgo.Handle(pc)
50+
if freeHandle {
51+
defer h.Delete()
52+
}
3753

38-
// Associates go callback with a unique id
39-
func register(cb UcpCallback) uint64 {
40-
mu.Lock()
41-
defer mu.Unlock()
42-
callback_id++
43-
callback_map[callback_id] = cb
44-
return callback_id
54+
return h.Value().(UcpCallback)
4555
}
4656

47-
// Atomically removes registered callback by it's id
48-
func deregister(id uint64) (UcpCallback, bool) {
49-
mu.Lock()
50-
defer mu.Unlock()
51-
val, ret := callback_map[id]
52-
delete(callback_map, id)
53-
return val, ret
57+
func (pc PackedCallback) unpack() UcpCallback {
58+
return pc.unpackInternal(false)
5459
}
5560

56-
func getCallback(id uint64) (UcpCallback, bool) {
57-
mu.Lock()
58-
defer mu.Unlock()
59-
val, ret := callback_map[id]
60-
return val, ret
61+
func (pc PackedCallback) unpackAndFree() UcpCallback {
62+
return pc.unpackInternal(true)
6163
}
6264

6365
//export ucxgo_completeGoSendRequest
64-
func ucxgo_completeGoSendRequest(request unsafe.Pointer, status C.ucs_status_t, callbackId unsafe.Pointer) {
65-
if callback, found := deregister(uint64(uintptr(callbackId))); found {
66+
func ucxgo_completeGoSendRequest(request unsafe.Pointer, status C.ucs_status_t, packedCb unsafe.Pointer) {
67+
if callback := PackedCallback(packedCb).unpackAndFree(); callback != nil {
6668
callback.(UcpSendCallback)(&UcpRequest{
6769
request: request,
6870
Status: UcsStatus(status),
@@ -71,8 +73,8 @@ func ucxgo_completeGoSendRequest(request unsafe.Pointer, status C.ucs_status_t,
7173
}
7274

7375
//export ucxgo_completeGoTagRecvRequest
74-
func ucxgo_completeGoTagRecvRequest(request unsafe.Pointer, status C.ucs_status_t, tag_info *C.ucp_tag_recv_info_t, callbackId unsafe.Pointer) {
75-
if callback, found := deregister(uint64(uintptr(callbackId))); found {
76+
func ucxgo_completeGoTagRecvRequest(request unsafe.Pointer, status C.ucs_status_t, tag_info *C.ucp_tag_recv_info_t, packedCb unsafe.Pointer) {
77+
if callback := PackedCallback(packedCb).unpackAndFree(); callback != nil {
7678
callback.(UcpTagRecvCallback)(&UcpRequest{
7779
request: request,
7880
Status: UcsStatus(status),
@@ -84,30 +86,30 @@ func ucxgo_completeGoTagRecvRequest(request unsafe.Pointer, status C.ucs_status_
8486
}
8587

8688
//export ucxgo_amRecvCallback
87-
func ucxgo_amRecvCallback(calbackId unsafe.Pointer, header unsafe.Pointer, headerSize C.size_t,
89+
func ucxgo_amRecvCallback(packedCb unsafe.Pointer, header unsafe.Pointer, headerSize C.size_t,
8890
data unsafe.Pointer, dataSize C.size_t, params *C.ucp_am_recv_param_t) C.ucs_status_t {
89-
cbId := uint64(uintptr(calbackId))
90-
if callback, found := getCallback(cbId); found {
91+
if callback := PackedCallback(packedCb).unpack(); callback != nil {
92+
bundle := callback.(*UcpAmRecvCallbackBundle)
9193
var replyEp *UcpEp
9294
if (params.recv_attr & C.UCP_AM_RECV_ATTR_FIELD_REPLY_EP) != 0 {
9395
replyEp = &UcpEp{ep: params.reply_ep}
9496
}
9597
amData := &UcpAmData{
96-
worker: idToWorker[cbId],
98+
worker: bundle.worker,
9799
flags: UcpAmRecvAttrs(params.recv_attr),
98100
dataPtr: data,
99101
length: uint64(dataSize),
100102
}
101-
return C.ucs_status_t(callback.(UcpAmRecvCallback)(header, uint64(headerSize), amData, replyEp))
103+
return C.ucs_status_t(bundle.cb(header, uint64(headerSize), amData, replyEp))
102104
}
103105
return C.UCS_OK
104106
}
105107

106108
//export ucxgo_completeAmRecvData
107109
func ucxgo_completeAmRecvData(request unsafe.Pointer, status C.ucs_status_t,
108-
length C.size_t, callbackId unsafe.Pointer) {
110+
length C.size_t, packedCb unsafe.Pointer) {
109111

110-
if callback, found := deregister(uint64(uintptr(callbackId))); found {
112+
if callback := PackedCallback(packedCb).unpackAndFree(); callback != nil {
111113
callback.(UcpAmDataRecvCallback)(&UcpRequest{
112114
request: request,
113115
Status: UcsStatus(status),

bindings/go/src/ucx/context.go

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55

66
package ucx
77

8+
// #include <stdlib.h>
89
// #include <ucp/api/ucp.h>
910
// #include <ucs/type/status.h>
1011
import "C"
12+
import "unsafe"
1113

1214
// UCP application context (or just a context) is an opaque handle that holds a
1315
// UCP communication instance's global information. It represents a single UCP
@@ -27,10 +29,44 @@ type UcpContext struct {
2729
context C.ucp_context_h
2830
}
2931

32+
type UcpConfig struct {
33+
config *C.ucp_config_t
34+
}
35+
36+
func ReadConfig() (UcpConfig, error) {
37+
var config *C.ucp_config_t
38+
if status := C.ucp_config_read(nil, nil, &config); status != C.UCS_OK {
39+
return UcpConfig{}, newUcxError(status)
40+
}
41+
return UcpConfig{config: config}, nil
42+
}
43+
44+
func (c *UcpConfig) Modify(name string, value string) error {
45+
cName := C.CString(name)
46+
cValue := C.CString(value)
47+
defer C.free(unsafe.Pointer(cName))
48+
defer C.free(unsafe.Pointer(cValue))
49+
if status := C.ucp_config_modify(c.config, cName, cValue); status != C.UCS_OK {
50+
return newUcxError(status)
51+
}
52+
return nil
53+
}
54+
55+
func (c *UcpConfig) Close() {
56+
C.ucp_config_release(c.config)
57+
}
58+
3059
func NewUcpContext(contextParams *UcpParams) (*UcpContext, error) {
3160
var ucp_context C.ucp_context_h
3261

33-
if status := C.ucp_init(&contextParams.params, nil, &ucp_context); status != C.UCS_OK {
62+
config, err := ReadConfig()
63+
if err != nil {
64+
return nil, err
65+
}
66+
defer config.Close()
67+
config.Modify("GVA_ENABLE", "auto")
68+
69+
if status := C.ucp_init(&contextParams.params, config.config, &ucp_context); status != C.UCS_OK {
3470
return nil, newUcxError(status)
3571
}
3672

@@ -92,6 +128,10 @@ func (c *UcpContext) Query(attrs ...UcpContextAttr) (*C.ucp_context_attr_t, erro
92128
func (c *UcpContext) NewWorker(workerParams *UcpWorkerParams) (*UcpWorker, error) {
93129
var ucp_worker C.ucp_worker_h
94130

131+
if ((workerParams.params.field_mask & C.UCP_WORKER_PARAM_FIELD_THREAD_MODE) == 0) {
132+
workerParams.SetThreadMode(UCS_THREAD_MODE_MULTI)
133+
}
134+
95135
if status := C.ucp_worker_create(c.context, &workerParams.params, &ucp_worker); status != C.UCS_OK {
96136
return nil, newUcxError(status)
97137
}

bindings/go/src/ucx/endpoint.go

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ type UcpEp struct {
1818

1919
var errorHandles = make(map[C.ucp_ep_h]UcpEpErrHandler)
2020

21-
func setSendParams(goRequestParams *UcpRequestParams, cRequestParams *C.ucp_request_param_t) uint64 {
21+
func setSendParams(goRequestParams *UcpRequestParams, cRequestParams *C.ucp_request_param_t) unsafe.Pointer {
2222
return packParams(goRequestParams, cRequestParams, unsafe.Pointer(C.ucxgo_completeGoSendRequest))
2323
}
2424

@@ -28,22 +28,20 @@ func setSendParams(goRequestParams *UcpRequestParams, cRequestParams *C.ucp_requ
2828
func (e *UcpEp) FlushNonBlocking(params *UcpRequestParams) (*UcpRequest, error) {
2929
var requestParams C.ucp_request_param_t
3030

31-
cbId := setSendParams(params, &requestParams)
32-
31+
callback := setSendParams(params, &requestParams)
3332
request := C.ucp_ep_flush_nbx(e.ep, &requestParams)
34-
return NewRequest(request, cbId, nil)
33+
return newRequest(request, callback, nil)
3534
}
3635

3736
func (e *UcpEp) CloseNonBlocking(mode C.uint, params *UcpRequestParams) (*UcpRequest, error) {
3837
var requestParams C.ucp_request_param_t
3938
requestParams.op_attr_mask = C.UCP_OP_ATTR_FIELD_FLAGS
4039
requestParams.flags = mode
4140

42-
cbId := setSendParams(params, &requestParams)
43-
41+
callback := setSendParams(params, &requestParams)
4442
request := C.ucp_ep_close_nbx(e.ep, &requestParams)
4543
delete(errorHandles, e.ep)
46-
return NewRequest(request, cbId, nil)
44+
return newRequest(request, callback, nil)
4745
}
4846

4947
// Non-blocking endpoint closure. Releases the endpoint without any
@@ -68,10 +66,9 @@ func (e *UcpEp) SendTagNonBlocking(tag uint64, address unsafe.Pointer, size uint
6866
params *UcpRequestParams) (*UcpRequest, error) {
6967
var requestParams C.ucp_request_param_t
7068

71-
cbId := setSendParams(params, &requestParams)
72-
69+
callback := setSendParams(params, &requestParams)
7370
request := C.ucp_tag_send_nbx(e.ep, address, C.size_t(size), C.ucp_tag_t(tag), &requestParams)
74-
return NewRequest(request, cbId, nil)
71+
return newRequest(request, callback, nil)
7572
}
7673

7774
// This routine sends an Active Message to an ep.
@@ -81,11 +78,10 @@ func (e *UcpEp) SendAmNonBlocking(id uint, header unsafe.Pointer, headerSize uin
8178
data unsafe.Pointer, dataSize uint64, flags UcpAmSendFlags, params *UcpRequestParams) (*UcpRequest, error) {
8279
var requestParams C.ucp_request_param_t
8380

84-
cbId := setSendParams(params, &requestParams)
85-
81+
callback := setSendParams(params, &requestParams)
8682
requestParams.op_attr_mask |= C.UCP_OP_ATTR_FIELD_FLAGS
8783
requestParams.flags = C.uint(flags)
8884

8985
request := C.ucp_am_send_nbx(e.ep, C.uint(id), header, C.size_t(headerSize), data, C.size_t(dataSize), &requestParams)
90-
return NewRequest(request, cbId, nil)
86+
return newRequest(request, callback, nil)
9187
}

bindings/go/src/ucx/listener.go

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,23 @@ package ucx
77

88
// #include <ucp/api/ucp.h>
99
import "C"
10-
import "net"
10+
import (
11+
"net"
12+
)
1113

1214
type UcpListener struct {
13-
listener C.ucp_listener_h
14-
connHandlerId uint64
15+
listener C.ucp_listener_h
16+
callback UcpListenerConnectionHandler
17+
packedCb PackedCallback
1518
}
1619

17-
// Needed to call connHandler.Reject() rather than listener.Reject(connHandler)
18-
var connHandles2Listener = make(map[uint64]C.ucp_listener_h)
19-
2020
type UcpListenerAttributes struct {
2121
Address *net.TCPAddr
2222
}
2323

2424
func (l *UcpListener) Close() {
2525
C.ucp_listener_destroy(l.listener)
26-
deregister(l.connHandlerId)
27-
delete(connHandles2Listener, l.connHandlerId)
26+
l.packedCb.unpackAndFree()
2827
}
2928

3029
func (l *UcpListener) Query(attrs ...UcpListenerAttribute) (*UcpListenerAttributes, error) {

bindings/go/src/ucx/listener_params.go

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,17 @@ import (
1616

1717
// Tuning parameters for the UCP listener.
1818
type UcpListenerParams struct {
19-
params C.ucp_listener_params_t
20-
connHandlerId uint64
19+
params C.ucp_listener_params_t
20+
callback UcpListenerConnectionHandler
2121
}
2222

2323
//export ucxgo_completeConnHandler
24-
func ucxgo_completeConnHandler(connRequest C.ucp_conn_request_h, cbId unsafe.Pointer) {
25-
id := uint64(uintptr((cbId)))
26-
if callback, found := getCallback(id); found {
27-
listener := connHandles2Listener[id]
28-
callback.(UcpListenerConnectionHandler)(&UcpConnectionRequest{
24+
func ucxgo_completeConnHandler(connRequest C.ucp_conn_request_h, packedCb unsafe.Pointer) {
25+
if callback := PackedCallback(packedCb).unpack(); callback != nil {
26+
listener := callback.(*UcpListener)
27+
listener.callback(&UcpConnectionRequest{
2928
connRequest: connRequest,
30-
listener: listener,
29+
listener: listener.listener,
3130
})
3231
}
3332
}
@@ -49,10 +48,8 @@ func (p *UcpListenerParams) SetSocketAddress(a *net.TCPAddr) (*UcpListenerParams
4948
// Handler of an incoming connection request in a client-server connection flow.
5049
func (p *UcpListenerParams) SetConnectionHandler(connHandler UcpListenerConnectionHandler) *UcpListenerParams {
5150
var ucpConnHndl C.ucp_listener_conn_handler_t
52-
cbId := register(connHandler)
5351

54-
p.connHandlerId = cbId
55-
ucpConnHndl.arg = unsafe.Pointer(uintptr(cbId))
52+
p.callback = connHandler
5653
ucpConnHndl.cb = (C.ucp_listener_conn_callback_t)(C.ucxgo_completeConnHandler)
5754
p.params.field_mask |= C.UCP_LISTENER_PARAM_FIELD_CONN_HANDLER
5855
p.params.conn_handler = ucpConnHndl

0 commit comments

Comments
 (0)