Skip to content

Commit

Permalink
Remove gin from the poll handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
juanfont committed Jun 20, 2022
1 parent dedeb4c commit 53e5c05
Showing 1 changed file with 74 additions and 71 deletions.
145 changes: 74 additions & 71 deletions poll.go
Expand Up @@ -2,13 +2,14 @@ package headscale

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"time"

"github.com/gin-gonic/gin"
"github.com/gorilla/mux"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/tailcfg"
Expand All @@ -33,13 +34,25 @@ const machineNameContextKey = contextKey("machineName")
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
func (h *Headscale) PollNetMapHandler(
w http.ResponseWriter,
r *http.Request,
) {
vars := mux.Vars(r)
machineKeyStr, ok := vars["mkey"]
if !ok || machineKeyStr == "" {
log.Error().
Str("handler", "PollNetMap").
Msg("No machine key in request")
http.Error(w, "No machine key in request", http.StatusBadRequest)

return
}
log.Trace().
Str("handler", "PollNetMap").
Str("id", ctx.Param("id")).
Str("id", machineKeyStr).
Msg("PollNetMapHandler called")
body, _ := io.ReadAll(ctx.Request.Body)
machineKeyStr := ctx.Param("id")
body, _ := io.ReadAll(r.Body)

var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr)))
Expand All @@ -48,7 +61,8 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot parse client key")
ctx.String(http.StatusBadRequest, "")

http.Error(w, "Cannot parse client key", http.StatusBadRequest)

return
}
Expand All @@ -59,7 +73,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot decode message")
ctx.String(http.StatusBadRequest, "")
http.Error(w, "Cannot decode message", http.StatusBadRequest)

return
}
Expand All @@ -70,20 +84,21 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
log.Warn().
Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", machineKey.String())
ctx.String(http.StatusUnauthorized, "")

http.Error(w, "", http.StatusUnauthorized)

return
}
log.Error().
Str("handler", "PollNetMap").
Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String())
ctx.String(http.StatusInternalServerError, "")
http.Error(w, "", http.StatusInternalServerError)

return
}
log.Trace().
Str("handler", "PollNetMap").
Str("id", ctx.Param("id")).
Str("id", machineKeyStr).
Str("machine", machine.Hostname).
Msg("Found machine in database")

Expand Down Expand Up @@ -120,11 +135,11 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Str("id", ctx.Param("id")).
Str("id", machineKeyStr).
Str("machine", machine.Hostname).
Err(err).
Msg("Failed to persist/update machine in the database")
ctx.String(http.StatusInternalServerError, ":(")
http.Error(w, "", http.StatusInternalServerError)

return
}
Expand All @@ -134,11 +149,11 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Str("id", ctx.Param("id")).
Str("id", machineKeyStr).
Str("machine", machine.Hostname).
Err(err).
Msg("Failed to get Map response")
ctx.String(http.StatusInternalServerError, ":(")
http.Error(w, "", http.StatusInternalServerError)

return
}
Expand All @@ -150,7 +165,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
log.Debug().
Str("handler", "PollNetMap").
Str("id", ctx.Param("id")).
Str("id", machineKeyStr).
Str("machine", machine.Hostname).
Bool("readOnly", req.ReadOnly).
Bool("omitPeers", req.OmitPeers).
Expand All @@ -162,7 +177,10 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
Str("handler", "PollNetMap").
Str("machine", machine.Hostname).
Msg("Client is starting up. Probably interested in a DERP map")
ctx.Data(http.StatusOK, "application/json; charset=utf-8", data)

w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(data)

return
}
Expand All @@ -177,7 +195,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
// Only create update channel if it has not been created
log.Trace().
Str("handler", "PollNetMap").
Str("id", ctx.Param("id")).
Str("id", machineKeyStr).
Str("machine", machine.Hostname).
Msg("Loading or creating update channel")

Expand All @@ -194,8 +212,9 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
Str("handler", "PollNetMap").
Str("machine", machine.Hostname).
Msg("Client sent endpoint update and is ok with a response without peer list")
ctx.Data(http.StatusOK, "application/json; charset=utf-8", data)

w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(data)
// It sounds like we should update the nodes when we have received a endpoint update
// even tho the comments in the tailscale code dont explicitly say so.
updateRequestsFromNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "endpoint-update").
Expand All @@ -208,7 +227,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
Str("handler", "PollNetMap").
Str("machine", machine.Hostname).
Msg("Ignoring request, don't know how to handle it")
ctx.String(http.StatusBadRequest, "")
http.Error(w, "", http.StatusBadRequest)

return
}
Expand All @@ -232,7 +251,8 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
updateChan <- struct{}{}

h.PollNetMapStream(
ctx,
w,
r,
machine,
req,
machineKey,
Expand All @@ -242,7 +262,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
)
log.Trace().
Str("handler", "PollNetMap").
Str("id", ctx.Param("id")).
Str("id", machineKeyStr).
Str("machine", machine.Hostname).
Msg("Finished stream, closing PollNetMap session")
}
Expand All @@ -251,49 +271,30 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
// stream logic, ensuring we communicate updates and data
// to the connected clients.
func (h *Headscale) PollNetMapStream(
ctx *gin.Context,
w http.ResponseWriter,
r *http.Request,
machine *Machine,
mapRequest tailcfg.MapRequest,
machineKey key.MachinePublic,
pollDataChan chan []byte,
keepAliveChan chan []byte,
updateChan chan struct{},
) {
{
machine, err := h.GetMachineByMachineKey(machineKey)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", machineKey.String())
ctx.String(http.StatusUnauthorized, "")
ctx := context.WithValue(context.Background(), machineNameContextKey, machine.Hostname)

return
}
log.Error().
Str("handler", "PollNetMap").
Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String())
ctx.String(http.StatusInternalServerError, "")
ctx, cancel := context.WithCancel(ctx)
defer cancel()

return
}

ctx := context.WithValue(ctx.Request.Context(), machineNameContextKey, machine.Hostname)

ctx, cancel := context.WithCancel(ctx)
defer cancel()

go h.scheduledPollWorker(
ctx,
updateChan,
keepAliveChan,
machineKey,
mapRequest,
machine,
)
}
go h.scheduledPollWorker(
ctx,
updateChan,
keepAliveChan,
machineKey,
mapRequest,
machine,
)

ctx.Stream(func(writer io.Writer) bool {
for {
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname).
Expand All @@ -312,7 +313,7 @@ func (h *Headscale) PollNetMapStream(
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Sending data received via pollData channel")
_, err := writer.Write(data)
_, err := w.Write(data)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Expand All @@ -321,7 +322,7 @@ func (h *Headscale) PollNetMapStream(
Err(err).
Msg("Cannot write data")

return false
break
}
log.Trace().
Str("handler", "PollNetMapStream").
Expand All @@ -343,7 +344,7 @@ func (h *Headscale) PollNetMapStream(

// client has been removed from database
// since the stream opened, terminate connection.
return false
break
}
now := time.Now().UTC()
machine.LastSeen = &now
Expand All @@ -369,7 +370,7 @@ func (h *Headscale) PollNetMapStream(
Msg("Machine entry in database updated successfully after sending pollData")
}

return true
break

case data := <-keepAliveChan:
log.Trace().
Expand All @@ -378,7 +379,7 @@ func (h *Headscale) PollNetMapStream(
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Sending keep alive message")
_, err := writer.Write(data)
_, err := w.Write(data)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Expand All @@ -387,7 +388,7 @@ func (h *Headscale) PollNetMapStream(
Err(err).
Msg("Cannot write keep alive message")

return false
break
}
log.Trace().
Str("handler", "PollNetMapStream").
Expand All @@ -409,7 +410,7 @@ func (h *Headscale) PollNetMapStream(

// client has been removed from database
// since the stream opened, terminate connection.
return false
break
}
now := time.Now().UTC()
machine.LastSeen = &now
Expand All @@ -430,7 +431,7 @@ func (h *Headscale) PollNetMapStream(
Msg("Machine updated successfully after sending keep alive")
}

return true
break

case <-updateChan:
log.Trace().
Expand Down Expand Up @@ -460,7 +461,7 @@ func (h *Headscale) PollNetMapStream(
Err(err).
Msg("Could not get the map update")
}
_, err = writer.Write(data)
_, err = w.Write(data)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Expand All @@ -471,7 +472,7 @@ func (h *Headscale) PollNetMapStream(
updateRequestsSentToNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "failed").
Inc()

return false
return
}
log.Trace().
Str("handler", "PollNetMapStream").
Expand Down Expand Up @@ -499,7 +500,7 @@ func (h *Headscale) PollNetMapStream(

// client has been removed from database
// since the stream opened, terminate connection.
return false
return
}
now := time.Now().UTC()

Expand Down Expand Up @@ -529,9 +530,9 @@ func (h *Headscale) PollNetMapStream(
Msgf("%s is up to date", machine.Hostname)
}

return true
return

case <-ctx.Request.Context().Done():
case <-ctx.Done():
log.Info().
Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname).
Expand All @@ -550,7 +551,7 @@ func (h *Headscale) PollNetMapStream(

// client has been removed from database
// since the stream opened, terminate connection.
return false
break
}
now := time.Now().UTC()
machine.LastSeen = &now
Expand All @@ -564,9 +565,11 @@ func (h *Headscale) PollNetMapStream(
Msg("Cannot update machine LastSeen")
}

return false
break
}
})
}

log.Info().Msgf("Closing poll loop to %s", machine.Hostname)
}

func (h *Headscale) scheduledPollWorker(
Expand Down

0 comments on commit 53e5c05

Please sign in to comment.