/
mux.go
182 lines (147 loc) · 4.19 KB
/
mux.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
// Copyright 2018 The go-hep Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package mux implements the multiplexer that manages access to and writes data to the channels
by corresponding StreamID from xrootd protocol specification.
Example of usage:
mux := New()
defer m.Close()
// Claim channel for response retrieving.
id, channel, err := m.Claim()
if err != nil {
// handle error.
}
// Send a request to the server using id as a streamID.
go func() {
// Read response from the server.
// ...
// Send response to the awaiting caller using streamID from the server.
err := m.SendData(streamID, want)
if err != nil {
// handle error.
}
}
// Fetch response.
response := <-channel
*/
package mux // import "go-hep.org/x/hep/xrootd/internal/mux"
import (
"math"
"sync"
"github.com/pkg/errors"
"go-hep.org/x/hep/xrootd/protocol"
)
// ServerResponse contains slice of bytes Data representing data from
// XRootD server response (see XRootD protocol specification) and
// Err representing error received from server or occurred
// during response decoding.
type ServerResponse struct {
Data []byte
Err error
}
type dataSendChan chan<- ServerResponse
type DataRecvChan <-chan ServerResponse
const streamIDPartSize = math.MaxUint8
const streamIDPoolSize = streamIDPartSize * streamIDPartSize
// Mux manages channels by their ids.
// Basically, it's a map[StreamID] chan<-ServerResponse
// with methods to claim, free and pass data to a specific channel by id.
type Mux struct {
mu sync.Mutex
dataWaiters map[protocol.StreamID]dataSendChan
freeIDs chan uint16
quit chan struct{}
closed bool
}
// New creates a new Mux.
func New() *Mux {
const freeIDsBufferSize = 32 // 32 is completely arbitrary ATM and should be refined based on real use cases.
m := Mux{
dataWaiters: make(map[protocol.StreamID]dataSendChan),
freeIDs: make(chan uint16, freeIDsBufferSize),
quit: make(chan struct{}),
}
go func() {
var i uint16 = 0
for {
select {
case m.freeIDs <- i:
i = (i + 1) % streamIDPoolSize
case <-m.quit:
close(m.freeIDs)
return
}
}
}()
return &m
}
// Close closes the Mux.
func (m *Mux) Close() {
m.mu.Lock()
if m.closed {
m.mu.Unlock()
return
}
m.closed = true
m.mu.Unlock()
close(m.quit)
response := ServerResponse{Err: errors.New("xrootd: close was called before response was fully received")}
for streamID := range m.dataWaiters {
m.SendData(streamID, response)
m.Unclaim(streamID)
}
}
// Claim searches for unclaimed id and returns corresponding channel.
func (m *Mux) Claim() (protocol.StreamID, DataRecvChan, error) {
ch := make(chan ServerResponse)
for {
id := <-m.freeIDs
streamId := protocol.StreamID{byte(id >> 8), byte(id)}
m.mu.Lock()
if m.closed {
m.mu.Unlock()
return protocol.StreamID{}, nil, errors.New("mux: Claim was called on closed Mux")
}
if _, claimed := m.dataWaiters[streamId]; claimed { // Skip id if it was already claimed manually via ClaimWithID
m.mu.Unlock()
continue
}
m.dataWaiters[streamId] = ch
m.mu.Unlock()
return streamId, ch, nil
}
}
// ClaimWithID checks if id is unclaimed and returns the corresponding channel in case of success.
func (m *Mux) ClaimWithID(id protocol.StreamID) (DataRecvChan, error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.closed {
return nil, errors.New("mux: ClaimWithID was called on closed Mux")
}
ch := make(chan ServerResponse)
if _, claimed := m.dataWaiters[id]; claimed {
return nil, errors.Errorf("mux: channel with id %s is already claimed", id)
}
m.dataWaiters[id] = ch
return ch, nil
}
// Unclaim marks channel with specified id as unclaimed.
func (m *Mux) Unclaim(id protocol.StreamID) {
m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.dataWaiters[id]; ok {
close(m.dataWaiters[id])
delete(m.dataWaiters, id)
}
}
// SendData sends data to channel with specific id.
func (m *Mux) SendData(id protocol.StreamID, data ServerResponse) error {
m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.dataWaiters[id]; !ok {
return errors.Errorf("mux: cannot find data waiter for id %s", id)
}
m.dataWaiters[id] <- data
return nil
}