Skip to content
Permalink
Browse files

refactor: move func of packet creation to MQTTReader

  • Loading branch information...
bati11 committed Mar 12, 2019
1 parent 4a5cb80 commit 1d1adca733d3cc915a1c76fef623cd868b53cbdb
@@ -1,26 +1,14 @@
package handler

import (
"bufio"
"fmt"

"github.com/bati11/oreno-mqtt/mqtt/packet"
)

// CONNECTパケットの可変ヘッダーのバイト数
var variableHeaderLength = 10

func HandleConnect(fixedHeader packet.FixedHeader, r *bufio.Reader) (packet.Connack, error) {
func HandleConnect(reader *packet.MQTTReader) (packet.Connack, error) {
fmt.Printf("HandleConnect\n")
variableHeader, err := packet.ToConnectVariableHeader(fixedHeader, r)
if err != nil {
if ce, ok := err.(packet.ConnectError); ok {
return ce.Connack(), nil
}
return packet.Connack{}, err
}

payload, err := packet.ToConnectPayload(r)
connect, err := reader.ReadConnect()
if err != nil {
if ce, ok := err.(packet.ConnectError); ok {
return ce.Connack(), nil
@@ -29,8 +17,8 @@ func HandleConnect(fixedHeader packet.FixedHeader, r *bufio.Reader) (packet.Conn
}

// TODO variableHeaderとpayloadを使って何かしらの処理
fmt.Printf(" %#v\n", variableHeader)
fmt.Printf(" %#v\n", payload)
fmt.Printf(" %#v\n", connect.VariableHeader)
fmt.Printf(" %#v\n", connect.Payload)

return packet.NewConnackForAccepted(), nil
}
@@ -1,28 +1,19 @@
package handler

import (
"bufio"
"fmt"
"io"

"github.com/bati11/oreno-mqtt/mqtt/packet"
)

func HandlePublish(fixedHeader packet.PublishFixedHeader, r *bufio.Reader) error {
func HandlePublish(reader *packet.MQTTReader) error {
fmt.Printf(" HandlePublish\n")
variableHeader, err := packet.ToPublishVariableHeader(fixedHeader, r)
publish, err := reader.ReadPublish()
if err != nil {
return err
}
fmt.Printf(" %#v\n", variableHeader)

payloadLength := fixedHeader.RemainingLength - variableHeader.Length()
payload := make([]byte, payloadLength)
_, err = io.ReadFull(r, payload)
if err != nil {
return err
}
fmt.Printf(" Payload: %v\n", string(payload))
fmt.Printf(" %#v\n", publish.VariableHeader)
fmt.Printf(" Payload: %v\n", string(publish.Payload))

// TODO QoS0なのでレスポンスなし
return nil
@@ -0,0 +1,23 @@
package packet

type Connect struct {
FixedHeader *FixedHeader
VariableHeader *ConnectVariableHeader
Payload *ConnectPayload
}

func (reader *MQTTReader) ReadConnect() (*Connect, error) {
fixedHeader, err := reader.readFixedHeader()
if err != nil {
return nil, err
}
variableHeader, err := reader.readConnectVariableHeader()
if err != nil {
return nil, err
}
payload, err := reader.readConnectPayload()
if err != nil {
return nil, err
}
return &Connect{fixedHeader, variableHeader, payload}, nil
}
@@ -1,7 +1,6 @@
package packet

import (
"bufio"
"encoding/binary"
"io"
"regexp"
@@ -13,25 +12,25 @@ type ConnectPayload struct {

var clientIDRegex = regexp.MustCompile("^[a-zA-Z0-9-|]*$")

func ToConnectPayload(r *bufio.Reader) (ConnectPayload, error) {
func (reader *MQTTReader) readConnectPayload() (*ConnectPayload, error) {
lengthBytes := make([]byte, 2)
_, err := io.ReadFull(r, lengthBytes)
_, err := io.ReadFull(reader.r, lengthBytes)
if err != nil {
return ConnectPayload{}, err
return nil, err
}
length := binary.BigEndian.Uint16(lengthBytes)

clientIDBytes := make([]byte, length)
_, err = io.ReadFull(r, clientIDBytes)
_, err = io.ReadFull(reader.r, clientIDBytes)
if err != nil {
return ConnectPayload{}, err
return nil, err
}
clientID := string(clientIDBytes)
if len(clientID) < 1 || len(clientID) > 23 {
return ConnectPayload{}, RefusedByIdentifierRejected("ClientID length is invalid")
return nil, RefusedByIdentifierRejected("ClientID length is invalid")
}
if !clientIDRegex.MatchString(clientID) {
return ConnectPayload{}, RefusedByIdentifierRejected("ClientId format shoud be \"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ\"")
return nil, RefusedByIdentifierRejected("ClientId format shoud be \"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ\"")
}
return ConnectPayload{ClientID: clientID}, nil
return &ConnectPayload{ClientID: clientID}, nil
}
@@ -1,62 +1,63 @@
package packet
package packet_test

import (
"bufio"
"bytes"
"reflect"
"testing"

"github.com/bati11/oreno-mqtt/mqtt/packet"
)

func TestToConnectPayload(t *testing.T) {
func TestMQTTReader_ReadConnectPayload(t *testing.T) {
type args struct {
r *bufio.Reader
r *packet.MQTTReader
}
tests := []struct {
name string
args args
want ConnectPayload
want *packet.ConnectPayload
wantErr bool
}{
{
name: "ClientIDが1文字",
args: args{bufio.NewReader(bytes.NewBuffer([]byte{0x00, 0x01, 'a'}))},
want: ConnectPayload{ClientID: "a"},
args: args{packet.NewMQTTReader(bytes.NewBuffer([]byte{0x00, 0x01, 'a'}))},
want: &packet.ConnectPayload{ClientID: "a"},
wantErr: false,
},
{
name: "ペイロードが0byte",
args: args{bufio.NewReader(bytes.NewBuffer([]byte{}))},
want: ConnectPayload{},
args: args{packet.NewMQTTReader(bytes.NewBuffer([]byte{}))},
want: nil,
wantErr: true,
},
{
name: "ClientIDが23文字を超える",
args: args{bufio.NewReader(bytes.NewBuffer([]byte{0x00, 0x18, '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', 'a', 'b', 'c', 'd'}))},
want: ConnectPayload{},
args: args{packet.NewMQTTReader(bytes.NewBuffer([]byte{0x00, 0x18, '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', 'a', 'b', 'c', 'd'}))},
want: nil,
wantErr: true,
},
{
name: "使えない文字がある",
args: args{bufio.NewReader(bytes.NewBuffer([]byte{0x00, 0x02, '1', '%'}))},
want: ConnectPayload{},
args: args{packet.NewMQTTReader(bytes.NewBuffer([]byte{0x00, 0x02, '1', '%'}))},
want: nil,
wantErr: true,
},
{
name: "指定された長さよりも実際に取得できたClientIDが短い",
args: args{bufio.NewReader(bytes.NewBuffer([]byte{0x00, 0x03, '1', '2'}))},
want: ConnectPayload{},
args: args{packet.NewMQTTReader(bytes.NewBuffer([]byte{0x00, 0x03, '1', '2'}))},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ToConnectPayload(tt.args.r)
got, err := packet.ExportReadConnectPayload(tt.args.r)
if (err != nil) != tt.wantErr {
t.Errorf("ToConnectPayload() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("ExportReadConnectPayload() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("ToConnectPayload() = %v, want %v", got, tt.want)
t.Errorf("ExportReadConnectPayload() = %v, want %v", got, tt.want)
}
})
}
@@ -1,10 +1,7 @@
package packet

import (
"bufio"
"io"

"github.com/pkg/errors"
)

type ConnectFlags struct {
@@ -23,35 +20,32 @@ type ConnectVariableHeader struct {
KeepAlive uint16
}

func ToConnectVariableHeader(fixedHeader FixedHeader, r *bufio.Reader) (ConnectVariableHeader, error) {
if fixedHeader.PacketType != 1 {
return ConnectVariableHeader{}, errors.New("fixedHeader.PacketType must be 1")
}
func (reader *MQTTReader) readConnectVariableHeader() (*ConnectVariableHeader, error) {
protocolName := make([]byte, 6)
_, err := io.ReadFull(r, protocolName)
_, err := io.ReadFull(reader.r, protocolName)
if err != nil || !isValidProtocolName(protocolName) {
return ConnectVariableHeader{}, RefusedByUnacceptableProtocolVersion("protocol name is invalid")
return nil, RefusedByUnacceptableProtocolVersion("protocol name is invalid")
}
protocolLevel, err := r.ReadByte()
protocolLevel, err := reader.r.ReadByte()
if err != nil || protocolLevel != 4 {
return ConnectVariableHeader{}, RefusedByUnacceptableProtocolVersion("protocol level must be 4")
return nil, RefusedByUnacceptableProtocolVersion("protocol level must be 4")
}

// TODO
_, err = r.ReadByte() // connectFlags
_, err = reader.r.ReadByte() // connectFlags
if err != nil {
return ConnectVariableHeader{}, err
return nil, err
}
_, err = r.ReadByte() // keepAlive MSB
_, err = reader.r.ReadByte() // keepAlive MSB
if err != nil {
return ConnectVariableHeader{}, err
return nil, err
}
_, err = r.ReadByte() // keepAlive LSB
_, err = reader.r.ReadByte() // keepAlive LSB
if err != nil {
return ConnectVariableHeader{}, err
return nil, err
}

return ConnectVariableHeader{
return &ConnectVariableHeader{
ProtocolName: "MQTT",
ProtocolLevel: 4,
ConnectFlags: ConnectFlags{UserNameFlag: true, PasswordFlag: true, WillRetain: false, WillQoS: 1, WillFlag: true, CleanSession: true},
@@ -1,96 +1,77 @@
package packet_test

import (
"bufio"
"bytes"
"reflect"
"testing"

"github.com/bati11/oreno-mqtt/mqtt/packet"
)

func TestToConnectVariableHeader(t *testing.T) {
func TestMQTTReader_ReadConnectVariableHeader(t *testing.T) {
type args struct {
fixedHeader packet.FixedHeader
r *bufio.Reader
r *packet.MQTTReader
}
tests := []struct {
name string
args args
want packet.ConnectVariableHeader
want *packet.ConnectVariableHeader
wantErr bool
}{
{
name: "仕様書のexample",
args: args{
fixedHeader: packet.FixedHeader{PacketType: 1},
r: bufio.NewReader(bytes.NewBuffer([]byte{
r: packet.NewMQTTReader(bytes.NewBuffer([]byte{
0x00, 0x04, 'M', 'Q', 'T', 'T', // Protocol Name
0x04, // Protocol Level
0xCE, // Connect Flags
0x00, 0x0A, // Keep Alive
})),
},
want: packet.ConnectVariableHeader{
want: &packet.ConnectVariableHeader{
ProtocolName: "MQTT",
ProtocolLevel: 4,
ConnectFlags: packet.ConnectFlags{UserNameFlag: true, PasswordFlag: true, WillRetain: false, WillQoS: 1, WillFlag: true, CleanSession: true},
KeepAlive: 10,
},
wantErr: false,
},
{
name: "固定ヘッダーのPacketTypeが1ではない",
args: args{
fixedHeader: packet.FixedHeader{PacketType: 2},
r: bufio.NewReader(bytes.NewReader([]byte{
0x00, 0x04, 'M', 'Q', 'T', 'T', // Protocol Name
0x04, // Protocol Level
0xCE, // Connect Flags
0x00, 0x0A, // Keep Alive
})),
},
want: packet.ConnectVariableHeader{},
wantErr: true,
},
{
name: "Protocol Nameが不正",
args: args{
fixedHeader: packet.FixedHeader{PacketType: 1},
r: bufio.NewReader(bytes.NewReader([]byte{
r: packet.NewMQTTReader(bytes.NewReader([]byte{
0x00, 0x04, 'M', 'Q', 'T', 't', // Protocol Name
0x04, // Protocol Level
0xCE, // Connect Flags
0x00, 0x0A, // Keep Alive
})),
},
want: packet.ConnectVariableHeader{},
want: nil,
wantErr: true,
},
{
name: "Protocol Levelが不正",
args: args{
fixedHeader: packet.FixedHeader{PacketType: 1},
r: bufio.NewReader(bytes.NewReader([]byte{
r: packet.NewMQTTReader(bytes.NewReader([]byte{
0x00, 0x04, 'M', 'Q', 'T', 'T', // Protocol Name
0x03, // Protocol Level
0xCE, // Connect Flags
0x00, 0x0A, // Keep Alive
})),
},
want: packet.ConnectVariableHeader{},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := packet.ToConnectVariableHeader(tt.args.fixedHeader, tt.args.r)
got, err := packet.ExportReadVariableConnectHeader(tt.args.r)
if (err != nil) != tt.wantErr {
t.Errorf("ToConnectVariableHeader() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("ExportReadVariableConnectHeader() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("ToConnectVariableHeader() = %v, want %v", got, tt.want)
t.Errorf("ExportReadVariableConnectHeader() = %v, want %v", got, tt.want)
}
})
}
@@ -0,0 +1,7 @@
package packet

var ExportReadPublishFixedHeader = (*MQTTReader).readPublishFixedHeader
var ExportReadPublishVariableHeader = (*MQTTReader).readPublishVariableHeader

var ExportReadVariableConnectHeader = (*MQTTReader).readConnectVariableHeader
var ExportReadConnectPayload = (*MQTTReader).readConnectPayload
Oops, something went wrong.

0 comments on commit 1d1adca

Please sign in to comment.
You can’t perform that action at this time.