Skip to content

Commit

Permalink
test: Split code and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
gfyrag authored and flemzord committed Sep 2, 2022
1 parent 06a5283 commit ceeef7f
Show file tree
Hide file tree
Showing 6 changed files with 281 additions and 80 deletions.
43 changes: 1 addition & 42 deletions cmd/container.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package cmd

import (
"context"
"crypto/tls"
"fmt"
"io"
Expand All @@ -11,7 +10,6 @@ import (
"os"
"strings"

"github.com/Masterminds/semver/v3"
"github.com/Shopify/sarama"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
Expand All @@ -25,7 +23,6 @@ import (
"github.com/numary/go-libs/sharedpublish/sharedpublishhttp"
"github.com/numary/go-libs/sharedpublish/sharedpublishkafka"
"github.com/numary/ledger/cmd/internal"
"github.com/numary/ledger/pkg/analytics"
"github.com/numary/ledger/pkg/api"
"github.com/numary/ledger/pkg/api/middlewares"
"github.com/numary/ledger/pkg/api/routes"
Expand Down Expand Up @@ -208,45 +205,7 @@ func NewContainer(v *viper.Viper, userOptions ...fx.Option) *fx.App {
}(),
}))

if v.GetBool(telemetryEnabledFlag) || v.GetBool(segmentEnabledFlag) {
applicationId := viper.GetString(telemetryApplicationIdFlag)
if applicationId == "" {
applicationId = viper.GetString(segmentApplicationIdFlag)
}
var appIdProviderModule fx.Option
if applicationId == "" {
appIdProviderModule = fx.Provide(analytics.FromStorageAppIdProvider)
} else {
appIdProviderModule = fx.Provide(func() analytics.AppIdProvider {
return analytics.AppIdProviderFn(func(ctx context.Context) (string, error) {
return applicationId, nil
})
})
}
writeKey := viper.GetString(telemetryWriteKeyFlag)
if writeKey == "" {
writeKey = viper.GetString(segmentWriteKeyFlag)
}
interval := viper.GetDuration(telemetryHeartbeatIntervalFlag)
if interval == 0 {
interval = viper.GetDuration(segmentHeartbeatIntervalFlag)
}
if writeKey == "" {
sharedlogging.GetLogger(context.Background()).Infof("telemetry enabled but no write key provided")
} else if interval == 0 {
sharedlogging.GetLogger(context.Background()).Error("telemetry heartbeat interval is 0")
} else {
_, err := semver.NewVersion(Version)
if err != nil {
sharedlogging.GetLogger(context.Background()).Infof("telemetry enabled but version '%s' is not semver, skip", Version)
} else {
options = append(options,
appIdProviderModule,
analytics.NewHeartbeatModule(Version, writeKey, interval),
)
}
}
}
options = append(options, internal.NewAnalyticsModule(v, Version))

options = append(options, fx.Provide(
fx.Annotate(func() []ledger.LedgerOption {
Expand Down
83 changes: 83 additions & 0 deletions cmd/internal/analytics.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package internal

import (
"context"
"time"

"github.com/Masterminds/semver/v3"
"github.com/numary/go-libs/sharedlogging"
"github.com/numary/ledger/pkg/analytics"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"go.uber.org/fx"
)

const (
// deprecated
segmentEnabledFlag = "segment-enabled"
// deprecated
segmentWriteKeyFlag = "segment-write-key"
// deprecated
segmentApplicationIdFlag = "segment-application-id"
// deprecated
segmentHeartbeatIntervalFlag = "segment-heartbeat-interval"

telemetryEnabledFlag = "telemetry-enabled"
telemetryWriteKeyFlag = "telemetry-write-key"
telemetryApplicationIdFlag = "telemetry-application-id"
telemetryHeartbeatIntervalFlag = "telemetry-heartbeat-interval"
)

func InitAnalyticsFlags(cmd *cobra.Command, defaultWriteKey string) {
cmd.PersistentFlags().Bool(segmentEnabledFlag, true, "Is segment enabled")
cmd.PersistentFlags().String(segmentApplicationIdFlag, "", "Segment application id")
cmd.PersistentFlags().String(segmentWriteKeyFlag, defaultWriteKey, "Segment write key")
cmd.PersistentFlags().Duration(segmentHeartbeatIntervalFlag, 4*time.Hour, "Segment heartbeat interval")
cmd.PersistentFlags().Bool(telemetryEnabledFlag, true, "Is telemetry enabled")
cmd.PersistentFlags().String(telemetryApplicationIdFlag, "", "telemetry application id")
cmd.PersistentFlags().String(telemetryWriteKeyFlag, defaultWriteKey, "telemetry write key")
cmd.PersistentFlags().Duration(telemetryHeartbeatIntervalFlag, 4*time.Hour, "telemetry heartbeat interval")
}

func NewAnalyticsModule(v *viper.Viper, version string) fx.Option {
if v.GetBool(telemetryEnabledFlag) || v.GetBool(segmentEnabledFlag) {
applicationId := viper.GetString(telemetryApplicationIdFlag)
if applicationId == "" {
applicationId = viper.GetString(segmentApplicationIdFlag)
}
var appIdProviderModule fx.Option
if applicationId == "" {
appIdProviderModule = fx.Provide(analytics.FromStorageAppIdProvider)
} else {
appIdProviderModule = fx.Provide(func() analytics.AppIdProvider {
return analytics.AppIdProviderFn(func(ctx context.Context) (string, error) {
return applicationId, nil
})
})
}
writeKey := viper.GetString(telemetryWriteKeyFlag)
if writeKey == "" {
writeKey = viper.GetString(segmentWriteKeyFlag)
}
interval := viper.GetDuration(telemetryHeartbeatIntervalFlag)
if interval == 0 {
interval = viper.GetDuration(segmentHeartbeatIntervalFlag)
}
if writeKey == "" {
sharedlogging.GetLogger(context.Background()).Infof("telemetry enabled but no write key provided")
} else if interval == 0 {
sharedlogging.GetLogger(context.Background()).Error("telemetry heartbeat interval is 0")
} else {
_, err := semver.NewVersion(version)
if err != nil {
sharedlogging.GetLogger(context.Background()).Infof("telemetry enabled but version '%s' is not semver, skip", version)
} else {
return fx.Options(
appIdProviderModule,
analytics.NewHeartbeatModule(version, writeKey, interval),
)
}
}
}
return fx.Options()
}
169 changes: 169 additions & 0 deletions cmd/internal/analytics_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
package internal

import (
"context"
"net/http"
"os"
"reflect"
"testing"
"time"

"github.com/numary/ledger/pkg/storage"
"github.com/numary/ledger/pkg/storage/sqlstorage"
"github.com/pborman/uuid"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/stretchr/testify/require"
"go.uber.org/fx"
"gopkg.in/segmentio/analytics-go.v3"
)

func TestAnalyticsFlags(t *testing.T) {
type testCase struct {
name string
key string
envValue string
viperMethod interface{}
expectedValue interface{}
}

for _, testCase := range []testCase{
{
name: "using deprecated segment enabled flag",
key: segmentEnabledFlag,
envValue: "true",
viperMethod: (*viper.Viper).GetBool,
expectedValue: true,
},
{
name: "using deprecated segment write key flag",
key: segmentWriteKeyFlag,
envValue: "foo:bar",
viperMethod: (*viper.Viper).GetString,
expectedValue: "foo:bar",
},
{
name: "using deprecated segment heartbeat interval flag",
key: segmentHeartbeatIntervalFlag,
envValue: "10s",
viperMethod: (*viper.Viper).GetDuration,
expectedValue: 10 * time.Second,
},
{
name: "using deprecated segment application id flag",
key: segmentApplicationIdFlag,
envValue: "foo:bar",
viperMethod: (*viper.Viper).GetString,
expectedValue: "foo:bar",
},
{
name: "using telemetry enabled flag",
key: telemetryEnabledFlag,
envValue: "true",
viperMethod: (*viper.Viper).GetBool,
expectedValue: true,
},
{
name: "using telemetry write key flag",
key: telemetryWriteKeyFlag,
envValue: "foo:bar",
viperMethod: (*viper.Viper).GetString,
expectedValue: "foo:bar",
},
{
name: "using telemetry heartbeat interval flag",
key: telemetryHeartbeatIntervalFlag,
envValue: "10s",
viperMethod: (*viper.Viper).GetDuration,
expectedValue: 10 * time.Second,
},
{
name: "using telemetry application id flag",
key: telemetryApplicationIdFlag,
envValue: "foo:bar",
viperMethod: (*viper.Viper).GetString,
expectedValue: "foo:bar",
},
} {
t.Run(testCase.name, func(t *testing.T) {
v := viper.GetViper()
cmd := &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
ret := reflect.ValueOf(testCase.viperMethod).Call([]reflect.Value{
reflect.ValueOf(v),
reflect.ValueOf(testCase.key),
})
require.Len(t, ret, 1)

rValue := ret[0].Interface()
require.Equal(t, testCase.expectedValue, rValue)
},
}
InitHTTPBasicFlags(cmd)
BindEnv(v)

restoreEnvVar := setEnvVar(testCase.key, testCase.envValue)
defer restoreEnvVar()

require.NoError(t, v.BindPFlags(cmd.PersistentFlags()))
require.NoError(t, cmd.Execute())
})
}
}

func TestAnalyticsModule(t *testing.T) {
v := viper.GetViper()
v.Set(telemetryEnabledFlag, true)
v.Set(telemetryWriteKeyFlag, "XXX")
v.Set(telemetryApplicationIdFlag, "appId")
v.Set(telemetryHeartbeatIntervalFlag, 10*time.Second)

handled := make(chan struct{})

module := NewAnalyticsModule(v, "1.0.0")
app := fx.New(
module,
fx.Provide(func(lc fx.Lifecycle) (storage.Driver, error) {
id := uuid.New()
driver := sqlstorage.NewDriver("sqlite", sqlstorage.NewSQLiteDB(os.TempDir(), id))
lc.Append(fx.Hook{
OnStart: driver.Initialize,
})
return driver, nil
}),
fx.Replace(analytics.Config{
BatchSize: 1,
Transport: roundTripperFn(func(req *http.Request) (*http.Response, error) {
select {
case <-handled:
// Nothing to do, the chan has already been closed
default:
close(handled)
}
return &http.Response{
StatusCode: http.StatusOK,
}, nil
}),
}))
require.NoError(t, app.Start(context.Background()))
defer func() {
require.NoError(t, app.Stop(context.Background()))
}()

select {
case <-time.After(time.Second):
require.Fail(t, "Timeout waiting first stats from analytics module")
case <-handled:
}

}

func TestAnalyticsModuleDisabled(t *testing.T) {
v := viper.GetViper()
v.Set(telemetryEnabledFlag, false)

module := NewAnalyticsModule(v, "1.0.0")
app := fx.New(module)
require.NoError(t, app.Start(context.Background()))
require.NoError(t, app.Stop(context.Background()))
}
15 changes: 0 additions & 15 deletions cmd/internal/http_basic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ package internal

import (
"fmt"
"os"
"reflect"
"strings"
"testing"

"github.com/numary/go-libs/sharedauth"
Expand All @@ -13,19 +11,6 @@ import (
"github.com/stretchr/testify/require"
)

func withPrefix(flag string) string {
return strings.ToUpper(fmt.Sprintf("%s_%s", envPrefix, EnvVarReplacer.Replace(flag)))
}

func setEnvVar(key, value string) func() {
prefixedFlag := withPrefix(key)
oldEnv := os.Getenv(prefixedFlag)
os.Setenv(prefixedFlag, value)
return func() {
os.Setenv(prefixedFlag, oldEnv)
}
}

func TestViperEnvBinding(t *testing.T) {

type testCase struct {
Expand Down
27 changes: 27 additions & 0 deletions cmd/internal/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package internal

import (
"fmt"
"net/http"
"os"
"strings"
)

func withPrefix(flag string) string {
return strings.ToUpper(fmt.Sprintf("%s_%s", envPrefix, EnvVarReplacer.Replace(flag)))
}

func setEnvVar(key, value string) func() {
prefixedFlag := withPrefix(key)
oldEnv := os.Getenv(prefixedFlag)
os.Setenv(prefixedFlag, value)
return func() {
os.Setenv(prefixedFlag, oldEnv)
}
}

type roundTripperFn func(req *http.Request) (*http.Response, error)

func (fn roundTripperFn) RoundTrip(req *http.Request) (*http.Response, error) {
return fn(req)
}

0 comments on commit ceeef7f

Please sign in to comment.