Skip to content

Commit

Permalink
Set tags as part of handleAuthKeyCommon
Browse files Browse the repository at this point in the history
  • Loading branch information
tsujamin committed Aug 25, 2022
1 parent 6faa1d2 commit ac18723
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 6 deletions.
12 changes: 12 additions & 0 deletions grpcv1.go
Expand Up @@ -106,6 +106,18 @@ func (api headscaleV1APIServer) CreatePreAuthKey(
expiration = request.GetExpiration().AsTime()
}

if len(request.AclTags) > 0 {
for _, tag := range request.AclTags {
err := validateTag(tag)

if err != nil {
return &v1.CreatePreAuthKeyResponse{
PreAuthKey: nil,
}, status.Error(codes.InvalidArgument, err.Error())
}
}
}

preAuthKey, err := api.h.CreatePreAuthKey(
request.GetNamespace(),
request.GetReusable(),
Expand Down
7 changes: 7 additions & 0 deletions integration_cli_test.go
Expand Up @@ -260,6 +260,8 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() {
"24h",
"--output",
"json",
"--tags",
"tag:test1,tag:test2",
},
[]string{},
)
Expand Down Expand Up @@ -333,6 +335,11 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() {
listedPreAuthKeys[4].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
)

// Test that tags are present
for i := 0; i < count; i++ {
assert.DeepEquals(listedPreAuthKeys[i].AclTags, []string{"tag:test1,", "tag:test2"})
}

// Expire three keys
for i := 0; i < 3; i++ {
_, err := ExecuteCommand(
Expand Down
24 changes: 18 additions & 6 deletions preauth_keys.go
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"strconv"
"strings"
"time"

v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
Expand Down Expand Up @@ -55,6 +56,12 @@ func (h *Headscale) CreatePreAuthKey(
return nil, err
}

for _, tag := range aclTags {
if !strings.HasPrefix(tag, "tag:") {
return nil, fmt.Errorf("aclTag '%s' did not begin with 'tag:'", tag)
}
}

now := time.Now().UTC()
kstr, err := h.generateKey()
if err != nil {
Expand All @@ -77,12 +84,17 @@ func (h *Headscale) CreatePreAuthKey(
}

if len(aclTags) > 0 {
seenTags := map[string]bool{}

for _, tag := range aclTags {
if err := db.Save(&PreAuthKeyAclTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
return fmt.Errorf(
"failed to create key tag in the database: %w",
err,
)
if seenTags[tag] == false {
if err := db.Save(&PreAuthKeyAclTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
return fmt.Errorf(
"failed to ceate key tag in the database: %w",
err,
)
}
seenTags[tag] = true
}
}
}
Expand Down Expand Up @@ -222,7 +234,7 @@ func (key *PreAuthKey) toProto() *v1.PreAuthKey {

if len(key.AclTags) > 0 {
for idx := range key.AclTags {
protoKey.AclTags[idx] = key.AclTags[0].Tag
protoKey.AclTags[idx] = key.AclTags[idx].Tag
}
}

Expand Down
17 changes: 17 additions & 0 deletions preauth_keys_test.go
Expand Up @@ -190,3 +190,20 @@ func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
_, err = app.checkKeyValidity(pak.Key)
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
}

func (*Suite) TestPreAuthKeyAclTags(c *check.C) {
namespace, err := app.CreateNamespace("test8")
c.Assert(err, check.IsNil)

_, err = app.CreatePreAuthKey(namespace.Name, false, false, nil, []string{"badtag"})
c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected

tags := []string{"tag:test1", "tag:test2"}
tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"}
_, err = app.CreatePreAuthKey(namespace.Name, false, false, nil, tagsWithDuplicate)
c.Assert(err, check.IsNil)

listedPaks, err := app.ListPreAuthKeys("test8")
c.Assert(err, check.IsNil)
c.Assert(listedPaks[0].toProto().AclTags, check.DeepEquals, tags)
}
21 changes: 21 additions & 0 deletions protocol_common.go
Expand Up @@ -345,6 +345,7 @@ func (h *Headscale) handleAuthKeyCommon(
machine.NodeKey = nodeKey
machine.AuthKeyID = uint(pak.ID)
err := h.RefreshMachine(machine, registerRequest.Expiry)

if err != nil {
log.Error().
Caller().
Expand All @@ -355,6 +356,25 @@ func (h *Headscale) handleAuthKeyCommon(

return
}

aclTags := pak.toProto().AclTags
if len(aclTags) > 0 {
// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login
err = h.SetTags(machine, aclTags)
}

if err != nil {
log.Error().
Caller().
Bool("noise", machineKey.IsZero()).
Str("machine", machine.Hostname).
Strs("aclTags", aclTags).
Err(err).
Msg("Failed to set tags after refreshing machine")

return
}

} else {
now := time.Now().UTC()

Expand All @@ -380,6 +400,7 @@ func (h *Headscale) handleAuthKeyCommon(
NodeKey: nodeKey,
LastSeen: &now,
AuthKeyID: uint(pak.ID),
ForcedTags: pak.toProto().AclTags,
}

machine, err = h.RegisterMachine(
Expand Down

0 comments on commit ac18723

Please sign in to comment.