diff --git a/build/gobind/gobind.go b/build/gobind/gobind.go index 814e85f3..b98d2294 100644 --- a/build/gobind/gobind.go +++ b/build/gobind/gobind.go @@ -56,10 +56,6 @@ func (m *Pinecone) PublicKey() string { return m.PineconeRouter.PublicKey().String() } -func (m *Pinecone) RegisterNetworkInterface(info pineconeMulticast.InterfaceInfo) { - m.PineconeMulticast.RegisterInterface(info) -} - func (m *Pinecone) SetMulticastEnabled(enabled bool) { if enabled { m.PineconeMulticast.Start() diff --git a/multicast/multicast.go b/multicast/multicast.go index c025a6e4..bb031447 100644 --- a/multicast/multicast.go +++ b/multicast/multicast.go @@ -55,19 +55,21 @@ type AltInterface struct { } type Multicast struct { - r *router.Router - log types.Logger - ctx context.Context - cancel context.CancelFunc - id string - started atomic.Bool - interfaces sync.Map // -> *multicastInterface - dialling sync.Map - listener net.Listener - dialer net.Dialer - tcpLC net.ListenConfig - udpLC net.ListenConfig - altInterfaces []AltInterface + r *router.Router + log types.Logger + ctx context.Context + cancel context.CancelFunc + id string + started atomic.Bool + interfaces sync.Map // -> *multicastInterface + dialling sync.Map + listener net.Listener + dialer net.Dialer + tcpLC net.ListenConfig + udpLC net.ListenConfig + altInterfaces map[string]AltInterface + interfaceCallback func() + callbackMutex sync.Mutex } type multicastInterface struct { @@ -98,7 +100,40 @@ func NewMulticast( return m } -func (m *Multicast) RegisterInterface(info InterfaceInfo) { +func (m *Multicast) RegisterNetworkCallback(intfCallback func() []InterfaceInfo) { + if intfCallback == nil { + return + } + + m.callbackMutex.Lock() + defer m.callbackMutex.Unlock() + // Assign the callback function used to obtain current interface information. + m.interfaceCallback = func() { + // Save a reference to the previously registered interfaces. + oldInterfaces := m.altInterfaces + // Clear out any previously registered interfaces. + m.altInterfaces = make(map[string]AltInterface) + + // Register each returned interface. + for _, intf := range intfCallback() { + m.registerInterface(intf) + } + + // If any of the previously registered interfaces that were being used for + // multicast discovery are no longer present, cancel their context/s so they + // are cleaned up appropriately. + for _, intf := range oldInterfaces { + if _, ok := m.altInterfaces[intf.iface.Name]; !ok { + if v, ok := m.interfaces.Load(intf.iface.Name); ok { + mi := v.(*multicastInterface) + mi.cancel() + } + } + } + } +} + +func (m *Multicast) registerInterface(info InterfaceInfo) { iface := AltInterface{ net.Interface{ Name: info.Name, @@ -133,7 +168,7 @@ func (m *Multicast) RegisterInterface(info InterfaceInfo) { } } - m.altInterfaces = append(m.altInterfaces, iface) + m.altInterfaces[iface.iface.Name] = iface m.log.Println("Registered interface ", iface.iface.Name) } @@ -161,16 +196,21 @@ func (m *Multicast) Start() { } intfs := []net.Interface{} - if len(m.altInterfaces) > 0 { - for _, iface := range m.altInterfaces { - intfs = append(intfs, iface.iface) - } - } else { - intfs, err = net.Interfaces() - if err != nil { - return + func() { + m.callbackMutex.Lock() + defer m.callbackMutex.Unlock() + if m.interfaceCallback != nil { + m.interfaceCallback() + for _, iface := range m.altInterfaces { + intfs = append(intfs, iface.iface) + } + } else { + intfs, err = net.Interfaces() + if err != nil { + return + } } - } + }() for _, intf := range intfs { unsuitable := intf.Flags&net.FlagUp == 0 || @@ -380,7 +420,7 @@ func (m *Multicast) startIPv6(intf *multicastInterface) { func (m *Multicast) advertise(intf *multicastInterface, conn net.PacketConn, addr net.Addr) { defer m.interfaces.Delete(intf.Name) - //defer m.log.Println("Stop advertising on", intf.Name) + // defer m.log.Println("Stop advertising on", intf.Name) tcpaddr, _ := m.listener.Addr().(*net.TCPAddr) portBytes := make([]byte, 2) binary.BigEndian.PutUint16(portBytes, uint16(tcpaddr.Port)) @@ -410,7 +450,7 @@ func (m *Multicast) advertise(intf *multicastInterface, conn net.PacketConn, add func (m *Multicast) listen(intf *multicastInterface, conn net.PacketConn, srcaddr net.Addr) { defer m.interfaces.Delete(intf.Name) - //defer m.log.Println("Stop listening on", intf.Name) + // defer m.log.Println("Stop listening on", intf.Name) dialer := m.dialer dialer.LocalAddr = srcaddr dialer.Control = m.tcpOptions