Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sctp fixes #9

Merged
merged 2 commits into from
Feb 1, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 38 additions & 13 deletions sctp.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
"unsafe"
Expand Down Expand Up @@ -330,13 +331,17 @@ func SCTPBind(fd int, addr *SCTPAddr, flags int) error {
}

type SCTPConn struct {
fd int
_fd int32
notificationHandler NotificationHandler
}

func (c *SCTPConn) fd() int {
return int(atomic.LoadInt32(&c._fd))
}

func NewSCTPConn(fd int, handler NotificationHandler) *SCTPConn {
conn := &SCTPConn{
fd: fd,
_fd: int32(fd),
notificationHandler: handler,
}
return conn
Expand All @@ -348,6 +353,9 @@ func (c *SCTPConn) Write(b []byte) (int, error) {

func (c *SCTPConn) Read(b []byte) (int, error) {
n, _, err := c.SCTPRead(b)
if n < 0 {
n = 0
}
return n, err
}

Expand All @@ -359,7 +367,7 @@ func (c *SCTPConn) SetInitMsg(numOstreams, maxInstreams, maxAttempts, maxInitTim
MaxInitTimeout: uint16(maxInitTimeout),
}
optlen := unsafe.Sizeof(param)
_, _, err := setsockopt(c.fd, SCTP_INITMSG, uintptr(unsafe.Pointer(&param)), uintptr(optlen))
_, _, err := setsockopt(c.fd(), SCTP_INITMSG, uintptr(unsafe.Pointer(&param)), uintptr(optlen))
return err
}

Expand Down Expand Up @@ -408,14 +416,14 @@ func (c *SCTPConn) SubscribeEvents(flags int) error {
SenderDry: se,
}
optlen := unsafe.Sizeof(param)
_, _, err := setsockopt(c.fd, SCTP_EVENTS, uintptr(unsafe.Pointer(&param)), uintptr(optlen))
_, _, err := setsockopt(c.fd(), SCTP_EVENTS, uintptr(unsafe.Pointer(&param)), uintptr(optlen))
return err
}

func (c *SCTPConn) SubscribedEvents() (int, error) {
param := EventSubscribe{}
optlen := unsafe.Sizeof(param)
_, _, err := getsockopt(c.fd, SCTP_EVENTS, uintptr(unsafe.Pointer(&param)), uintptr(unsafe.Pointer(&optlen)))
_, _, err := getsockopt(c.fd(), SCTP_EVENTS, uintptr(unsafe.Pointer(&param)), uintptr(unsafe.Pointer(&optlen)))
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -455,14 +463,14 @@ func (c *SCTPConn) SubscribedEvents() (int, error) {

func (c *SCTPConn) SetDefaultSentParam(info *SndRcvInfo) error {
optlen := unsafe.Sizeof(*info)
_, _, err := setsockopt(c.fd, SCTP_DEFAULT_SENT_PARAM, uintptr(unsafe.Pointer(info)), uintptr(optlen))
_, _, err := setsockopt(c.fd(), SCTP_DEFAULT_SENT_PARAM, uintptr(unsafe.Pointer(info)), uintptr(optlen))
return err
}

func (c *SCTPConn) GetDefaultSentParam() (*SndRcvInfo, error) {
info := &SndRcvInfo{}
optlen := unsafe.Sizeof(*info)
_, _, err := getsockopt(c.fd, SCTP_DEFAULT_SENT_PARAM, uintptr(unsafe.Pointer(info)), uintptr(unsafe.Pointer(&optlen)))
_, _, err := getsockopt(c.fd(), SCTP_DEFAULT_SENT_PARAM, uintptr(unsafe.Pointer(info)), uintptr(unsafe.Pointer(&optlen)))
return info, err
}

Expand Down Expand Up @@ -514,24 +522,41 @@ func sctpGetAddrs(fd, id, optname int) (*SCTPAddr, error) {
return resolveFromRawAddr(unsafe.Pointer(&param.addrs), int(param.addrNum))
}

func (c *SCTPConn) SCTPGetPrimaryPeerAddr() (*SCTPAddr, error) {

type sctpGetSetPrim struct {
assocId int32
addrs [128]byte
}
param := sctpGetSetPrim{
assocId: int32(0),
}
optlen := unsafe.Sizeof(param)
_, _, err := getsockopt(c.fd(), SCTP_PRIMARY_ADDR, uintptr(unsafe.Pointer(&param)), uintptr(unsafe.Pointer(&optlen)))
if err != nil {
return nil, err
}
return resolveFromRawAddr(unsafe.Pointer(&param.addrs), 1)
}

func (c *SCTPConn) SCTPLocalAddr(id int) (*SCTPAddr, error) {
return sctpGetAddrs(c.fd, id, SCTP_GET_LOCAL_ADDRS)
return sctpGetAddrs(c.fd(), id, SCTP_GET_LOCAL_ADDRS)
}

func (c *SCTPConn) SCTPRemoteAddr(id int) (*SCTPAddr, error) {
return sctpGetAddrs(c.fd, id, SCTP_GET_PEER_ADDRS)
return sctpGetAddrs(c.fd(), id, SCTP_GET_PEER_ADDRS)
}

func (c *SCTPConn) LocalAddr() net.Addr {
addr, err := sctpGetAddrs(c.fd, 0, SCTP_GET_LOCAL_ADDRS)
addr, err := sctpGetAddrs(c.fd(), 0, SCTP_GET_LOCAL_ADDRS)
if err != nil {
return nil
}
return addr
}

func (c *SCTPConn) RemoteAddr() net.Addr {
addr, err := sctpGetAddrs(c.fd, 0, SCTP_GET_PEER_ADDRS)
addr, err := sctpGetAddrs(c.fd(), 0, SCTP_GET_PEER_ADDRS)
if err != nil {
return nil
}
Expand All @@ -547,11 +572,11 @@ func (c *SCTPConn) PeelOff(id int) (*SCTPConn, error) {
assocId: int32(id),
}
optlen := unsafe.Sizeof(param)
_, _, err := getsockopt(c.fd, SCTP_SOCKOPT_PEELOFF, uintptr(unsafe.Pointer(&param)), uintptr(unsafe.Pointer(&optlen)))
_, _, err := getsockopt(c.fd(), SCTP_SOCKOPT_PEELOFF, uintptr(unsafe.Pointer(&param)), uintptr(unsafe.Pointer(&optlen)))
if err != nil {
return nil, err
}
return &SCTPConn{fd: param.sd}, nil
return &SCTPConn{_fd: int32(param.sd)}, nil
}

func (c *SCTPConn) SetDeadline(t time.Time) error {
Expand Down
21 changes: 14 additions & 7 deletions sctp_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net"
"sync/atomic"
"syscall"
"unsafe"
)
Expand Down Expand Up @@ -54,7 +55,7 @@ func (c *SCTPConn) SCTPWrite(b []byte, info *SndRcvInfo) (int, error) {
hdr.SetLen(syscall.CmsgSpace(len(cmsgBuf)))
cbuf = append(toBuf(hdr), cmsgBuf...)
}
return syscall.SendmsgN(c.fd, b, cbuf, nil, 0)
return syscall.SendmsgN(c.fd(), b, cbuf, nil, 0)
}

func parseSndRcvInfo(b []byte) (*SndRcvInfo, error) {
Expand All @@ -76,7 +77,7 @@ func parseSndRcvInfo(b []byte) (*SndRcvInfo, error) {
func (c *SCTPConn) SCTPRead(b []byte) (int, *SndRcvInfo, error) {
oob := make([]byte, 254)
for {
n, oobn, recvflags, _, err := syscall.Recvmsg(c.fd, b, oob, 0)
n, oobn, recvflags, _, err := syscall.Recvmsg(c.fd(), b, oob, 0)
if err != nil {
return n, nil, err
}
Expand All @@ -100,12 +101,18 @@ func (c *SCTPConn) SCTPRead(b []byte) (int, *SndRcvInfo, error) {
}

func (c *SCTPConn) Close() error {
info := &SndRcvInfo{
Flags: SCTP_EOF,
if c != nil {
fd := atomic.SwapInt32(&c._fd, -1)
if fd > 0 {
info := &SndRcvInfo{
Flags: SCTP_EOF,
}
c.SCTPWrite(nil, info)
syscall.Shutdown(int(fd), syscall.SHUT_RDWR)
return syscall.Close(int(fd))
}
}
c.SCTPWrite(nil, info)
syscall.Shutdown(c.fd, syscall.SHUT_RDWR)
return syscall.Close(c.fd)
return syscall.EBADF
}

func ListenSCTP(net string, laddr *SCTPAddr) (*SCTPListener, error) {
Expand Down
134 changes: 134 additions & 0 deletions sctp_streams_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package sctp

import (
"fmt"
"io"
"math/rand"
"testing"
"time"
)

const (
STREAM_TEST_CLIENTS = 128
STREAM_TEST_STREAMS = 11
)

func TestStreams(t *testing.T) {

r := rand.New(rand.NewSource(time.Now().UnixNano()))
randomStr := func(strlen int) string {
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, strlen)
for i := range result {
result[i] = chars[r.Intn(len(chars))]
}
return string(result)
}

addr, _ := ResolveSCTPAddr("sctp", "127.0.0.1:0")
ln, err := ListenSCTP("sctp", addr)
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
addr = ln.Addr().(*SCTPAddr)
t.Logf("Listen on %s", ln.Addr())

go func() {
for {
c, err := ln.Accept()
sconn := c.(*SCTPConn)
if err != nil {
t.Errorf("failed to accept: %v", err)
return
}
defer sconn.Close()

sconn.SubscribeEvents(SCTP_EVENT_DATA_IO)
go func() {
totalrcvd := 0
for {
buf := make([]byte, 512)
n, info, err := sconn.SCTPRead(buf)
if err != nil {
if err == io.EOF || err == io.ErrUnexpectedEOF {
if n == 0 {
break
}
t.Logf("EOF on server connection. Total bytes received: %d, bytes received: %d", totalrcvd, n)
} else {
t.Errorf("Server connection read err: %v. Total bytes received: %d, bytes received: %d", err, totalrcvd, n)
return
}
}
t.Logf("server read: info: %+v, payload: %s", info, string(buf[:n]))
n, err = sconn.SCTPWrite(buf[:n], info)
if err != nil {
t.Error(err)
return
}
}
}()
}
}()

wait := make(chan struct{})
i := 0
for ; i < STREAM_TEST_CLIENTS; i++ {
go func(test int) {
defer func() { wait <- struct{}{} }()
conn, err := DialSCTP("sctp", nil, addr)
if err != nil {
t.Errorf("failed to dial address %s, test #%d: %v", addr.String(), test, err)
return
}
defer conn.Close()
conn.SubscribeEvents(SCTP_EVENT_DATA_IO)
for ppid := uint16(0); ppid < STREAM_TEST_STREAMS; ppid++ {
info := &SndRcvInfo{
Stream: uint16(ppid),
PPID: uint32(ppid),
}
text := fmt.Sprintf("Test %s ***\n\t\t%d %d ***", randomStr(r.Intn(255)), test, ppid)
n, err := conn.SCTPWrite([]byte(text), info)
if err != nil {
t.Errorf("failed to write %s, len: %d, err: %v, bytes written: %d", text, len(text), err, n)
return
}
rn := 0
cn := 0
buf := make([]byte, 512)
for {
cn, info, err = conn.SCTPRead(buf[rn:])
if err != nil {
if err == io.EOF || err == io.ErrUnexpectedEOF {
rn += cn
break
}
t.Errorf("failed to read: %v", err)
return
}
if info.Stream != ppid {
t.Errorf("Mismatched PPIDs: %d != %d", info.Stream, ppid)
return
}
rn += cn
if rn >= n {
break
}
}
rtext := string(buf[:rn])
if rtext != text {
t.Fatalf("Mismatched payload: %s != %s", rtext, text)
}
}
}(i)
}
for ; i > 0; i-- {
select {
case <-wait:
case <-time.After(time.Second * 30):
close(wait)
t.Fatal("timed out")
}
}
}