Skip to content

Commit

Permalink
Change multicast interface to accept an external network callback (#76)
Browse files Browse the repository at this point in the history
* Change multicast interface to accept an external callback for network interfaces

* Remove callback registration from gobind interface.

* Add comments explaining multicast interface callback registration.

* Further clarify multicast interface callback registration
  • Loading branch information
devonh committed Oct 26, 2022
1 parent ba39dd2 commit 639feef
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 30 deletions.
4 changes: 0 additions & 4 deletions build/gobind/gobind.go
Expand Up @@ -56,10 +56,6 @@ func (m *Pinecone) PublicKey() string {
return m.PineconeRouter.PublicKey().String()
}

func (m *Pinecone) RegisterNetworkInterface(info pineconeMulticast.InterfaceInfo) {
m.PineconeMulticast.RegisterInterface(info)
}

func (m *Pinecone) SetMulticastEnabled(enabled bool) {
if enabled {
m.PineconeMulticast.Start()
Expand Down
92 changes: 66 additions & 26 deletions multicast/multicast.go
Expand Up @@ -55,19 +55,21 @@ type AltInterface struct {
}

type Multicast struct {
r *router.Router
log types.Logger
ctx context.Context
cancel context.CancelFunc
id string
started atomic.Bool
interfaces sync.Map // -> *multicastInterface
dialling sync.Map
listener net.Listener
dialer net.Dialer
tcpLC net.ListenConfig
udpLC net.ListenConfig
altInterfaces []AltInterface
r *router.Router
log types.Logger
ctx context.Context
cancel context.CancelFunc
id string
started atomic.Bool
interfaces sync.Map // -> *multicastInterface
dialling sync.Map
listener net.Listener
dialer net.Dialer
tcpLC net.ListenConfig
udpLC net.ListenConfig
altInterfaces map[string]AltInterface
interfaceCallback func()
callbackMutex sync.Mutex
}

type multicastInterface struct {
Expand Down Expand Up @@ -98,7 +100,40 @@ func NewMulticast(
return m
}

func (m *Multicast) RegisterInterface(info InterfaceInfo) {
func (m *Multicast) RegisterNetworkCallback(intfCallback func() []InterfaceInfo) {
if intfCallback == nil {
return
}

m.callbackMutex.Lock()
defer m.callbackMutex.Unlock()
// Assign the callback function used to obtain current interface information.
m.interfaceCallback = func() {
// Save a reference to the previously registered interfaces.
oldInterfaces := m.altInterfaces
// Clear out any previously registered interfaces.
m.altInterfaces = make(map[string]AltInterface)

// Register each returned interface.
for _, intf := range intfCallback() {
m.registerInterface(intf)
}

// If any of the previously registered interfaces that were being used for
// multicast discovery are no longer present, cancel their context/s so they
// are cleaned up appropriately.
for _, intf := range oldInterfaces {
if _, ok := m.altInterfaces[intf.iface.Name]; !ok {
if v, ok := m.interfaces.Load(intf.iface.Name); ok {
mi := v.(*multicastInterface)
mi.cancel()
}
}
}
}
}

func (m *Multicast) registerInterface(info InterfaceInfo) {
iface := AltInterface{
net.Interface{
Name: info.Name,
Expand Down Expand Up @@ -133,7 +168,7 @@ func (m *Multicast) RegisterInterface(info InterfaceInfo) {
}
}

m.altInterfaces = append(m.altInterfaces, iface)
m.altInterfaces[iface.iface.Name] = iface
m.log.Println("Registered interface ", iface.iface.Name)
}

Expand Down Expand Up @@ -161,16 +196,21 @@ func (m *Multicast) Start() {
}

intfs := []net.Interface{}
if len(m.altInterfaces) > 0 {
for _, iface := range m.altInterfaces {
intfs = append(intfs, iface.iface)
}
} else {
intfs, err = net.Interfaces()
if err != nil {
return
func() {
m.callbackMutex.Lock()
defer m.callbackMutex.Unlock()
if m.interfaceCallback != nil {
m.interfaceCallback()
for _, iface := range m.altInterfaces {
intfs = append(intfs, iface.iface)
}
} else {
intfs, err = net.Interfaces()
if err != nil {
return
}
}
}
}()

for _, intf := range intfs {
unsuitable := intf.Flags&net.FlagUp == 0 ||
Expand Down Expand Up @@ -380,7 +420,7 @@ func (m *Multicast) startIPv6(intf *multicastInterface) {

func (m *Multicast) advertise(intf *multicastInterface, conn net.PacketConn, addr net.Addr) {
defer m.interfaces.Delete(intf.Name)
//defer m.log.Println("Stop advertising on", intf.Name)
// defer m.log.Println("Stop advertising on", intf.Name)
tcpaddr, _ := m.listener.Addr().(*net.TCPAddr)
portBytes := make([]byte, 2)
binary.BigEndian.PutUint16(portBytes, uint16(tcpaddr.Port))
Expand Down Expand Up @@ -410,7 +450,7 @@ func (m *Multicast) advertise(intf *multicastInterface, conn net.PacketConn, add

func (m *Multicast) listen(intf *multicastInterface, conn net.PacketConn, srcaddr net.Addr) {
defer m.interfaces.Delete(intf.Name)
//defer m.log.Println("Stop listening on", intf.Name)
// defer m.log.Println("Stop listening on", intf.Name)
dialer := m.dialer
dialer.LocalAddr = srcaddr
dialer.Control = m.tcpOptions
Expand Down

0 comments on commit 639feef

Please sign in to comment.