Skip to content

Commit

Permalink
Merge pull request #725 from juanfont/switch-to-db-d
Browse files Browse the repository at this point in the history
Improve registration protocol implementation and switch to NodeKey as main identifier
  • Loading branch information
juanfont committed Aug 12, 2022
2 parents 73cd428 + 77bf1e8 commit 09cd7ba
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 53 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

- Updated dependencies (including the library that lacked armhf support) [#722](https://github.com/juanfont/headscale/pull/722)
- Fix missing group expansion in function `excludeCorretlyTaggedNodes` [#563](https://github.com/juanfont/headscale/issues/563)
- Improve registration protocol implementation and switch to NodeKey as main identifier [#725](https://github.com/juanfont/headscale/pull/725)

## 0.16.0 (2022-07-25)

Expand Down
62 changes: 52 additions & 10 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
)

const (
// TODO(juan): remove this once https://github.com/juanfont/headscale/issues/727 is fixed.
registrationHoldoff = time.Second * 5
reservedResponseHeaderSize = 4
RegisterMethodAuthKey = "authkey"
RegisterMethodOIDC = "oidc"
Expand Down Expand Up @@ -107,13 +109,17 @@ var registerWebAPITemplate = template.Must(
`))

// RegisterWebAPI shows a simple message in the browser to point to the CLI
// Listens in /register.
// Listens in /register/:nkey.
//
// This is not part of the Tailscale control API, as we could send whatever URL
// in the RegisterResponse.AuthURL field.
func (h *Headscale) RegisterWebAPI(
writer http.ResponseWriter,
req *http.Request,
) {
machineKeyStr := req.URL.Query().Get("key")
if machineKeyStr == "" {
vars := mux.Vars(req)
nodeKeyStr, ok := vars["nkey"]
if !ok || nodeKeyStr == "" {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Wrong params"))
Expand All @@ -129,7 +135,7 @@ func (h *Headscale) RegisterWebAPI(

var content bytes.Buffer
if err := registerWebAPITemplate.Execute(&content, registerWebAPITemplateConfig{
Key: machineKeyStr,
Key: nodeKeyStr,
}); err != nil {
log.Error().
Str("func", "RegisterWebAPI").
Expand Down Expand Up @@ -206,8 +212,6 @@ func (h *Headscale) RegistrationHandler(
now := time.Now().UTC()
machine, err := h.GetMachineByMachineKey(machineKey)
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Info().Str("machine", registerRequest.Hostinfo.Hostname).Msg("New machine")

machineKeyStr := MachinePublicKeyStripPrefix(machineKey)

// If the machine has AuthKey set, handle registration via PreAuthKeys
Expand All @@ -217,6 +221,44 @@ func (h *Headscale) RegistrationHandler(
return
}

// Check if the node is waiting for interactive login.
//
// TODO(juan): We could use this field to improve our protocol implementation,
// and hold the request until the client closes it, or the interactive
// login is completed (i.e., the user registers the machine).
// This is not implemented yet, as it is no strictly required. The only side-effect
// is that the client will hammer headscale with requests until it gets a
// successful RegisterResponse.
if registerRequest.Followup != "" {
if _, ok := h.registrationCache.Get(NodePublicKeyStripPrefix(registerRequest.NodeKey)); ok {
log.Debug().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Str("follow_up", registerRequest.Followup).
Msg("Machine is waiting for interactive login")

ticker := time.NewTicker(registrationHoldoff)
select {
case <-req.Context().Done():
return
case <-ticker.C:
h.handleMachineRegistrationNew(writer, req, machineKey, registerRequest)

return
}
}
}

log.Info().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Str("follow_up", registerRequest.Followup).
Msg("New machine not yet in the database")

givenName, err := h.GenerateGivenName(registerRequest.Hostinfo.Hostname)
if err != nil {
log.Error().
Expand Down Expand Up @@ -251,7 +293,7 @@ func (h *Headscale) RegistrationHandler(
}

h.registrationCache.Set(
machineKeyStr,
newMachine.NodeKey,
newMachine,
registerCacheExpiration,
)
Expand Down Expand Up @@ -652,16 +694,16 @@ func (h *Headscale) handleMachineRegistrationNew(
// The machine registration is new, redirect the client to the registration URL
log.Debug().
Str("machine", registerRequest.Hostinfo.Hostname).
Msg("The node is sending us a new NodeKey, sending auth url")
Msg("The node seems to be new, sending auth url")
if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf(
"%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
machineKey.String(),
)
} else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), MachinePublicKeyStripPrefix(machineKey))
resp.AuthURL = fmt.Sprintf("%s/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), NodePublicKeyStripPrefix(registerRequest.NodeKey))
}

respBody, err := encode(resp, &machineKey, h.privateKey)
Expand Down
16 changes: 6 additions & 10 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -417,21 +417,17 @@ func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *mux.Router {

router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet)
router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
router.HandleFunc("/register", h.RegisterWebAPI).Methods(http.MethodGet)
router.HandleFunc("/machine/{mkey}/map", h.PollNetMapHandler).
Methods(http.MethodPost)
router.HandleFunc("/register/{nkey}", h.RegisterWebAPI).Methods(http.MethodGet)
router.HandleFunc("/machine/{mkey}/map", h.PollNetMapHandler).Methods(http.MethodPost)
router.HandleFunc("/machine/{mkey}", h.RegistrationHandler).Methods(http.MethodPost)
router.HandleFunc("/oidc/register/{mkey}", h.RegisterOIDC).Methods(http.MethodGet)
router.HandleFunc("/oidc/register/{nkey}", h.RegisterOIDC).Methods(http.MethodGet)
router.HandleFunc("/oidc/callback", h.OIDCCallback).Methods(http.MethodGet)
router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet)
router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig).
Methods(http.MethodGet)
router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig).Methods(http.MethodGet)
router.HandleFunc("/windows", h.WindowsConfigMessage).Methods(http.MethodGet)
router.HandleFunc("/windows/tailscale.reg", h.WindowsRegConfig).
Methods(http.MethodGet)
router.HandleFunc("/windows/tailscale.reg", h.WindowsRegConfig).Methods(http.MethodGet)
router.HandleFunc("/swagger", SwaggerUI).Methods(http.MethodGet)
router.HandleFunc("/swagger/v1/openapiv2.json", SwaggerAPIv1).
Methods(http.MethodGet)
router.HandleFunc("/swagger/v1/openapiv2.json", SwaggerAPIv1).Methods(http.MethodGet)

if h.cfg.DERP.ServerEnabled {
router.HandleFunc("/derp", h.DERPHandler)
Expand Down
2 changes: 1 addition & 1 deletion cmd/headscale/cli/nodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ var registerNodeCmd = &cobra.Command{
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Error getting machine key from flag: %s", err),
fmt.Sprintf("Error getting node key from flag: %s", err),
output,
)

Expand Down
2 changes: 1 addition & 1 deletion grpcv1.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func (api headscaleV1APIServer) RegisterMachine(
) (*v1.RegisterMachineResponse, error) {
log.Trace().
Str("namespace", request.GetNamespace()).
Str("machine_key", request.GetKey()).
Str("node_key", request.GetKey()).
Msg("Registering machine")

machine, err := api.h.RegisterMachineFromAuthCallback(
Expand Down
19 changes: 16 additions & 3 deletions machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
return &m, nil
}

// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct.
// GetMachineByMachineKey finds a Machine by its MachineKey and returns the Machine struct.
func (h *Headscale) GetMachineByMachineKey(
machineKey key.MachinePublic,
) (*Machine, error) {
Expand All @@ -362,6 +362,19 @@ func (h *Headscale) GetMachineByMachineKey(
return &m, nil
}

// GetMachineByNodeKey finds a Machine by its current NodeKey.
func (h *Headscale) GetMachineByNodeKey(
nodeKey key.NodePublic,
) (*Machine, error) {
machine := Machine{}
if result := h.db.Preload("Namespace").First(&machine, "node_key = ?",
NodePublicKeyStripPrefix(nodeKey)); result.Error != nil {
return nil, result.Error
}

return &machine, nil
}

// UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database
// and updates it with the latest data from the database.
func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error {
Expand Down Expand Up @@ -762,11 +775,11 @@ func getTags(
}

func (h *Headscale) RegisterMachineFromAuthCallback(
machineKeyStr string,
nodeKeyStr string,
namespaceName string,
registrationMethod string,
) (*Machine, error) {
if machineInterface, ok := h.registrationCache.Get(machineKeyStr); ok {
if machineInterface, ok := h.registrationCache.Get(nodeKeyStr); ok {
if registrationMachine, ok := machineInterface.(Machine); ok {
namespace, err := h.GetNamespace(namespaceName)
if err != nil {
Expand Down
56 changes: 28 additions & 28 deletions oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ const (
errOIDCAllowedDomains = Error("authenticated principal does not match any allowed domain")
errOIDCAllowedUsers = Error("authenticated principal does not match any allowed user")
errOIDCInvalidMachineState = Error("requested machine state key expired before authorisation completed")
errOIDCMachineKeyMissing = Error("could not get machine key from cache")
errOIDCNodeKeyMissing = Error("could not get node key from cache")
)

type IDTokenClaims struct {
Expand Down Expand Up @@ -68,26 +68,26 @@ func (h *Headscale) initOIDC() error {
}

// RegisterOIDC redirects to the OIDC provider for authentication
// Puts machine key in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:mKey.
// Puts NodeKey in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:nKey.
func (h *Headscale) RegisterOIDC(
writer http.ResponseWriter,
req *http.Request,
) {
vars := mux.Vars(req)
machineKeyStr, ok := vars["mkey"]
if !ok || machineKeyStr == "" {
nodeKeyStr, ok := vars["nkey"]
if !ok || nodeKeyStr == "" {
log.Error().
Caller().
Msg("Missing machine key in URL")
http.Error(writer, "Missing machine key in URL", http.StatusBadRequest)
Msg("Missing node key in URL")
http.Error(writer, "Missing node key in URL", http.StatusBadRequest)

return
}

log.Trace().
Caller().
Str("machine_key", machineKeyStr).
Str("node_key", nodeKeyStr).
Msg("Received oidc register call")

randomBlob := make([]byte, randomByteSize)
Expand All @@ -102,8 +102,8 @@ func (h *Headscale) RegisterOIDC(

stateStr := hex.EncodeToString(randomBlob)[:32]

// place the machine key into the state cache, so it can be retrieved later
h.registrationCache.Set(stateStr, machineKeyStr, registerCacheExpiration)
// place the node key into the state cache, so it can be retrieved later
h.registrationCache.Set(stateStr, nodeKeyStr, registerCacheExpiration)

// Add any extra parameter provided in the configuration to the Authorize Endpoint request
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams))
Expand Down Expand Up @@ -135,7 +135,7 @@ var oidcCallbackTemplate = template.Must(
)

// OIDCCallback handles the callback from the OIDC endpoint
// Retrieves the mkey from the state cache and adds the machine to the users email namespace
// Retrieves the nkey from the state cache and adds the machine to the users email namespace
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
// TODO: Add groups information from OIDC tokens into machine HostInfo
// Listens in /oidc/callback.
Expand Down Expand Up @@ -178,7 +178,7 @@ func (h *Headscale) OIDCCallback(
return
}

machineKey, machineExists, err := h.validateMachineForOIDCCallback(writer, state, claims)
nodeKey, machineExists, err := h.validateMachineForOIDCCallback(writer, state, claims)
if err != nil || machineExists {
return
}
Expand All @@ -196,7 +196,7 @@ func (h *Headscale) OIDCCallback(
return
}

if err := h.registerMachineForOIDCCallback(writer, namespace, machineKey); err != nil {
if err := h.registerMachineForOIDCCallback(writer, namespace, nodeKey); err != nil {
return
}

Expand Down Expand Up @@ -401,7 +401,7 @@ func (h *Headscale) validateMachineForOIDCCallback(
writer http.ResponseWriter,
state string,
claims *IDTokenClaims,
) (*key.MachinePublic, bool, error) {
) (*key.NodePublic, bool, error) {
// retrieve machinekey from state cache
machineKeyIf, machineKeyFound := h.registrationCache.Get(state)
if !machineKeyFound {
Expand All @@ -420,14 +420,14 @@ func (h *Headscale) validateMachineForOIDCCallback(
return nil, false, errOIDCInvalidMachineState
}

var machineKey key.MachinePublic
machineKeyFromCache, machineKeyOK := machineKeyIf.(string)
err := machineKey.UnmarshalText(
[]byte(MachinePublicKeyEnsurePrefix(machineKeyFromCache)),
var nodeKey key.NodePublic
nodeKeyFromCache, nodeKeyOK := machineKeyIf.(string)
err := nodeKey.UnmarshalText(
[]byte(NodePublicKeyEnsurePrefix(nodeKeyFromCache)),
)
if err != nil {
log.Error().
Msg("could not parse machine public key")
Msg("could not parse node public key")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("could not parse public key"))
Expand All @@ -441,26 +441,26 @@ func (h *Headscale) validateMachineForOIDCCallback(
return nil, false, err
}

if !machineKeyOK {
log.Error().Msg("could not get machine key from cache")
if !nodeKeyOK {
log.Error().Msg("could not get node key from cache")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, err := writer.Write([]byte("could not get machine key from cache"))
_, err := writer.Write([]byte("could not get node key from cache"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}

return nil, false, errOIDCMachineKeyMissing
return nil, false, errOIDCNodeKeyMissing
}

// retrieve machine information if it exist
// The error is not important, because if it does not
// exist, then this is a new machine and we will move
// on to registration.
machine, _ := h.GetMachineByMachineKey(machineKey)
machine, _ := h.GetMachineByNodeKey(nodeKey)

if machine != nil {
log.Trace().
Expand Down Expand Up @@ -520,7 +520,7 @@ func (h *Headscale) validateMachineForOIDCCallback(
return nil, true, nil
}

return &machineKey, false, nil
return &nodeKey, false, nil
}

func getNamespaceName(
Expand Down Expand Up @@ -600,12 +600,12 @@ func (h *Headscale) findOrCreateNewNamespaceForOIDCCallback(
func (h *Headscale) registerMachineForOIDCCallback(
writer http.ResponseWriter,
namespace *Namespace,
machineKey *key.MachinePublic,
nodeKey *key.NodePublic,
) error {
machineKeyStr := MachinePublicKeyStripPrefix(*machineKey)
nodeKeyStr := NodePublicKeyStripPrefix(*nodeKey)

if _, err := h.RegisterMachineFromAuthCallback(
machineKeyStr,
nodeKeyStr,
namespace.Name,
RegisterMethodOIDC,
); err != nil {
Expand Down

0 comments on commit 09cd7ba

Please sign in to comment.