From 9046d8d0cbf106fb1f5945f764bece52e4f9d553 Mon Sep 17 00:00:00 2001 From: Norio Nomura Date: Mon, 20 Oct 2025 17:37:58 +0900 Subject: [PATCH] pkg/portfwd: Refactor `Forwarder.OnEvent()` Change `client` parameter to `dialContext` that supports dialers other than using `guestagentclient.GuestAgentClient`. Signed-off-by: Norio Nomura --- pkg/hostagent/hostagent.go | 3 +- pkg/portfwd/client.go | 65 ++++++++++++++++---------------------- pkg/portfwd/forward.go | 5 ++- pkg/portfwd/listener.go | 16 ++++------ 4 files changed, 38 insertions(+), 51 deletions(-) diff --git a/pkg/hostagent/hostagent.go b/pkg/hostagent/hostagent.go index d66f2812a91..53ba5fd1039 100644 --- a/pkg/hostagent/hostagent.go +++ b/pkg/hostagent/hostagent.go @@ -813,7 +813,8 @@ func (a *HostAgent) processGuestAgentEvents(ctx context.Context, client *guestag if useSSHFwd { a.portForwarder.OnEvent(ctx, ev) } else { - a.grpcPortForwarder.OnEvent(ctx, client, ev) + dialContext := portfwd.DialContextToGRPCTunnel(client) + a.grpcPortForwarder.OnEvent(ctx, dialContext, ev) } } diff --git a/pkg/portfwd/client.go b/pkg/portfwd/client.go index ffd41ac6b5e..5090f20b268 100644 --- a/pkg/portfwd/client.go +++ b/pkg/portfwd/client.go @@ -18,61 +18,50 @@ import ( guestagentclient "github.com/lima-vm/lima/v2/pkg/guestagent/api/client" ) -func HandleTCPConnection(ctx context.Context, client *guestagentclient.GuestAgentClient, conn net.Conn, guestAddr string) { - id := fmt.Sprintf("tcp-%s-%s", conn.LocalAddr().String(), conn.RemoteAddr().String()) +func HandleTCPConnection(_ context.Context, dialContext func(ctx context.Context, network string, addr string) (net.Conn, error), conn net.Conn, guestAddr string) { + proxy := tcpproxy.DialProxy{Addr: guestAddr, DialContext: dialContext} + proxy.HandleConn(conn) +} - stream, err := client.Tunnel(ctx) +func HandleUDPConnection(ctx context.Context, dialContext func(ctx context.Context, network string, addr string) (net.Conn, error), conn net.PacketConn, guestAddr string) { + proxy, err := forwarder.NewUDPProxy(conn, func() (net.Conn, error) { + return dialContext(ctx, "udp", guestAddr) + }) if err != nil { - logrus.Errorf("could not open tcp tunnel for id: %s error:%v", id, err) - return - } - - // Handshake message to start tunnel - if err := stream.Send(&api.TunnelMessage{Id: id, Protocol: "tcp", GuestAddr: guestAddr}); err != nil { - logrus.Errorf("could not start tcp tunnel for id: %s error:%v", id, err) + logrus.WithError(err).Error("error in udp tunnel proxy") return } - rw := &GrpcClientRW{stream: stream, id: id, addr: guestAddr, protocol: "tcp"} - proxy := tcpproxy.DialProxy{DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { - return conn, nil - }} - proxy.HandleConn(rw) + defer func() { + err := proxy.Close() + if err != nil { + logrus.WithError(err).Error("error in closing udp tunnel proxy") + } + }() + proxy.Run() } -func HandleUDPConnection(ctx context.Context, client *guestagentclient.GuestAgentClient, conn net.PacketConn, guestAddr string) { - var udpConnectionCounter atomic.Uint32 - initialID := fmt.Sprintf("udp-%s", conn.LocalAddr().String()) - +func DialContextToGRPCTunnel(client *guestagentclient.GuestAgentClient) func(ctx context.Context, network, addr string) (net.Conn, error) { // gvisor-tap-vsock's UDPProxy demultiplexes client connections internally based on their source address. // It calls this dialer function only when it receives a datagram from a new, unrecognized client. // For each new client, we must return a new net.Conn, which in our case is a new gRPC stream. // The atomic counter ensures that each stream has a unique ID to distinguish them on the server side. - proxy, err := forwarder.NewUDPProxy(conn, func() (net.Conn, error) { - id := fmt.Sprintf("%s-%d", initialID, udpConnectionCounter.Add(1)) - stream, err := client.Tunnel(ctx) + var connectionCounter atomic.Uint32 + return func(_ context.Context, network, addr string) (net.Conn, error) { + // Passed context.Context is used for timeout on initiate connection, not for the lifetime of the connection. + // We use context.Background() here to avoid unexpected cancellation. + stream, err := client.Tunnel(context.Background()) if err != nil { - return nil, fmt.Errorf("could not open udp tunnel for id: %s error:%w", id, err) + return nil, fmt.Errorf("could not open tunnel for addr: %s error:%w", addr, err) } // Handshake message to start tunnel - if err := stream.Send(&api.TunnelMessage{Id: id, Protocol: "udp", GuestAddr: guestAddr}); err != nil { - return nil, fmt.Errorf("could not start udp tunnel for id: %s error:%w", id, err) + id := fmt.Sprintf("%s-%s-%d", network, addr, connectionCounter.Add(1)) + if err := stream.Send(&api.TunnelMessage{Id: id, Protocol: network, GuestAddr: addr}); err != nil { + return nil, fmt.Errorf("could not start tunnel for id: %s addr: %s error:%w", id, addr, err) } - rw := &GrpcClientRW{stream: stream, id: id, addr: guestAddr, protocol: "udp"} + rw := &GrpcClientRW{stream: stream, id: id, addr: addr, protocol: network} return rw, nil - }) - if err != nil { - logrus.Errorf("error in udp tunnel proxy for id: %s error:%v", initialID, err) - return } - - defer func() { - err := proxy.Close() - if err != nil { - logrus.Errorf("error in closing udp tunnel proxy for id: %s error:%v", initialID, err) - } - }() - proxy.Run() } type GrpcClientRW struct { diff --git a/pkg/portfwd/forward.go b/pkg/portfwd/forward.go index 003fca40884..1edd08fee0a 100644 --- a/pkg/portfwd/forward.go +++ b/pkg/portfwd/forward.go @@ -11,7 +11,6 @@ import ( "github.com/sirupsen/logrus" "github.com/lima-vm/lima/v2/pkg/guestagent/api" - guestagentclient "github.com/lima-vm/lima/v2/pkg/guestagent/api/client" "github.com/lima-vm/lima/v2/pkg/limatype" "github.com/lima-vm/lima/v2/pkg/limayaml" ) @@ -38,7 +37,7 @@ func (fw *Forwarder) Close() error { return fw.closableListeners.Close() } -func (fw *Forwarder) OnEvent(ctx context.Context, client *guestagentclient.GuestAgentClient, ev *api.Event) { +func (fw *Forwarder) OnEvent(ctx context.Context, dialContext func(ctx context.Context, network string, addr string) (net.Conn, error), ev *api.Event) { for _, f := range ev.AddedLocalPorts { // Before forwarding, check if any static rule matches this port otherwise it will be forwarded twice and cause a port conflict if fw.isPortStaticallyForwarded(f) { @@ -55,7 +54,7 @@ func (fw *Forwarder) OnEvent(ctx context.Context, client *guestagentclient.Guest continue } logrus.Infof("Forwarding %s from %s to %s", strings.ToUpper(f.Protocol), remote, local) - fw.closableListeners.Forward(ctx, client, f.Protocol, local, remote) + fw.closableListeners.Forward(ctx, dialContext, f.Protocol, local, remote) } for _, f := range ev.RemovedLocalPorts { local, remote := fw.forwardingAddresses(f) diff --git a/pkg/portfwd/listener.go b/pkg/portfwd/listener.go index 58e992f9aad..190a3bd06e0 100644 --- a/pkg/portfwd/listener.go +++ b/pkg/portfwd/listener.go @@ -14,8 +14,6 @@ import ( "sync" "github.com/sirupsen/logrus" - - guestagentclient "github.com/lima-vm/lima/v2/pkg/guestagent/api/client" ) type ClosableListeners struct { @@ -59,14 +57,14 @@ func (p *ClosableListeners) Close() error { return errors.Join(errs...) } -func (p *ClosableListeners) Forward(ctx context.Context, client *guestagentclient.GuestAgentClient, +func (p *ClosableListeners) Forward(ctx context.Context, dialContext func(ctx context.Context, network string, addr string) (net.Conn, error), protocol string, hostAddress string, guestAddress string, ) { switch protocol { case "tcp", "tcp6": - go p.forwardTCP(ctx, client, hostAddress, guestAddress) + go p.forwardTCP(ctx, dialContext, hostAddress, guestAddress) case "udp", "udp6": - go p.forwardUDP(ctx, client, hostAddress, guestAddress) + go p.forwardUDP(ctx, dialContext, hostAddress, guestAddress) } } @@ -93,7 +91,7 @@ func (p *ClosableListeners) Remove(_ context.Context, protocol, hostAddress, gue } } -func (p *ClosableListeners) forwardTCP(ctx context.Context, client *guestagentclient.GuestAgentClient, hostAddress, guestAddress string) { +func (p *ClosableListeners) forwardTCP(ctx context.Context, dialContext func(ctx context.Context, network string, addr string) (net.Conn, error), hostAddress, guestAddress string) { key := key("tcp", hostAddress, guestAddress) p.listenersRW.Lock() @@ -124,11 +122,11 @@ func (p *ClosableListeners) forwardTCP(ctx context.Context, client *guestagentcl } return } - go HandleTCPConnection(ctx, client, conn, guestAddress) + go HandleTCPConnection(ctx, dialContext, conn, guestAddress) } } -func (p *ClosableListeners) forwardUDP(ctx context.Context, client *guestagentclient.GuestAgentClient, hostAddress, guestAddress string) { +func (p *ClosableListeners) forwardUDP(ctx context.Context, dialContext func(ctx context.Context, network string, addr string) (net.Conn, error), hostAddress, guestAddress string) { key := key("udp", hostAddress, guestAddress) defer p.Remove(ctx, "udp", hostAddress, guestAddress) @@ -148,7 +146,7 @@ func (p *ClosableListeners) forwardUDP(ctx context.Context, client *guestagentcl p.udpListeners[key] = udpConn p.udpListenersRW.Unlock() - HandleUDPConnection(ctx, client, udpConn, guestAddress) + HandleUDPConnection(ctx, dialContext, udpConn, guestAddress) } func key(protocol, hostAddress, guestAddress string) string {