diff --git a/api/smp.go b/api/smp.go index f323072..38e9442 100644 --- a/api/smp.go +++ b/api/smp.go @@ -207,6 +207,66 @@ func (mp *MPPeerSock) WaitForPeerConnect(sel pathselection.CustomPathSelection) return remote, err } +// +// This method waits until a remote MPPeerSock calls connect to this +// socket's local address +// A pathselection may be passed, which lets the socket dialing back to its remote +// (e.g. for server-side path selection) +// Since the MPPeerSock waits for only one incoming connection to determine a new peer +// it starts waiting for other connections (if no selection passed) and fires the +// OnConnectionsChange event for each new incoming connection +// +func (mp *MPPeerSock) WaitForPeerConnectWithContext(ctx context.Context, sel pathselection.CustomPathSelection) (*snet.UDPAddr, error) { + log.Debugf("Waiting for incoming connection") + remote, err := mp.UnderlaySocket.WaitForDialInWithContext(ctx) + if err != nil { + return nil, err + } + log.Debugf("Accepted connection from %s", remote.String()) + mp.Peer = remote + mp.selection = sel + // Start selection process -> will update DB + mp.StartPathSelection(sel, sel == nil) + log.Debugf("Done path selection") + // wait until first signal on channel + // selectedPathSet := <-mp.OnPathsetChange + // time.Sleep(1 * time.Second) + // dial all paths selected by user algorithm + if sel != nil { + err = mp.DialAll(mp.SelectedPathSet, &socket.ConnectOptions{ + SendAddrPacket: false, + }) + mp.collectMetrics() + } else { + mp.collectMetrics() + go func() { + conns := mp.UnderlaySocket.GetConnections() + mp.PacketScheduler.SetConnections(conns) + mp.PathQualityDB.SetConnections(conns) + mp.connectionSetChange(conns) + for { + log.Debugf("Waiting for new connections...") + conn, err := mp.UnderlaySocket.WaitForIncomingConnWithContext(ctx) + if conn == nil && err == nil { + log.Debugf("Socket does not implement WaitForIncomingConn, stopping here...") + return + } + if err != nil { + log.Errorf("Failed to wait for incoming connection %s", err.Error()) + return + } + + conns := mp.UnderlaySocket.GetConnections() + mp.PacketScheduler.SetConnections(conns) + mp.PathQualityDB.SetConnections(conns) + mp.connectionSetChange(conns) + } + }() + } + + return remote, err +} + func (mp *MPPeerSock) collectMetrics() { mp.metricsTicker = time.NewTicker(mp.MetricsInterval) go func() { diff --git a/socket/quicsocket.go b/socket/quicsocket.go index ac82f0d..6bbbd31 100644 --- a/socket/quicsocket.go +++ b/socket/quicsocket.go @@ -137,6 +137,70 @@ func (s *QUICSocket) WaitForIncomingConn() (packets.UDPConn, error) { } } +func (s *QUICSocket) WaitForIncomingConnWithContext(ctx context.Context) (packets.UDPConn, error) { + if s.options == nil || !s.options.MultiportMode { + log.Debugf("Waiting for new connection") + stream, err := s.listenConns[0].AcceptStreamWithContext(ctx) + if err != nil { + log.Fatalf("QUIC Accept err %s", err.Error()) + } + + log.Debugf("Accepted new Stream on listen socket") + + bts := make([]byte, packets.PACKET_SIZE) + _, err = stream.Read(bts) + + if s.listenConns[0].GetInternalConn() == nil { + s.listenConns[0].SetStream(stream) + select { + case s.listenConns[0].Ready <- true: + default: + } + + return s.listenConns[0], nil + } else { + newConn := &packets.QUICReliableConn{} + id := RandStringBytes(32) + newConn.SetId(id) + newConn.SetLocal(*s.localAddr) + newConn.SetRemote(s.listenConns[0].GetRemote()) + newConn.SetStream(stream) + s.listenConns = append(s.listenConns, newConn) + + _, err = stream.Read(bts) + if err != nil { + return nil, err + } + return newConn, nil + } + } else { + addr := s.localAddr.Copy() + addr.Host.Port = s.localAddr.Host.Port + len(s.listenConns) + conn := &packets.QUICReliableConn{} + err := conn.Listen(*addr) + if err != nil { + return nil, err + } + + stream, err := conn.AcceptStreamWithContext(ctx) + if err != nil { + return nil, err + } + + id := RandStringBytes(32) + conn.SetId(id) + + conn.SetStream(stream) + s.listenConns = append(s.listenConns, conn) + bts := make([]byte, packets.PACKET_SIZE) + _, err = stream.Read(bts) + if err != nil { + return nil, err + } + return conn, nil + } +} + func (s *QUICSocket) WaitForDialIn() (*snet.UDPAddr, error) { bts := make([]byte, packets.PACKET_SIZE) log.Debugf("Wait for Dial In") @@ -210,7 +274,7 @@ func (s *QUICSocket) WaitForDialInWithContext(ctx context.Context) (*snet.UDPAdd log.Debugf("Waiting for %d more connections", p.NumPaths-1) for i := 1; i < p.NumPaths; i++ { - _, err := s.WaitForIncomingConn() + _, err := s.WaitForIncomingConnWithContext(ctx) if err != nil { return nil, err } diff --git a/socket/scionsocket.go b/socket/scionsocket.go index 1c05a3d..f676d3a 100644 --- a/socket/scionsocket.go +++ b/socket/scionsocket.go @@ -140,6 +140,10 @@ func (s *SCIONSocket) WaitForIncomingConn() (packets.UDPConn, error) { return nil, nil } +func (s *SCIONSocket) WaitForIncomingConnWithContext(ctx context.Context) (packets.UDPConn, error) { + return nil, nil +} + func (s *SCIONSocket) DialAll(remote snet.UDPAddr, path []pathselection.PathQuality, options DialOptions) ([]packets.UDPConn, error) { // There is always one listening connection conns := make([]packets.UDPConn, 1) diff --git a/socket/socket.go b/socket/socket.go index 3ea54a6..6400f5e 100644 --- a/socket/socket.go +++ b/socket/socket.go @@ -29,6 +29,7 @@ type UnderlaySocket interface { WaitForDialIn() (*snet.UDPAddr, error) WaitForDialInWithContext(ctx context.Context) (*snet.UDPAddr, error) WaitForIncomingConn() (packets.UDPConn, error) + WaitForIncomingConnWithContext(ctx context.Context) (packets.UDPConn, error) Dial(remote snet.UDPAddr, path snet.Path, options DialOptions, i int) (packets.UDPConn, error) DialAll(remote snet.UDPAddr, path []pathselection.PathQuality, options DialOptions) ([]packets.UDPConn, error) CloseAll() []error