Skip to content

Commit

Permalink
Merge pull request #54 from axw/validate-service-name
Browse files Browse the repository at this point in the history
elasticapm: validate/sanitize service name
  • Loading branch information
axw committed Apr 26, 2018
2 parents c48d673 + 4fec446 commit ace23e4
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 36 deletions.
54 changes: 50 additions & 4 deletions env_test.go
@@ -1,14 +1,17 @@
package elasticapm_test

import (
"context"
"os"
"os/exec"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/elastic/apm-agent-go"
"github.com/elastic/apm-agent-go/model"
"github.com/elastic/apm-agent-go/transport/transporttest"
)

Expand All @@ -25,15 +28,15 @@ func TestTracerFlushIntervalEnvInvalid(t *testing.T) {
os.Setenv("ELASTIC_APM_FLUSH_INTERVAL", "aeon")
defer os.Unsetenv("ELASTIC_APM_FLUSH_INTERVAL")

_, err := elasticapm.NewTracer("tracer.testing", "")
_, err := elasticapm.NewTracer("tracer_testing", "")
assert.EqualError(t, err, "failed to parse ELASTIC_APM_FLUSH_INTERVAL: time: invalid duration aeon")
}

func testTracerFlushIntervalEnv(t *testing.T, envValue string, expectedInterval time.Duration) {
os.Setenv("ELASTIC_APM_FLUSH_INTERVAL", envValue)
defer os.Unsetenv("ELASTIC_APM_FLUSH_INTERVAL")

tracer, err := elasticapm.NewTracer("tracer.testing", "")
tracer, err := elasticapm.NewTracer("tracer_testing", "")
require.NoError(t, err)
defer tracer.Close()
tracer.Transport = transporttest.Discard
Expand Down Expand Up @@ -63,15 +66,15 @@ func TestTracerTransactionRateEnvInvalid(t *testing.T) {
os.Setenv("ELASTIC_APM_TRANSACTION_SAMPLE_RATE", "2.0")
defer os.Unsetenv("ELASTIC_APM_TRANSACTION_SAMPLE_RATE")

_, err := elasticapm.NewTracer("tracer.testing", "")
_, err := elasticapm.NewTracer("tracer_testing", "")
assert.EqualError(t, err, "invalid ELASTIC_APM_TRANSACTION_SAMPLE_RATE value 2.0: out of range [0,1.0]")
}

func testTracerTransactionRateEnv(t *testing.T, envValue string, ratio float64) {
os.Setenv("ELASTIC_APM_TRANSACTION_SAMPLE_RATE", envValue)
defer os.Unsetenv("ELASTIC_APM_TRANSACTION_SAMPLE_RATE")

tracer, err := elasticapm.NewTracer("tracer.testing", "")
tracer, err := elasticapm.NewTracer("tracer_testing", "")
require.NoError(t, err)
defer tracer.Close()
tracer.Transport = transporttest.Discard
Expand All @@ -87,3 +90,46 @@ func testTracerTransactionRateEnv(t *testing.T, envValue string, ratio float64)
}
assert.InDelta(t, N*ratio, sampled, N*0.02) // allow 2% error
}

func TestTracerServiceNameEnvSanitizationSpecified(t *testing.T) {
testTracerServiceNameSanitization(
t, "TestTracerServiceNameEnvSanitizationSpecified",
"foo_bar", "ELASTIC_APM_SERVICE_NAME=foo!bar",
)
}

func TestTracerServiceNameEnvSanitizationExecutableName(t *testing.T) {
testTracerServiceNameSanitization(
t, "TestTracerServiceNameEnvSanitizationExecutableName",
"apm-agent-go_test", // .test -> _test
)
}

func testTracerServiceNameSanitization(t *testing.T, testName, sanitizedServiceName string, env ...string) {
if os.Getenv("_INSIDE_TEST") != "1" {
cmd := exec.Command(os.Args[0], "-test.run=^"+testName+"$")
cmd.Env = append(cmd.Env, "_INSIDE_TEST=1")
cmd.Env = append(cmd.Env, env...)
err := cmd.Run()
assert.NoError(t, err)
return
}

tracer, err := elasticapm.NewTracer("", "")
require.NoError(t, err)
defer tracer.Close()

var called bool
tracer.Transport = transporttest.CallbackTransport{
Transactions: func(_ context.Context, payload *model.TransactionsPayload) error {
assert.Equal(t, sanitizedServiceName, payload.Service.Name)
called = true
return nil
},
}

tx := tracer.StartTransaction("name", "type")
tx.Done(-1)
tracer.Flush(nil)
assert.True(t, called)
}
2 changes: 1 addition & 1 deletion error_test.go
Expand Up @@ -61,7 +61,7 @@ func TestInternalStackTrace(t *testing.T) {

func sendError(t *testing.T, err error, f ...func(*elasticapm.Error)) *model.Error {
var r transporttest.RecorderTransport
tracer, newTracerErr := elasticapm.NewTracer("tracer.testing", "")
tracer, newTracerErr := elasticapm.NewTracer("tracer_testing", "")
assert.NoError(t, newTracerErr)
defer tracer.Close()

Expand Down
14 changes: 5 additions & 9 deletions tracer.go
Expand Up @@ -6,8 +6,6 @@ import (
"sync"
"time"

"github.com/pkg/errors"

"github.com/elastic/apm-agent-go/model"
"github.com/elastic/apm-agent-go/stacktrace"
"github.com/elastic/apm-agent-go/transport"
Expand Down Expand Up @@ -144,17 +142,15 @@ type Tracer struct {
// or taking the service name and version from the environment
// if unspecified.
//
// If service is nil, then the service will be defined using the
// ELASTIC_APM_* environment variables.
// If serviceName is empty, then the service name will be defined
// using the ELASTIC_APM_SERVER_NAME environment variable.
func NewTracer(serviceName, serviceVersion string) (*Tracer, error) {
service := &envService
if serviceName != "" {
if err := validateServiceName(serviceName); err != nil {
return nil, err
}
service = newService(serviceName, serviceVersion)
} else if service == nil {
return nil, errors.Errorf(
"no service name specified, and %s not specified",
envServiceName,
)
}
var opts options
if err := opts.init(false); err != nil {
Expand Down
35 changes: 14 additions & 21 deletions tracer_test.go
Expand Up @@ -15,7 +15,7 @@ import (
)

func TestTracerStats(t *testing.T) {
tracer, err := elasticapm.NewTracer("tracer.testing", "")
tracer, err := elasticapm.NewTracer("tracer_testing", "")
assert.NoError(t, err)
defer tracer.Close()
tracer.Transport = transporttest.Discard
Expand All @@ -30,7 +30,7 @@ func TestTracerStats(t *testing.T) {
}

func TestTracerClosedSendNonblocking(t *testing.T) {
tracer, err := elasticapm.NewTracer("tracer.testing", "")
tracer, err := elasticapm.NewTracer("tracer_testing", "")
assert.NoError(t, err)
tracer.Close()

Expand All @@ -41,7 +41,7 @@ func TestTracerClosedSendNonblocking(t *testing.T) {
}

func TestTracerFlushInterval(t *testing.T) {
tracer, err := elasticapm.NewTracer("tracer.testing", "")
tracer, err := elasticapm.NewTracer("tracer_testing", "")
assert.NoError(t, err)
defer tracer.Close()
tracer.Transport = transporttest.Discard
Expand All @@ -59,7 +59,7 @@ func TestTracerFlushInterval(t *testing.T) {
}

func TestTracerMaxQueueSize(t *testing.T) {
tracer, err := elasticapm.NewTracer("tracer.testing", "")
tracer, err := elasticapm.NewTracer("tracer_testing", "")
assert.NoError(t, err)
defer tracer.Close()

Expand All @@ -84,7 +84,7 @@ func TestTracerMaxQueueSize(t *testing.T) {
}

func TestTracerRetryTimer(t *testing.T) {
tracer, err := elasticapm.NewTracer("tracer.testing", "")
tracer, err := elasticapm.NewTracer("tracer_testing", "")
assert.NoError(t, err)
defer tracer.Close()

Expand Down Expand Up @@ -123,7 +123,7 @@ func TestTracerRetryTimer(t *testing.T) {
}

func TestTracerRetryTimerFlush(t *testing.T) {
tracer, err := elasticapm.NewTracer("tracer.testing", "")
tracer, err := elasticapm.NewTracer("tracer_testing", "")
assert.NoError(t, err)
defer tracer.Close()
interval := time.Second
Expand Down Expand Up @@ -163,7 +163,7 @@ func TestTracerRetryTimerFlush(t *testing.T) {

func TestTracerMaxSpans(t *testing.T) {
var r transporttest.RecorderTransport
tracer, err := elasticapm.NewTracer("tracer.testing", "")
tracer, err := elasticapm.NewTracer("tracer_testing", "")
assert.NoError(t, err)
defer tracer.Close()
tracer.Transport = &r
Expand Down Expand Up @@ -194,7 +194,7 @@ func TestTracerMaxSpans(t *testing.T) {

func TestTracerErrors(t *testing.T) {
var r transporttest.RecorderTransport
tracer, err := elasticapm.NewTracer("tracer.testing", "")
tracer, err := elasticapm.NewTracer("tracer_testing", "")
assert.NoError(t, err)
defer tracer.Close()
tracer.Transport = &r
Expand All @@ -216,7 +216,7 @@ func TestTracerErrors(t *testing.T) {
}

func TestTracerErrorsBuffered(t *testing.T) {
tracer, err := elasticapm.NewTracer("tracer.testing", "")
tracer, err := elasticapm.NewTracer("tracer_testing", "")
assert.NoError(t, err)
defer tracer.Close()
errors := make(chan transporttest.SendErrorsRequest)
Expand Down Expand Up @@ -268,7 +268,7 @@ func TestTracerErrorsBuffered(t *testing.T) {
}

func TestTracerProcessor(t *testing.T) {
tracer, err := elasticapm.NewTracer("tracer.testing", "")
tracer, err := elasticapm.NewTracer("tracer_testing", "")
assert.NoError(t, err)
defer tracer.Close()
tracer.Transport = transporttest.Discard
Expand Down Expand Up @@ -308,7 +308,7 @@ func TestTracerProcessor(t *testing.T) {

func TestTracerRecover(t *testing.T) {
var r transporttest.RecorderTransport
tracer, err := elasticapm.NewTracer("tracer.testing", "")
tracer, err := elasticapm.NewTracer("tracer_testing", "")
assert.NoError(t, err)
defer tracer.Close()
tracer.Transport = &r
Expand All @@ -331,14 +331,7 @@ func capturePanic(tracer *elasticapm.Tracer, v interface{}) {
panic(v)
}

type testLogger struct {
t *testing.T
}

func (l testLogger) Debugf(format string, args ...interface{}) {
l.t.Logf("[DEBUG] "+format, args...)
}

func (l testLogger) Errorf(format string, args ...interface{}) {
l.t.Logf("[ERROR] "+format, args...)
func TestTracerServiceNameValidation(t *testing.T) {
_, err := elasticapm.NewTracer("wot!", "")
assert.EqualError(t, err, `invalid service name "wot!": character '!' is not in the allowed set (a-zA-Z0-9 _-)`)
}
27 changes: 26 additions & 1 deletion utils.go
Expand Up @@ -3,9 +3,12 @@ package elasticapm
import (
"os"
"path/filepath"
"regexp"
"runtime"
"strings"

"github.com/pkg/errors"

"github.com/elastic/apm-agent-go/model"
)

Expand All @@ -17,6 +20,8 @@ var (
goLanguage = model.Language{Name: "go", Version: runtime.Version()}
goRuntime = model.Runtime{Name: runtime.Compiler, Version: runtime.Version()}
localSystem model.System

serviceNameInvalidRegexp = regexp.MustCompile("[^" + serviceNameValidClass + "]")
)

const (
Expand All @@ -26,6 +31,8 @@ const (
envHostname = "ELASTIC_APM_HOSTNAME"
envServiceName = "ELASTIC_APM_SERVICE_NAME"
envServiceVersion = "ELASTIC_APM_SERVICE_VERSION"

serviceNameValidClass = "a-zA-Z0-9 _-"
)

func init() {
Expand Down Expand Up @@ -64,8 +71,11 @@ func getEnvironmentService() model.Service {
name := os.Getenv(envServiceName)
if name == "" {
name = filepath.Base(os.Args[0])
if runtime.GOOS == "windows" {
name = strings.TrimSuffix(name, filepath.Ext(name))
}
}
svc := newService(name, "")
svc := newService(sanitizeServiceName(name), "")
return *svc
}

Expand Down Expand Up @@ -101,3 +111,18 @@ func getLocalSystem() model.System {
func validTagKey(k string) bool {
return !strings.ContainsAny(k, `.*"`)
}

func validateServiceName(name string) error {
idx := serviceNameInvalidRegexp.FindStringIndex(name)
if idx == nil {
return nil
}
return errors.Errorf(
"invalid service name %q: character %q is not in the allowed set (%s)",
name, name[idx[0]], serviceNameValidClass,
)
}

func sanitizeServiceName(name string) string {
return serviceNameInvalidRegexp.ReplaceAllString(name, "_")
}

0 comments on commit ace23e4

Please sign in to comment.