diff --git a/.travis.yml b/.travis.yml index e72c578..8ef5f5d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,6 +5,8 @@ go: - 1.8 - 1.9 - "1.10" + - "1.11" + - "1.12" script: - go test -v -race ./... diff --git a/example/sctp.go b/example/sctp.go index c207cca..19552c0 100644 --- a/example/sctp.go +++ b/example/sctp.go @@ -3,6 +3,7 @@ package main import ( "flag" "log" + "math/rand" "net" "strings" "time" @@ -10,15 +11,18 @@ import ( "github.com/ishidawataru/sctp" ) -func serveClient(conn net.Conn) error { +func serveClient(conn net.Conn, bufsize int) error { for { - buf := make([]byte, 254) + buf := make([]byte, bufsize+128) // add overhead of SCTPSndRcvInfoWrappedConn n, err := conn.Read(buf) if err != nil { + log.Printf("read failed: %v", err) return err } + log.Printf("read: %d", n) n, err = conn.Write(buf[:n]) if err != nil { + log.Printf("write failed: %v", err) return err } log.Printf("write: %d", n) @@ -30,6 +34,9 @@ func main() { var ip = flag.String("ip", "0.0.0.0", "") var port = flag.Int("port", 0, "") var lport = flag.Int("lport", 0, "") + var bufsize = flag.Int("bufsize", 256, "") + var sndbuf = flag.Int("sndbuf", 0, "") + var rcvbuf = flag.Int("rcvbuf", 0, "") flag.Parse() @@ -64,7 +71,29 @@ func main() { } log.Printf("Accepted Connection from RemoteAddr: %s", conn.RemoteAddr()) wconn := sctp.NewSCTPSndRcvInfoWrappedConn(conn.(*sctp.SCTPConn)) - go serveClient(wconn) + if *sndbuf != 0 { + err = wconn.SetWriteBuffer(*sndbuf) + if err != nil { + log.Fatalf("failed to set write buf: %v", err) + } + } + if *rcvbuf != 0 { + err = wconn.SetReadBuffer(*rcvbuf) + if err != nil { + log.Fatalf("failed to set read buf: %v", err) + } + } + *sndbuf, err = wconn.GetWriteBuffer() + if err != nil { + log.Fatalf("failed to get write buf: %v", err) + } + *rcvbuf, err = wconn.GetWriteBuffer() + if err != nil { + log.Fatalf("failed to get read buf: %v", err) + } + log.Printf("SndBufSize: %d, RcvBufSize: %d", *sndbuf, *rcvbuf) + + go serveClient(wconn, *bufsize) } } else { @@ -78,7 +107,32 @@ func main() { if err != nil { log.Fatalf("failed to dial: %v", err) } + log.Printf("Dail LocalAddr: %s; RemoteAddr: %s", conn.LocalAddr(), conn.RemoteAddr()) + + if *sndbuf != 0 { + err = conn.SetWriteBuffer(*sndbuf) + if err != nil { + log.Fatalf("failed to set write buf: %v", err) + } + } + if *rcvbuf != 0 { + err = conn.SetReadBuffer(*rcvbuf) + if err != nil { + log.Fatalf("failed to set read buf: %v", err) + } + } + + *sndbuf, err = conn.GetWriteBuffer() + if err != nil { + log.Fatalf("failed to get write buf: %v", err) + } + *rcvbuf, err = conn.GetReadBuffer() + if err != nil { + log.Fatalf("failed to get read buf: %v", err) + } + log.Printf("SndBufSize: %d, RcvBufSize: %d", *sndbuf, *rcvbuf) + ppid := 0 for { info := &sctp.SndRcvInfo{ @@ -87,17 +141,21 @@ func main() { } ppid += 1 conn.SubscribeEvents(sctp.SCTP_EVENT_DATA_IO) - n, err := conn.SCTPWrite([]byte("hello"), info) + buf := make([]byte, *bufsize) + n, err := rand.Read(buf) + if n != *bufsize { + log.Fatalf("failed to generate random string len: %d", *bufsize) + } + n, err = conn.SCTPWrite(buf, info) if err != nil { log.Fatalf("failed to write: %v", err) } - log.Printf("write: %d", n) - buf := make([]byte, 254) - _, info, err = conn.SCTPRead(buf) + log.Printf("write: len %d", n) + n, info, err = conn.SCTPRead(buf) if err != nil { log.Fatalf("failed to read: %v", err) } - log.Printf("read: info: %+v", info) + log.Printf("read: len %d, info: %+v", n, info) time.Sleep(time.Second) } } diff --git a/sctp.go b/sctp.go index 30d6196..34ea7ca 100644 --- a/sctp.go +++ b/sctp.go @@ -678,3 +678,19 @@ func (c *SCTPSndRcvInfoWrappedConn) SetReadDeadline(t time.Time) error { func (c *SCTPSndRcvInfoWrappedConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) } + +func (c *SCTPSndRcvInfoWrappedConn) SetWriteBuffer(bytes int) error { + return c.conn.SetWriteBuffer(bytes) +} + +func (c *SCTPSndRcvInfoWrappedConn) GetWriteBuffer() (int, error) { + return c.conn.GetWriteBuffer() +} + +func (c *SCTPSndRcvInfoWrappedConn) SetReadBuffer(bytes int) error { + return c.conn.SetReadBuffer(bytes) +} + +func (c *SCTPSndRcvInfoWrappedConn) GetReadBuffer() (int, error) { + return c.conn.GetReadBuffer() +} diff --git a/sctp_linux.go b/sctp_linux.go index 5a6ad93..2e1976c 100644 --- a/sctp_linux.go +++ b/sctp_linux.go @@ -114,6 +114,22 @@ func (c *SCTPConn) Close() error { return syscall.EBADF } +func (c *SCTPConn) SetWriteBuffer(bytes int) error { + return syscall.SetsockoptInt(c.fd(), syscall.SOL_SOCKET, syscall.SO_SNDBUF, bytes) +} + +func (c *SCTPConn) GetWriteBuffer() (int, error) { + return syscall.GetsockoptInt(c.fd(), syscall.SOL_SOCKET, syscall.SO_SNDBUF) +} + +func (c *SCTPConn) SetReadBuffer(bytes int) error { + return syscall.SetsockoptInt(c.fd(), syscall.SOL_SOCKET, syscall.SO_RCVBUF, bytes) +} + +func (c *SCTPConn) GetReadBuffer() (int, error) { + return syscall.GetsockoptInt(c.fd(), syscall.SOL_SOCKET, syscall.SO_RCVBUF) +} + // ListenSCTP - start listener on specified address/port func ListenSCTP(net string, laddr *SCTPAddr) (*SCTPListener, error) { return ListenSCTPExt(net, laddr, InitMsg{NumOstreams: SCTP_MAX_STREAM}) diff --git a/sctp_unsupported.go b/sctp_unsupported.go index e541584..ed12f46 100644 --- a/sctp_unsupported.go +++ b/sctp_unsupported.go @@ -30,6 +30,22 @@ func (c *SCTPConn) Close() error { return ErrUnsupported } +func (c *SCTPConn) SetWriteBuffer(bytes int) error { + return ErrUnsupported +} + +func (c *SCTPConn) GetWriteBuffer() (int, error) { + return 0, ErrUnsupported +} + +func (c *SCTPConn) SetReadBuffer(bytes int) error { + return ErrUnsupported +} + +func (c *SCTPConn) GetReadBuffer() (int, error) { + return 0, ErrUnsupported +} + func ListenSCTP(net string, laddr *SCTPAddr) (*SCTPListener, error) { return nil, ErrUnsupported }