Skip to content

Commit

Permalink
appc,ipn/ipnlocal: add app connector routes if any part of a CNAME ch…
Browse files Browse the repository at this point in the history
…ain is routed

If any domain along a CNAME chain matches any of the routed domains, add
routes for the discovered domains.

Fixes tailscale/corp#16928

Signed-off-by: James Tucker <james@tailscale.com>
  • Loading branch information
raggi committed Feb 1, 2024
1 parent 2aeef4e commit e1a4b89
Show file tree
Hide file tree
Showing 3 changed files with 263 additions and 41 deletions.
156 changes: 115 additions & 41 deletions appc/appconnector.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"tailscale.com/types/views"
"tailscale.com/util/dnsname"
"tailscale.com/util/execqueue"
"tailscale.com/util/mak"
)

// RouteAdvertiser is an interface that allows the AppConnector to advertise
Expand Down Expand Up @@ -206,7 +207,16 @@ func (e *AppConnector) ObserveDNSResponse(res []byte) {
return
}

nextAnswer:
// cnameChain tracks a chain of CNAMEs for a given query in order to reverse
// a CNAME chain back to the original query for flattening. The keys are
// CNAME record targets, and the value is the name the record answers, so
// for www.example.com CNAME example.com, the map would contain
// ["example.com"] = "www.example.com".
var cnameChain map[string]string

// addressRecords is a list of address records found in the response.
var addressRecords map[string][]netip.Addr

for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
Expand All @@ -222,83 +232,147 @@ nextAnswer:
}
continue
}
if h.Type != dnsmessage.TypeA && h.Type != dnsmessage.TypeAAAA {

switch h.Type {
case dnsmessage.TypeCNAME, dnsmessage.TypeA, dnsmessage.TypeAAAA:
default:
if err := p.SkipAnswer(); err != nil {
return
}
continue
}

domain := h.Name.String()
if len(domain) == 0 {
return
}
domain = strings.TrimSuffix(domain, ".")
domain = strings.ToLower(domain)
e.logf("[v2] observed DNS response for %s", domain)

e.mu.Lock()
addrs, ok := e.domains[domain]
// match wildcard domains
if !ok {
for _, wc := range e.wildcards {
if dnsname.HasSuffix(domain, wc) {
e.domains[domain] = nil
ok = true
break
}
}
domain := strings.TrimSuffix(strings.ToLower(h.Name.String()), ".")
if len(domain) == 0 {
continue
}
e.mu.Unlock()

if !ok {
if err := p.SkipAnswer(); err != nil {
if h.Type == dnsmessage.TypeCNAME {
res, err := p.CNAMEResource()
if err != nil {
return
}
cname := strings.TrimSuffix(strings.ToLower(res.CNAME.String()), ".")
if len(cname) == 0 {
continue
}
mak.Set(&cnameChain, cname, domain)
continue
}

var addr netip.Addr
switch h.Type {
case dnsmessage.TypeA:
r, err := p.AResource()
if err != nil {
return
}
addr = netip.AddrFrom4(r.A)
addr := netip.AddrFrom4(r.A)
mak.Set(&addressRecords, domain, append(addressRecords[domain], addr))
case dnsmessage.TypeAAAA:
r, err := p.AAAAResource()
if err != nil {
return
}
addr = netip.AddrFrom16(r.AAAA)
addr := netip.AddrFrom16(r.AAAA)
mak.Set(&addressRecords, domain, append(addressRecords[domain], addr))
default:
if err := p.SkipAnswer(); err != nil {
return
}
continue
}
if slices.Contains(addrs, addr) {
}

e.mu.Lock()
defer e.mu.Unlock()

for domain, addrs := range addressRecords {
domain, isRouted := e.findRoutedDomainLocked(domain, cnameChain)

// domain and none of the CNAMEs in the chain are routed
if !isRouted {
continue
}
for _, route := range e.controlRoutes {
if route.Contains(addr) {
// record the new address associated with the domain for faster matching in subsequent
// requests and for diagnostic records.
e.mu.Lock()
e.domains[domain] = append(addrs, addr)
e.mu.Unlock()
continue nextAnswer

// advertise each address we have learned for the routed domain, that
// was not already known.
for _, addr := range addrs {
e.logf("[v2] observed routed DNS response for %s: %s", domain, addr)
if e.isAddrKnownLocked(domain, addr) {
continue
}

e.scheduleAdvertisement(domain, addr)
}
}
}

// starting from the given domain that resolved to an address, find it, or any
// of the domains in the CNAME chain toward resolving it, that are routed
// domains, returning the routed domain name and a bool indicating whether a
// routed domain was found.
// e.mu must be held.
func (e *AppConnector) findRoutedDomainLocked(domain string, cnameChain map[string]string) (string, bool) {
var isRouted bool
for {
_, isRouted = e.domains[domain]
if isRouted {
break
}

// match wildcard domains
for _, wc := range e.wildcards {
if dnsname.HasSuffix(domain, wc) {
e.domains[domain] = nil
isRouted = true
break
}
}

next, ok := cnameChain[domain]
if !ok {
break
}
domain = next
}
return domain, isRouted
}

// isAddrKnownLocked returns true if the address is known to be associated with
// the given domain. Known domain tables are updated for covered routes to speed
// up future matches.
// e.mu must be held.
func (e *AppConnector) isAddrKnownLocked(domain string, addr netip.Addr) bool {
if slices.Contains(e.domains[domain], addr) {
return true
}
for _, route := range e.controlRoutes {
if route.Contains(addr) {
// record the new address associated with the domain for faster matching in subsequent
// requests and for diagnostic records.
e.domains[domain] = append(e.domains[domain], addr)
return true
}
}
return false

}

// scheduleAdvertisement schedules an advertisement of the given address
// associated with the given domain.
func (e *AppConnector) scheduleAdvertisement(domain string, addr netip.Addr) {
e.queue.Add(func() {
if err := e.routeAdvertiser.AdvertiseRoute(netip.PrefixFrom(addr, addr.BitLen())); err != nil {
e.logf("failed to advertise route for %s: %v: %v", domain, addr, err)
continue
return
}
e.logf("[v2] advertised route for %v: %v", domain, addr)

e.mu.Lock()
e.domains[domain] = append(addrs, addr)
e.mu.Unlock()
}
defer e.mu.Unlock()

if !slices.Contains(e.domains[domain], addr) {
e.logf("[v2] advertised route for %v: %v", domain, addr)
e.domains[domain] = append(e.domains[domain], addr)
}
})
}
82 changes: 82 additions & 0 deletions appc/appconnector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ func TestDomainRoutes(t *testing.T) {
a := NewAppConnector(t.Logf, rc)
a.updateDomains([]string{"example.com"})
a.ObserveDNSResponse(dnsResponse("example.com.", "192.0.0.8"))
a.Wait(context.Background())

want := map[string][]netip.Addr{
"example.com": {netip.MustParseAddr("192.0.0.8")},
Expand All @@ -110,6 +111,7 @@ func TestDomainRoutes(t *testing.T) {
}

func TestObserveDNSResponse(t *testing.T) {
ctx := context.Background()
rc := &appctest.RouteCollector{}
a := NewAppConnector(t.Logf, rc)

Expand All @@ -123,19 +125,41 @@ func TestObserveDNSResponse(t *testing.T) {

a.updateDomains([]string{"example.com"})
a.ObserveDNSResponse(dnsResponse("example.com.", "192.0.0.8"))
a.Wait(ctx)
if got, want := rc.Routes(), wantRoutes; !slices.Equal(got, want) {
t.Errorf("got %v; want %v", got, want)
}

// a CNAME record chain should result in a route being added if the chain
// matches a routed domain.
a.updateDomains([]string{"www.example.com", "example.com"})
a.ObserveDNSResponse(dnsCNAMEResponse("192.0.0.9", "www.example.com.", "chain.example.com.", "example.com."))
a.Wait(ctx)
wantRoutes = append(wantRoutes, netip.MustParsePrefix("192.0.0.9/32"))
if got, want := rc.Routes(), wantRoutes; !slices.Equal(got, want) {
t.Errorf("got %v; want %v", got, want)
}

// a CNAME record chain should result in a route being added if the chain
// even if only found in the middle of the chain
a.ObserveDNSResponse(dnsCNAMEResponse("192.0.0.10", "outside.example.org.", "www.example.com.", "example.org."))
a.Wait(ctx)
wantRoutes = append(wantRoutes, netip.MustParsePrefix("192.0.0.10/32"))
if got, want := rc.Routes(), wantRoutes; !slices.Equal(got, want) {
t.Errorf("got %v; want %v", got, want)
}

wantRoutes = append(wantRoutes, netip.MustParsePrefix("2001:db8::1/128"))

a.ObserveDNSResponse(dnsResponse("example.com.", "2001:db8::1"))
a.Wait(ctx)
if got, want := rc.Routes(), wantRoutes; !slices.Equal(got, want) {
t.Errorf("got %v; want %v", got, want)
}

// don't re-advertise routes that have already been advertised
a.ObserveDNSResponse(dnsResponse("example.com.", "2001:db8::1"))
a.Wait(ctx)
if !slices.Equal(rc.Routes(), wantRoutes) {
t.Errorf("rc.Routes(): got %v; want %v", rc.Routes(), wantRoutes)
}
Expand All @@ -145,6 +169,7 @@ func TestObserveDNSResponse(t *testing.T) {
a.updateRoutes([]netip.Prefix{pfx})
wantRoutes = append(wantRoutes, pfx)
a.ObserveDNSResponse(dnsResponse("example.com.", "192.0.2.1"))
a.Wait(ctx)
if !slices.Equal(rc.Routes(), wantRoutes) {
t.Errorf("rc.Routes(): got %v; want %v", rc.Routes(), wantRoutes)
}
Expand All @@ -154,11 +179,13 @@ func TestObserveDNSResponse(t *testing.T) {
}

func TestWildcardDomains(t *testing.T) {
ctx := context.Background()
rc := &appctest.RouteCollector{}
a := NewAppConnector(t.Logf, rc)

a.updateDomains([]string{"*.example.com"})
a.ObserveDNSResponse(dnsResponse("foo.example.com.", "192.0.0.8"))
a.Wait(ctx)
if got, want := rc.Routes(), []netip.Prefix{netip.MustParsePrefix("192.0.0.8/32")}; !slices.Equal(got, want) {
t.Errorf("routes: got %v; want %v", got, want)
}
Expand Down Expand Up @@ -218,6 +245,61 @@ func dnsResponse(domain, address string) []byte {
return must.Get(b.Finish())
}

func dnsCNAMEResponse(address string, domains ...string) []byte {
addr := netip.MustParseAddr(address)
b := dnsmessage.NewBuilder(nil, dnsmessage.Header{})
b.EnableCompression()
b.StartAnswers()

if len(domains) >= 2 {
for i, domain := range domains[:len(domains)-1] {
b.CNAMEResource(
dnsmessage.ResourceHeader{
Name: dnsmessage.MustNewName(domain),
Type: dnsmessage.TypeCNAME,
Class: dnsmessage.ClassINET,
TTL: 0,
},
dnsmessage.CNAMEResource{
CNAME: dnsmessage.MustNewName(domains[i+1]),
},
)
}
}

domain := domains[len(domains)-1]

switch addr.BitLen() {
case 32:
b.AResource(
dnsmessage.ResourceHeader{
Name: dnsmessage.MustNewName(domain),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
TTL: 0,
},
dnsmessage.AResource{
A: addr.As4(),
},
)
case 128:
b.AAAAResource(
dnsmessage.ResourceHeader{
Name: dnsmessage.MustNewName(domain),
Type: dnsmessage.TypeAAAA,
Class: dnsmessage.ClassINET,
TTL: 0,
},
dnsmessage.AAAAResource{
AAAA: addr.As16(),
},
)
default:
panic("invalid address length")
}
return must.Get(b.Finish())
}

func prefixEqual(a, b netip.Prefix) bool {
return a == b
}
Expand Down

0 comments on commit e1a4b89

Please sign in to comment.