Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADDED] MQTT: Support for Websocket #2735

Merged
merged 1 commit into from
Dec 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion server/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -1106,7 +1106,7 @@ func validateAllowedConnectionTypes(m map[string]struct{}) error {
switch ctuc {
case jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket,
jwt.ConnectionTypeLeafnode, jwt.ConnectionTypeLeafnodeWS,
jwt.ConnectionTypeMqtt:
jwt.ConnectionTypeMqtt, jwt.ConnectionTypeMqttWS:
default:
return fmt.Errorf("unknown connection type %q", ct)
}
Expand Down
14 changes: 11 additions & 3 deletions server/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,11 @@ func (c *client) initClient() {
case WS:
c.ncs.Store(fmt.Sprintf("%s - wid:%d", conn, c.cid))
case MQTT:
c.ncs.Store(fmt.Sprintf("%s - mid:%d", conn, c.cid))
var ws string
if c.isWebsocket() {
ws = "_ws"
}
c.ncs.Store(fmt.Sprintf("%s - mid%s:%d", conn, ws, c.cid))
}
case ROUTER:
c.ncs.Store(fmt.Sprintf("%s - rid:%d", conn, c.cid))
Expand Down Expand Up @@ -5180,7 +5184,7 @@ func convertAllowedConnectionTypes(cts []string) (map[string]struct{}, error) {
switch i {
case jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket,
jwt.ConnectionTypeLeafnode, jwt.ConnectionTypeLeafnodeWS,
jwt.ConnectionTypeMqtt:
jwt.ConnectionTypeMqtt, jwt.ConnectionTypeMqttWS:
m[i] = struct{}{}
default:
unknown = append(unknown, i)
Expand Down Expand Up @@ -5211,7 +5215,11 @@ func (c *client) connectionTypeAllowed(acts map[string]struct{}) bool {
case WS:
want = jwt.ConnectionTypeWebsocket
case MQTT:
want = jwt.ConnectionTypeMqtt
if c.isWebsocket() {
want = jwt.ConnectionTypeMqttWS
} else {
want = jwt.ConnectionTypeMqtt
}
}
case LEAF:
if c.isWebsocket() {
Expand Down
102 changes: 44 additions & 58 deletions server/mqtt.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ const (

// Default retry delay if transfer of old session streams to new one fails
mqttDefaultTransferRetry = 5 * time.Second

// For Websocket URLs
mqttWSPath = "/mqtt"
)

var (
Expand All @@ -181,7 +184,7 @@ var (
)

var (
errMQTTWebsocketNotSupported = errors.New("invalid connection, websocket currently not supported")
errMQTTNotWebsocketPort = errors.New("MQTT clients over websocket must connect to the Websocket port, not the MQTT port")
errMQTTTopicFilterCannotBeEmpty = errors.New("topic filter cannot be empty")
errMQTTMalformedVarInt = errors.New("malformed variable int")
errMQTTSecondConnectPacket = errors.New("received a second CONNECT packet")
Expand Down Expand Up @@ -343,6 +346,8 @@ type mqttReader struct {
reader mqttIOReader
buf []byte
pos int
pstart int
pbuf []byte
}

type mqttWriter struct {
Expand Down Expand Up @@ -407,14 +412,14 @@ func (s *Server) startMQTT() {
scheme = "tls"
}
s.Noticef("Listening for MQTT clients on %s://%s:%d", scheme, o.Host, o.Port)
go s.acceptConnections(hl, "MQTT", func(conn net.Conn) { s.createMQTTClient(conn) }, nil)
go s.acceptConnections(hl, "MQTT", func(conn net.Conn) { s.createMQTTClient(conn, nil) }, nil)
s.mu.Unlock()
}

// This is similar to createClient() but has some modifications specifi to MQTT clients.
// The comments have been kept to minimum to reduce code size. Check createClient() for
// more details.
func (s *Server) createMQTTClient(conn net.Conn) *client {
func (s *Server) createMQTTClient(conn net.Conn, ws *websocket) *client {
opts := s.getOpts()

maxPay := int32(opts.MaxPayload)
Expand All @@ -424,7 +429,7 @@ func (s *Server) createMQTTClient(conn net.Conn) *client {
}
now := time.Now().UTC()

c := &client{srv: s, nc: conn, mpay: maxPay, msubs: maxSubs, start: now, last: now, mqtt: &mqtt{}}
c := &client{srv: s, nc: conn, mpay: maxPay, msubs: maxSubs, start: now, last: now, mqtt: &mqtt{}, ws: ws}
c.headers = true
c.mqtt.pp = &mqttPublish{}
// MQTT clients don't send NATS CONNECT protocols. So make it an "echo"
Expand Down Expand Up @@ -463,7 +468,8 @@ func (s *Server) createMQTTClient(conn net.Conn) *client {
}
s.clients[c.cid] = c

tlsRequired := opts.MQTT.TLSConfig != nil
// Websocket TLS handshake is already done when getting to this function.
tlsRequired := opts.MQTT.TLSConfig != nil && ws == nil
s.mu.Unlock()

c.mu.Lock()
Expand Down Expand Up @@ -623,9 +629,13 @@ func (c *client) mqttParse(buf []byte) error {
var err error
var b byte
var pl int
var complete bool

for err == nil && r.hasMore() {

// Keep track of the starting of the packet, in case we have a partial
r.pstart = r.pos

// Read packet type and flags
if b, err = r.readByte("packet type"); err != nil {
break
Expand All @@ -637,17 +647,19 @@ func (c *client) mqttParse(buf []byte) error {
// If client was not connected yet, the first packet must be
// a mqttPacketConnect otherwise we fail the connection.
if !connected && pt != mqttPacketConnect {
// Try to guess if the client is trying to connect using Websocket,
// which is currently not supported
if bytes.HasPrefix(buf, []byte("GET ")) {
err = errMQTTWebsocketNotSupported
// If the buffer indicates that it may be a websocket handshake
// but the client is not websocket, it means that the client
// connected to the MQTT port instead of the Websocket port.
if bytes.HasPrefix(buf, []byte("GET ")) && !c.isWebsocket() {
err = errMQTTNotWebsocketPort
} else {
err = fmt.Errorf("the first packet should be a CONNECT (%v), got %v", mqttPacketConnect, pt)
}
break
}

if pl, err = r.readPacketLen(); err != nil {
pl, complete, err = r.readPacketLen()
if err != nil || !complete {
break
}

Expand Down Expand Up @@ -2379,13 +2391,6 @@ func (sess *mqttSession) deleteConsumer(cc *ConsumerConfig) {

// Parse the MQTT connect protocol
func (c *client) mqttParseConnect(r *mqttReader, pl int) (byte, *mqttConnectProto, error) {

// Make sure that we have the expected length in the buffer,
// and if not, this will read it from the underlying reader.
if err := r.ensurePacketInBuffer(pl); err != nil {
return 0, nil, err
}

// Protocol name
proto, err := r.readBytes("protocol name", false)
if err != nil {
Expand Down Expand Up @@ -2804,9 +2809,6 @@ func (c *client) mqttParsePub(r *mqttReader, pl int, pp *mqttPublish) error {
if qos > 1 {
return fmt.Errorf("publish QoS=%v not supported", qos)
}
if err := r.ensurePacketInBuffer(pl); err != nil {
return err
}
// Keep track of where we are when starting to read the variable header
start := r.pos

Expand Down Expand Up @@ -3086,9 +3088,6 @@ func (c *client) mqttEnqueuePubAck(pi uint16) {
}

func mqttParsePubAck(r *mqttReader, pl int) (uint16, error) {
if err := r.ensurePacketInBuffer(pl); err != nil {
return 0, err
}
pi, err := r.readUint16("packet identifier")
if err != nil {
return 0, err
Expand Down Expand Up @@ -3168,9 +3167,6 @@ func (c *client) mqttParseSubsOrUnsubs(r *mqttReader, b byte, pl int, sub bool)
if rf := b & 0xf; rf != expectedFlag {
return 0, nil, fmt.Errorf("wrong %ssubscribe reserved flags: %x", action, rf)
}
if err := r.ensurePacketInBuffer(pl); err != nil {
return 0, nil, err
}
pi, err := r.readUint16("packet identifier")
if err != nil {
return 0, nil, fmt.Errorf("reading packet identifier: %v", err)
Expand Down Expand Up @@ -3870,8 +3866,16 @@ func mqttNeedSubForLevelUp(subject string) bool {
//////////////////////////////////////////////////////////////////////////////

func (r *mqttReader) reset(buf []byte) {
if l := len(r.pbuf); l > 0 {
tmp := make([]byte, l+len(buf))
copy(tmp, r.pbuf)
copy(tmp[l:], buf)
buf = tmp
r.pbuf = nil
}
r.buf = buf
r.pos = 0
r.pstart = 0
}

func (r *mqttReader) hasMore() bool {
Expand All @@ -3887,7 +3891,11 @@ func (r *mqttReader) readByte(field string) (byte, error) {
return b, nil
}

func (r *mqttReader) readPacketLen() (int, error) {
func (r *mqttReader) readPacketLen() (int, bool, error) {
return r.readPacketLenWithCheck(true)
}

func (r *mqttReader) readPacketLenWithCheck(check bool) (int, bool, error) {
m := 1
v := 0
for {
Expand All @@ -3896,45 +3904,23 @@ func (r *mqttReader) readPacketLen() (int, error) {
b = r.buf[r.pos]
r.pos++
} else {
var buf [1]byte
if _, err := r.reader.Read(buf[:1]); err != nil {
if err == io.EOF {
return 0, io.ErrUnexpectedEOF
}
return 0, fmt.Errorf("error reading packet length: %v", err)
}
b = buf[0]
break
}
v += int(b&0x7f) * m
if (b & 0x80) == 0 {
return v, nil
if check && r.pos+v > len(r.buf) {
break
}
return v, true, nil
}
m *= 0x80
if m > 0x200000 {
return 0, errMQTTMalformedVarInt
}
}
}

func (r *mqttReader) ensurePacketInBuffer(pl int) error {
rem := len(r.buf) - r.pos
if rem >= pl {
return nil
}
b := make([]byte, pl)
start := copy(b, r.buf[r.pos:])
for start != pl {
n, err := r.reader.Read(b[start:cap(b)])
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return fmt.Errorf("error ensuring protocol is loaded: %v", err)
return 0, false, errMQTTMalformedVarInt
}
start += n
}
r.reset(b)
return nil
r.pbuf = make([]byte, len(r.buf)-r.pstart)
copy(r.pbuf, r.buf[r.pstart:])
return 0, false, nil
}

func (r *mqttReader) readString(field string) (string, error) {
Expand Down