Skip to content

Commit

Permalink
remove "stripEmailDomain" argument
Browse files Browse the repository at this point in the history
This commit makes a wrapper function round the normalisation requiring
"stripEmailDomain" which has to be passed in almost all functions of
headscale by loading it from Viper instead.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
  • Loading branch information
kradalby committed Jun 21, 2023
1 parent 161243c commit 717abe8
Show file tree
Hide file tree
Showing 16 changed files with 127 additions and 220 deletions.
1 change: 0 additions & 1 deletion hscontrol/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
database, err := db.NewHeadscaleDatabase(
cfg.DBtype,
dbString,
cfg.OIDC.StripEmaildomain,
app.dbDebug,
app.stateUpdateChan,
cfg.IPPrefixes,
Expand Down
12 changes: 6 additions & 6 deletions hscontrol/db/acls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) {
},
}

got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{})
assert.NoError(t, err)

want := []tailcfg.FilterRule{
Expand Down Expand Up @@ -107,7 +107,7 @@ func TestInvalidTagValidUser(t *testing.T) {
},
}

got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{})
assert.NoError(t, err)

want := []tailcfg.FilterRule{
Expand Down Expand Up @@ -169,7 +169,7 @@ func TestPortGroup(t *testing.T) {
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
assert.NoError(t, err)

got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{})
assert.NoError(t, err)

want := []tailcfg.FilterRule{
Expand Down Expand Up @@ -224,7 +224,7 @@ func TestPortUser(t *testing.T) {
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
assert.NoError(t, err)

got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{})
assert.NoError(t, err)

want := []tailcfg.FilterRule{
Expand Down Expand Up @@ -285,7 +285,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) {
// c.Assert(rules[0].DstPorts, check.HasLen, 1)
// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")

got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{})
assert.NoError(t, err)

want := []tailcfg.FilterRule{
Expand Down Expand Up @@ -361,7 +361,7 @@ func TestValidTagInvalidUser(t *testing.T) {
},
}

got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{machine2}, false)
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{machine2})
assert.NoError(t, err)

want := []tailcfg.FilterRule{
Expand Down
15 changes: 6 additions & 9 deletions hscontrol/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,15 @@ type HSDatabase struct {

ipAllocationMutex sync.Mutex

ipPrefixes []netip.Prefix
baseDomain string
stripEmailDomain bool
ipPrefixes []netip.Prefix
baseDomain string
}

// TODO(kradalby): assemble this struct from toptions or something typed
// rather than arguments.
func NewHeadscaleDatabase(
dbType, connectionAddr string,
stripEmailDomain, debug bool,
debug bool,
notifyStateChan chan<- struct{},
ipPrefixes []netip.Prefix,
baseDomain string,
Expand All @@ -64,9 +63,8 @@ func NewHeadscaleDatabase(
db: dbConn,
notifyStateChan: notifyStateChan,

ipPrefixes: ipPrefixes,
baseDomain: baseDomain,
stripEmailDomain: stripEmailDomain,
ipPrefixes: ipPrefixes,
baseDomain: baseDomain,
}

log.Debug().Msgf("database %#v", dbConn)
Expand Down Expand Up @@ -202,9 +200,8 @@ func NewHeadscaleDatabase(

for item, machine := range machines {
if machine.GivenName == "" {
normalizedHostname, err := util.NormalizeToFQDNRules(
normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper(
machine.Hostname,
stripEmailDomain,
)
if err != nil {
log.Error().
Expand Down
3 changes: 1 addition & 2 deletions hscontrol/db/machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -632,9 +632,8 @@ func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string
}

func (hsdb *HSDatabase) generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
normalizedHostname, err := util.NormalizeToFQDNRules(
normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper(
suppliedName,
hsdb.stripEmailDomain,
)
if err != nil {
return "", err
Expand Down
32 changes: 9 additions & 23 deletions hscontrol/db/machine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
testPeers, err := db.ListPeers(testMachine)
c.Assert(err, check.IsNil)

adminRules, _, err := policy.GenerateFilterRules(aclPolicy, adminMachine, adminPeers, false)
adminRules, _, err := policy.GenerateFilterRules(aclPolicy, adminMachine, adminPeers)
c.Assert(err, check.IsNil)

testRules, _, err := policy.GenerateFilterRules(aclPolicy, testMachine, testPeers, false)
testRules, _, err := policy.GenerateFilterRules(aclPolicy, testMachine, testPeers)
c.Assert(err, check.IsNil)

peersOfAdminMachine := policy.FilterMachinesByACL(adminMachine, adminPeers, adminRules)
Expand Down Expand Up @@ -482,9 +482,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
}{
{
name: "simple machine name generation",
db: &HSDatabase{
stripEmailDomain: true,
},
db: &HSDatabase{},
args: args{
suppliedName: "testmachine",
randomSuffix: false,
Expand All @@ -494,9 +492,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
},
{
name: "machine name with 53 chars",
db: &HSDatabase{
stripEmailDomain: true,
},
db: &HSDatabase{},
args: args{
suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine",
randomSuffix: false,
Expand All @@ -506,9 +502,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
},
{
name: "machine name with 63 chars",
db: &HSDatabase{
stripEmailDomain: true,
},
db: &HSDatabase{},
args: args{
suppliedName: "machineeee12345678901234567890123456789012345678901234567890123",
randomSuffix: false,
Expand All @@ -518,9 +512,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
},
{
name: "machine name with 64 chars",
db: &HSDatabase{
stripEmailDomain: true,
},
db: &HSDatabase{},
args: args{
suppliedName: "machineeee123456789012345678901234567890123456789012345678901234",
randomSuffix: false,
Expand All @@ -530,9 +522,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
},
{
name: "machine name with 73 chars",
db: &HSDatabase{
stripEmailDomain: true,
},
db: &HSDatabase{},
args: args{
suppliedName: "machineeee123456789012345678901234567890123456789012345678901234567890123",
randomSuffix: false,
Expand All @@ -542,9 +532,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
},
{
name: "machine name with random suffix",
db: &HSDatabase{
stripEmailDomain: true,
},
db: &HSDatabase{},
args: args{
suppliedName: "test",
randomSuffix: true,
Expand All @@ -554,9 +542,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
},
{
name: "machine name with 63 chars with random suffix",
db: &HSDatabase{
stripEmailDomain: true,
},
db: &HSDatabase{},
args: args{
suppliedName: "machineeee12345678901234567890123456789012345678901234567890123",
randomSuffix: true,
Expand Down
2 changes: 1 addition & 1 deletion hscontrol/db/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
approvedRoutes = append(approvedRoutes, advertisedRoute)
} else {
// TODO(kradalby): figure out how to get this to depend on less stuff
approvedIps, err := aclPolicy.ExpandAlias(types.Machines{*machine}, approvedAlias, hsdb.stripEmailDomain)
approvedIps, err := aclPolicy.ExpandAlias(types.Machines{*machine}, approvedAlias)
if err != nil {
log.Err(err).
Str("alias", approvedAlias).
Expand Down
1 change: 0 additions & 1 deletion hscontrol/db/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ func (s *Suite) ResetDB(c *check.C) {
"sqlite3",
tmpDir+"/headscale_test.db",
false,
false,
sink,
[]netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
Expand Down
1 change: 0 additions & 1 deletion hscontrol/grpcv1.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,6 @@ func (api headscaleV1APIServer) ListMachines(
m := machine.Proto()
validTags, invalidTags := api.h.ACLPolicy.GetTagsOfMachine(
machine,
api.h.cfg.OIDC.StripEmaildomain,
)
m.InvalidTags = invalidTags
m.ValidTags = validTags
Expand Down
10 changes: 2 additions & 8 deletions hscontrol/mapper/mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ type Mapper struct {
dnsCfg *tailcfg.DNSConfig
logtail bool
randomClientPort bool
stripEmailDomain bool
}

func NewMapper(
Expand All @@ -53,7 +52,6 @@ func NewMapper(
dnsCfg *tailcfg.DNSConfig,
logtail bool,
randomClientPort bool,
stripEmailDomain bool,
) *Mapper {
return &Mapper{
db: db,
Expand All @@ -66,7 +64,6 @@ func NewMapper(
dnsCfg: dnsCfg,
logtail: logtail,
randomClientPort: randomClientPort,
stripEmailDomain: stripEmailDomain,
}
}

Expand All @@ -87,14 +84,13 @@ func fullMapResponse(
machine *types.Machine,
peers types.Machines,

stripEmailDomain bool,
baseDomain string,
dnsCfg *tailcfg.DNSConfig,
derpMap *tailcfg.DERPMap,
logtail bool,
randomClientPort bool,
) (*tailcfg.MapResponse, error) {
tailnode, err := tailNode(*machine, pol, dnsCfg, baseDomain, stripEmailDomain)
tailnode, err := tailNode(*machine, pol, dnsCfg, baseDomain)
if err != nil {
return nil, err
}
Expand All @@ -103,7 +99,6 @@ func fullMapResponse(
pol,
machine,
peers,
stripEmailDomain,
)
if err != nil {
return nil, err
Expand All @@ -129,7 +124,7 @@ func fullMapResponse(
peers,
)

tailPeers, err := tailNodes(peers, pol, dnsCfg, baseDomain, stripEmailDomain)
tailPeers, err := tailNodes(peers, pol, dnsCfg, baseDomain)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -296,7 +291,6 @@ func (m Mapper) CreateMapResponse(
pol,
machine,
peers,
m.stripEmailDomain,
m.baseDomain,
m.dnsCfg,
m.derpMap,
Expand Down
6 changes: 0 additions & 6 deletions hscontrol/mapper/mapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,6 @@ func Test_fullMapResponse(t *testing.T) {
machine *types.Machine
peers types.Machines

stripEmailDomain bool
baseDomain string
dnsConfig *tailcfg.DNSConfig
derpMap *tailcfg.DERPMap
Expand All @@ -335,7 +334,6 @@ func Test_fullMapResponse(t *testing.T) {
// pol: &policy.ACLPolicy{},
// dnsConfig: &tailcfg.DNSConfig{},
// baseDomain: "",
// stripEmailDomain: false,
// want: nil,
// wantErr: true,
// },
Expand All @@ -344,7 +342,6 @@ func Test_fullMapResponse(t *testing.T) {
pol: &policy.ACLPolicy{},
machine: mini,
peers: []types.Machine{},
stripEmailDomain: false,
baseDomain: "",
dnsConfig: &tailcfg.DNSConfig{},
derpMap: &tailcfg.DERPMap{},
Expand Down Expand Up @@ -375,7 +372,6 @@ func Test_fullMapResponse(t *testing.T) {
peers: []types.Machine{
peer1,
},
stripEmailDomain: false,
baseDomain: "",
dnsConfig: &tailcfg.DNSConfig{},
derpMap: &tailcfg.DERPMap{},
Expand Down Expand Up @@ -417,7 +413,6 @@ func Test_fullMapResponse(t *testing.T) {
peer1,
peer2,
},
stripEmailDomain: false,
baseDomain: "",
dnsConfig: &tailcfg.DNSConfig{},
derpMap: &tailcfg.DERPMap{},
Expand Down Expand Up @@ -458,7 +453,6 @@ func Test_fullMapResponse(t *testing.T) {
tt.pol,
tt.machine,
tt.peers,
tt.stripEmailDomain,
tt.baseDomain,
tt.dnsConfig,
tt.derpMap,
Expand Down
5 changes: 1 addition & 4 deletions hscontrol/mapper/tail.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ func tailNodes(
pol *policy.ACLPolicy,
dnsConfig *tailcfg.DNSConfig,
baseDomain string,
stripEmailDomain bool,
) ([]*tailcfg.Node, error) {
nodes := make([]*tailcfg.Node, len(machines))

Expand All @@ -28,7 +27,6 @@ func tailNodes(
pol,
dnsConfig,
baseDomain,
stripEmailDomain,
)
if err != nil {
return nil, err
Expand All @@ -47,7 +45,6 @@ func tailNode(
pol *policy.ACLPolicy,
dnsConfig *tailcfg.DNSConfig,
baseDomain string,
stripEmailDomain bool,
) (*tailcfg.Node, error) {
nodeKey, err := machine.NodePublicKey()
if err != nil {
Expand Down Expand Up @@ -107,7 +104,7 @@ func tailNode(

online := machine.IsOnline()

tags, _ := pol.GetTagsOfMachine(machine, stripEmailDomain)
tags, _ := pol.GetTagsOfMachine(machine)
tags = lo.Uniq(append(tags, machine.ForcedTags...))

node := tailcfg.Node{
Expand Down

0 comments on commit 717abe8

Please sign in to comment.