Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug Fix] Apply default domain before caching derived values #116

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
67 changes: 31 additions & 36 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,11 @@ const (
// Register a service by given arguments. This call will take the system's hostname
// and lookup IP by that hostname.
func Register(instance, service, domain string, port int, text []string, ifaces []net.Interface) (*Server, error) {
entry := NewServiceEntry(instance, service, domain)
entry.Port = port
entry.Text = text

if entry.Instance == "" {
return nil, fmt.Errorf("missing service instance name")
}
if entry.Service == "" {
return nil, fmt.Errorf("missing service name")
}
if entry.Domain == "" {
entry.Domain = "local."
}
if entry.Port == 0 {
return nil, fmt.Errorf("missing port")
entry, err := newRegisterServiceEntry(instance, service, domain, port, text)
if err != nil {
return nil, err
}

var err error
if entry.HostName == "" {
entry.HostName, err = os.Hostname()
if err != nil {
Expand Down Expand Up @@ -83,25 +70,9 @@ func Register(instance, service, domain string, port int, text []string, ifaces
// RegisterProxy registers a service proxy. This call will skip the hostname/IP lookup and
// will use the provided values.
func RegisterProxy(instance, service, domain string, port int, host string, ips []string, text []string, ifaces []net.Interface) (*Server, error) {
entry := NewServiceEntry(instance, service, domain)
entry.Port = port
entry.Text = text
entry.HostName = host

if entry.Instance == "" {
return nil, fmt.Errorf("missing service instance name")
}
if entry.Service == "" {
return nil, fmt.Errorf("missing service name")
}
if entry.HostName == "" {
return nil, fmt.Errorf("missing host name")
}
if entry.Domain == "" {
entry.Domain = "local"
}
if entry.Port == 0 {
return nil, fmt.Errorf("missing port")
entry, err := newRegisterServiceEntry(instance, service, domain, port, text)
if err != nil {
return nil, err
}

if !strings.HasSuffix(trimDot(entry.HostName), entry.Domain) {
Expand Down Expand Up @@ -137,6 +108,30 @@ func RegisterProxy(instance, service, domain string, port int, host string, ips
return s, nil
}

// newRegisterServiceEntry returns a ServiceEntry with defaults substituted as required.
func newRegisterServiceEntry(instance, service, domain string, port int, text []string) (*ServiceEntry, error) {
// Required parameters
if instance == "" {
return nil, fmt.Errorf("missing service instance name")
}
if service == "" {
return nil, fmt.Errorf("missing service name")
}
if port == 0 {
return nil, fmt.Errorf("missing port")
}
// Defaulted parameters
if domain == "" {
domain = "local."
}

entry := NewServiceEntry(instance, service, domain)
entry.Port = port
entry.Text = text

return entry, nil
}

const (
qClassCacheFlush uint16 = 1 << 15
)
Expand Down Expand Up @@ -525,7 +520,7 @@ func (s *Server) serviceTypeName(resp *dns.Msg, ttl uint32) {
}

// Perform probing & announcement
//TODO: implement a proper probing & conflict resolution
// TODO: implement a proper probing & conflict resolution
func (s *Server) probe() {
q := new(dns.Msg)
q.SetQuestion(s.service.ServiceInstanceName(), dns.TypePTR)
Expand Down
93 changes: 93 additions & 0 deletions service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package zeroconf
import (
"context"
"log"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -164,3 +165,95 @@ func TestSubtype(t *testing.T) {
}
})
}

// Test the default domain is applied.
func TestDefaultDomain(t *testing.T) {
t.Run("register", func(t *testing.T) {
server, err := Register(mdnsName, mdnsService, "", mdnsPort, []string{"txtv=0", "lo=2", "la=3"}, nil)
if err != nil {
t.Fatal(err)
}
if server == nil {
t.Fatal("expect non-nil")
}
// Check the service record's cached fields
sr := server.service.ServiceRecord
if strings.Contains(sr.serviceName, "..") {
t.Errorf("malformed service name: %s", sr.serviceName)
}
if strings.Contains(sr.serviceInstanceName, "..") {
t.Errorf("malformed service instance name: %s", sr.serviceInstanceName)
}

t.Logf("Published service: %+v", server.service)

ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()

// Wait for context to time out
<-ctx.Done()

t.Log("Shutting down.")
server.Shutdown()
})

t.Run("registerproxy", func(t *testing.T) {
server, err := RegisterProxy(mdnsName, mdnsService, "", mdnsPort, "localhost", []string{"::1"}, []string{"txtv=0", "lo=2", "la=3"}, nil)
if err != nil {
t.Fatal(err)
}
if server == nil {
t.Fatal("expect non-nil")
}
// Check the service record's cached fields
sr := server.service.ServiceRecord
if strings.Contains(sr.serviceName, "..") {
t.Errorf("malformed service name: %s", sr.serviceName)
}
if strings.Contains(sr.serviceInstanceName, "..") {
t.Errorf("malformed service instance name: %s", sr.serviceInstanceName)
}

t.Logf("Published service: %+v", server.service)

ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()

// Wait for context to time out
<-ctx.Done()

t.Log("Shutting down.")
server.Shutdown()
})
}

func TestNewRegisterServiceEntry(t *testing.T) {
tests := []struct {
name string
instance, service, domain string
port int
text []string
err bool
}{
{"minimal", mdnsName, mdnsService, mdnsDomain, mdnsPort, []string{}, false},
// Required parameters
{"require-instance", "", mdnsService, mdnsDomain, mdnsPort, []string{}, true},
{"require-service", mdnsName, "", mdnsDomain, mdnsPort, []string{}, true},
{"require-port", mdnsName, mdnsService, mdnsDomain, 0, []string{}, true},
// Default domain
{"default-domain", mdnsName, mdnsService, "", mdnsPort, []string{}, false},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
se, err := newRegisterServiceEntry(test.instance, test.service, test.domain, test.port, test.text)
if test.err && err == nil {
t.Error("expect error")
} else if !test.err && err != nil {
t.Error(err)
}
if err == nil && se == nil {
t.Error("expect non-nil")
}
})
}
}