diff --git a/internal/conn/udp.go b/internal/conn/udp.go index 266cc98..a27fbc6 100644 --- a/internal/conn/udp.go +++ b/internal/conn/udp.go @@ -6,6 +6,7 @@ import ( C "github.com/Dreamacro/clash/constant" "github.com/igoogolx/itun2socks/internal/constants" "net" + "sync" "time" ) @@ -23,6 +24,11 @@ type UdpConnContext struct { metadata *C.Metadata conn UdpConn rule constants.RuleType + wg *sync.WaitGroup +} + +func (u *UdpConnContext) Wg() *sync.WaitGroup { + return u.wg } func (u *UdpConnContext) Ctx() context.Context { @@ -41,12 +47,13 @@ func (u *UdpConnContext) Conn() UdpConn { return u.conn } -func NewUdpConnContext(ctx context.Context, conn UdpConn, metadata *C.Metadata) (*UdpConnContext, error) { +func NewUdpConnContext(ctx context.Context, conn UdpConn, metadata *C.Metadata, wg *sync.WaitGroup) (*UdpConnContext, error) { var connContext = &UdpConnContext{ ctx, metadata, conn, constants.RuleProxy, + wg, } for _, matcher := range GetConnMatcher() { diff --git a/internal/local_server/listener.go b/internal/local_server/listener.go index 02a963b..2370060 100644 --- a/internal/local_server/listener.go +++ b/internal/local_server/listener.go @@ -75,7 +75,10 @@ func processTcp(t C.ConnContext) { } func processUdp(u *inbound.PacketAdapter) { - ct, err := conn.NewUdpConnContext(context.Background(), udpConn{u.UDPPacket}, u.Metadata()) + var wg sync.WaitGroup + wg.Add(1) + defer wg.Wait() + ct, err := conn.NewUdpConnContext(context.Background(), udpConn{u.UDPPacket}, u.Metadata(), &wg) if err != nil { return } diff --git a/internal/proxy_handler/handler.go b/internal/proxy_handler/handler.go index 784abbc..a1060b3 100644 --- a/internal/proxy_handler/handler.go +++ b/internal/proxy_handler/handler.go @@ -7,28 +7,56 @@ import ( "github.com/igoogolx/itun2socks/pkg/log" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" + "github.com/sagernet/sing/common/bufio/deadline" M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/common/network" - "io" "net" - "time" + "sync" ) type udpConn struct { network.PacketConn - read bool - dest M.Socksaddr - buff *buf.Buffer } func (uc *udpConn) ReadFrom(data []byte) (int, net.Addr, error) { - if uc.read { - return 0, nil, io.EOF + + var err error + var buff *buf.Buffer + var dest M.Socksaddr + + defer func() { + if buff != nil { + buff.Release() + } + + }() + + newBuffer := func() *buf.Buffer { + buff = buf.NewPacket() // do not use stack buffer + return buff + } + readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(uc) + if isReadWaiter { + readWaiter.InitializeReadWaiter(newBuffer) + } + + if isReadWaiter { + dest, err = readWaiter.WaitReadPacket() + } else { + dest, err = uc.ReadPacket(newBuffer()) + } + + if err != nil { + return 0, nil, err + } + + n, err := buff.Read(data) + + if err != nil { + return 0, nil, err } - n, err := uc.buff.Read(data) - uc.buff.Release() - uc.read = true - return n, uc.dest.UDPAddr(), err + + return n, dest.UDPAddr(), nil } func (uc *udpConn) WriteTo(data []byte, addr net.Addr) (int, error) { @@ -42,10 +70,6 @@ func (uc *udpConn) WriteTo(data []byte, addr net.Addr) (int, error) { return len(data), err } -func (uc ConnHandler) SetReadDeadline(t time.Time) error { - return nil -} - type ConnHandler struct { tcpIn chan conn.TcpConnContext udpIn chan conn.UdpConnContext @@ -70,12 +94,6 @@ func (uc ConnHandler) NewConnection(ctx context.Context, netConn net.Conn, metad } func (uc ConnHandler) NewPacketConnection(ctx context.Context, packetConn network.PacketConn, metadata M.Metadata) error { - defer func(packetConn network.PacketConn) { - err := packetConn.Close() - if err != nil { - log.Errorln("fail to close packetConn") - } - }(packetConn) local, err := net.ResolveUDPAddr("udp", metadata.Source.String()) if err != nil { return err @@ -86,39 +104,19 @@ func (uc ConnHandler) NewPacketConnection(ctx context.Context, packetConn networ } m := tunnel.CreateUdpMetadata(*local, *remote) - for { - var buff *buf.Buffer - newBuffer := func() *buf.Buffer { - buff = buf.NewPacket() // do not use stack buffer - return buff - } - var err error - var dest M.Socksaddr - readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(packetConn) - if isReadWaiter { - readWaiter.InitializeReadWaiter(newBuffer) - } - if isReadWaiter { - dest, err = readWaiter.WaitReadPacket() - } else { - dest, err = packetConn.ReadPacket(newBuffer()) - } - if err != nil { - if buff != nil { - buff.Release() - } - break - } - ct, err := conn.NewUdpConnContext(ctx, &udpConn{PacketConn: packetConn, dest: dest, buff: buff}, &m) - if err != nil { - if buff != nil { - buff.Release() - } - break - } - uc.udpIn <- *ct + if deadline.NeedAdditionalReadDeadline(packetConn) { + packetConn = deadline.NewFallbackPacketConn(bufio.NewNetPacketConn(packetConn)) // conn from sing should check NeedAdditionalReadDeadline } + var wg sync.WaitGroup + wg.Add(1) + defer wg.Wait() + + ct, err := conn.NewUdpConnContext(ctx, &udpConn{PacketConn: packetConn}, &m, &wg) + if err != nil { + return err + } + uc.udpIn <- *ct return nil } diff --git a/internal/tunnel/udp.go b/internal/tunnel/udp.go index 398c5c1..9801324 100644 --- a/internal/tunnel/udp.go +++ b/internal/tunnel/udp.go @@ -34,6 +34,7 @@ func copyUdpPacket(lc conn.UdpConn, rc conn.UdpConn) error { receivedBuf := pool.NewBytes(pool.BufSize) defer pool.FreeBytes(receivedBuf) for { + err := rc.SetReadDeadline(time.Now().Add(5 * time.Second)) if err != nil { return fmt.Errorf("fail to set udp conn read deadline: %v", err) @@ -66,6 +67,11 @@ func copyUdpPacket(lc conn.UdpConn, rc conn.UdpConn) error { func handleUdpConn(ct conn.UdpConnContext) { log.Debugln(log.FormatLog(log.UdpPrefix, "handle udp conn, remote address: %v"), ct.Metadata().RemoteAddress()) defer func() { + err := closeConn(ct.Conn()) + ct.Wg().Done() + if err != nil { + log.Warnln(log.FormatLog(log.UdpPrefix, "fail to close remote conn,err: %v"), err) + } log.Debugln(log.FormatLog(log.UdpPrefix, "close remote conn: %v"), ct.Metadata().String()) }() var lc conn.UdpConn