diff --git a/internal/dev_server/db/sqlite.go b/internal/dev_server/db/sqlite.go index 2e847ff2..5bd66ec7 100644 --- a/internal/dev_server/db/sqlite.go +++ b/internal/dev_server/db/sqlite.go @@ -6,6 +6,7 @@ import ( "encoding/json" "io" "os" + "strings" _ "github.com/mattn/go-sqlite3" "github.com/pkg/errors" @@ -47,12 +48,15 @@ func (s *Sqlite) GetDevProject(ctx context.Context, key string) (*model.Project, var flagStateData string row := s.database.QueryRowContext(ctx, ` - SELECT key, source_environment_key, context, last_sync_time, flag_state - FROM projects + SELECT key, source_environment_key, context, last_sync_time, flag_state, payload_version + FROM projects WHERE key = ? `, key) - if err := row.Scan(&project.Key, &project.SourceEnvironmentKey, &contextData, &project.LastSyncTime, &flagStateData); err != nil { + if err := row.Scan( + &project.Key, &project.SourceEnvironmentKey, &contextData, + &project.LastSyncTime, &flagStateData, &project.PayloadVersion, + ); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, model.NewErrNotFound("project", key) } @@ -200,14 +204,15 @@ SELECT 1 FROM projects WHERE key = ? return } _, err = tx.Exec(` -INSERT INTO projects (key, source_environment_key, context, last_sync_time, flag_state) -VALUES (?, ?, ?, ?, ?) +INSERT INTO projects (key, source_environment_key, context, last_sync_time, flag_state, payload_version) +VALUES (?, ?, ?, ?, ?, ?) `, project.Key, project.SourceEnvironmentKey, project.Context.JSONString(), project.LastSyncTime, string(flagsStateJson), + project.PayloadVersion, ) if err != nil { return @@ -341,6 +346,20 @@ func (s *Sqlite) UpsertOverride(ctx context.Context, override model.Override) (m return override, nil } +func (s *Sqlite) IncrementProjectPayloadVersion(ctx context.Context, projectKey string) (int, error) { + row := s.database.QueryRowContext(ctx, ` + UPDATE projects + SET payload_version = payload_version + 1 + WHERE key = ? + RETURNING payload_version + `, projectKey) + var version int + if err := row.Scan(&version); err != nil { + return 0, errors.Wrap(err, "unable to increment payload version") + } + return version, nil +} + func (s *Sqlite) DeactivateOverride(ctx context.Context, projectKey, flagKey string) (int, error) { row := s.database.QueryRowContext(ctx, ` UPDATE overrides @@ -373,12 +392,12 @@ func (s *Sqlite) RestoreBackup(ctx context.Context, stream io.Reader) (string, e } err = os.Rename(filepath, s.dbPath) if err != nil { - //panic because this would really leave the app in an invalid state + // panic because this would really leave the app in an invalid state panic(err) } s.database, err = sql.Open("sqlite3", s.dbPath) if err != nil { - //panic because this would really leave the app in an invalid state + // panic because this would really leave the app in an invalid state panic(err) } @@ -445,12 +464,20 @@ func (s *Sqlite) runMigrations(ctx context.Context) error { source_environment_key text NOT NULL, context text NOT NULL, last_sync_time timestamp NOT NULL, - flag_state TEXT NOT NULL + flag_state TEXT NOT NULL, + payload_version INTEGER NOT NULL DEFAULT 1 )`) if err != nil { return err } + // Migration: add payload_version to existing databases that predate this column. + _, err = tx.Exec(`ALTER TABLE projects ADD COLUMN payload_version INTEGER NOT NULL DEFAULT 1`) + if err != nil && !strings.Contains(err.Error(), "duplicate column name") { + return err + } + err = nil + _, err = tx.Exec(` CREATE TABLE IF NOT EXISTS overrides ( project_key text NOT NULL, diff --git a/internal/dev_server/db/sqlite_test.go b/internal/dev_server/db/sqlite_test.go index a15ef72c..3eb3cb46 100644 --- a/internal/dev_server/db/sqlite_test.go +++ b/internal/dev_server/db/sqlite_test.go @@ -36,6 +36,7 @@ func TestDBFunctions(t *testing.T) { SourceEnvironmentKey: "env-1", Context: ldContext, LastSyncTime: now, + PayloadVersion: 1, AllFlagsState: model.FlagsState{ "flag-1": model.FlagState{Value: ldvalue.Bool(true), Version: 2}, "flag-2": model.FlagState{Value: ldvalue.String("cool"), Version: 2}, @@ -71,6 +72,7 @@ func TestDBFunctions(t *testing.T) { SourceEnvironmentKey: "env-2", Context: ldContext, LastSyncTime: now, + PayloadVersion: 1, AllFlagsState: model.FlagsState{ "flag-1": model.FlagState{Value: ldvalue.Int(123), Version: 2}, "flag-2": model.FlagState{Value: ldvalue.Float64(99.99), Version: 2}, @@ -97,6 +99,7 @@ func TestDBFunctions(t *testing.T) { SourceEnvironmentKey: "env-3", Context: ldContext, LastSyncTime: now, + PayloadVersion: 1, AllFlagsState: model.FlagsState{ "flag-1": model.FlagState{Value: ldvalue.Int(123), Version: 2}, "flag-2": model.FlagState{Value: ldvalue.Float64(99.99), Version: 2}, @@ -169,6 +172,7 @@ func TestDBFunctions(t *testing.T) { assert.Equal(t, expected.SourceEnvironmentKey, p.SourceEnvironmentKey) assert.Equal(t, expected.Context, p.Context) assert.True(t, expected.LastSyncTime.Equal(p.LastSyncTime)) + assert.Equal(t, expected.PayloadVersion, p.PayloadVersion) }) t.Run("GetAvailableVariations returns variations", func(t *testing.T) { @@ -364,6 +368,25 @@ func TestDBFunctions(t *testing.T) { assert.True(t, found) }) + t.Run("IncrementProjectPayloadVersion increments and returns new version", func(t *testing.T) { + proj, err := store.GetDevProject(ctx, projects[0].Key) + require.NoError(t, err) + initialVersion := proj.PayloadVersion + + newVersion, err := store.IncrementProjectPayloadVersion(ctx, projects[0].Key) + require.NoError(t, err) + assert.Equal(t, initialVersion+1, newVersion) + + proj, err = store.GetDevProject(ctx, projects[0].Key) + require.NoError(t, err) + assert.Equal(t, initialVersion+1, proj.PayloadVersion) + + // Calling again should increment once more + newVersion2, err := store.IncrementProjectPayloadVersion(ctx, projects[0].Key) + require.NoError(t, err) + assert.Equal(t, initialVersion+2, newVersion2) + }) + t.Run("UpdateProject deletes overrides for flags that are no longer in the project", func(t *testing.T) { project := projects[2] diff --git a/internal/dev_server/model/import_project.go b/internal/dev_server/model/import_project.go index 2b6133ad..d5360a05 100644 --- a/internal/dev_server/model/import_project.go +++ b/internal/dev_server/model/import_project.go @@ -53,6 +53,7 @@ func ImportProject(ctx context.Context, projectKey string, importData ImportData Context: importData.Context, AllFlagsState: importData.FlagsState, AvailableVariations: []FlagVariation{}, + PayloadVersion: 1, } // Convert available variations if present diff --git a/internal/dev_server/model/mocks/store.go b/internal/dev_server/model/mocks/store.go index 5487ece9..dd037c2e 100644 --- a/internal/dev_server/model/mocks/store.go +++ b/internal/dev_server/model/mocks/store.go @@ -148,6 +148,21 @@ func (mr *MockStoreMockRecorder) GetOverridesForProject(ctx, projectKey any) *go return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOverridesForProject", reflect.TypeOf((*MockStore)(nil).GetOverridesForProject), ctx, projectKey) } +// IncrementProjectPayloadVersion mocks base method. +func (m *MockStore) IncrementProjectPayloadVersion(ctx context.Context, projectKey string) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IncrementProjectPayloadVersion", ctx, projectKey) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IncrementProjectPayloadVersion indicates an expected call of IncrementProjectPayloadVersion. +func (mr *MockStoreMockRecorder) IncrementProjectPayloadVersion(ctx, projectKey any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementProjectPayloadVersion", reflect.TypeOf((*MockStore)(nil).IncrementProjectPayloadVersion), ctx, projectKey) +} + // InsertProject mocks base method. func (m *MockStore) InsertProject(ctx context.Context, project model.Project) error { m.ctrl.T.Helper() diff --git a/internal/dev_server/model/override.go b/internal/dev_server/model/override.go index 36359bbd..cfc97c34 100644 --- a/internal/dev_server/model/override.go +++ b/internal/dev_server/model/override.go @@ -3,6 +3,8 @@ package model import ( "context" + "github.com/pkg/errors" + "github.com/launchdarkly/go-sdk-common/v3/ldvalue" ) @@ -58,6 +60,11 @@ func UpsertOverride(ctx context.Context, projectKey, flagKey string, value ldval return Override{}, err } + _, err = store.IncrementProjectPayloadVersion(ctx, projectKey) + if err != nil { + return Override{}, errors.Wrap(err, "unable to increment payload version") + } + GetObserversFromContext(ctx).Notify(OverrideEvent{ FlagKey: flagKey, ProjectKey: projectKey, @@ -76,6 +83,12 @@ func DeleteOverride(ctx context.Context, projectKey, flagKey string) error { if err != nil { return err } + + _, err = store.IncrementProjectPayloadVersion(ctx, projectKey) + if err != nil { + return errors.Wrap(err, "unable to increment payload version") + } + override := Override{ ProjectKey: projectKey, FlagKey: flagKey, diff --git a/internal/dev_server/model/override_test.go b/internal/dev_server/model/override_test.go index 57e58ddd..b9198336 100644 --- a/internal/dev_server/model/override_test.go +++ b/internal/dev_server/model/override_test.go @@ -73,6 +73,7 @@ func TestUpsertOverride(t *testing.T) { t.Run("override is applied, observers are notified", func(t *testing.T) { store.EXPECT().GetDevProject(gomock.Any(), projKey).Return(project, nil) store.EXPECT().UpsertOverride(gomock.Any(), override).Return(override, nil) + store.EXPECT().IncrementProjectPayloadVersion(gomock.Any(), projKey).Return(1, nil) observer. EXPECT(). Handle(model.OverrideEvent{ @@ -128,6 +129,7 @@ func TestDeleteOverride(t *testing.T) { t.Run("override is applied, observers are notified", func(t *testing.T) { store.EXPECT().GetDevProject(gomock.Any(), projKey).Return(project, nil) store.EXPECT().DeactivateOverride(gomock.Any(), projKey, flagKey).Return(2, nil) + store.EXPECT().IncrementProjectPayloadVersion(gomock.Any(), projKey).Return(1, nil) observer. EXPECT(). Handle(model.OverrideEvent{ @@ -198,11 +200,13 @@ func TestDeleteOverrides(t *testing.T) { // Expectations for first override store.EXPECT().GetDevProject(gomock.Any(), projKey).Return(project, nil) store.EXPECT().DeactivateOverride(gomock.Any(), projKey, flagKey).Return(2, nil) + store.EXPECT().IncrementProjectPayloadVersion(gomock.Any(), projKey).Return(1, nil) observer.EXPECT().Handle(gomock.Any()) // Expectations for second override store.EXPECT().GetDevProject(gomock.Any(), projKey).Return(project, nil) store.EXPECT().DeactivateOverride(gomock.Any(), projKey, "flag2").Return(2, nil) + store.EXPECT().IncrementProjectPayloadVersion(gomock.Any(), projKey).Return(2, nil) observer.EXPECT().Handle(gomock.Any()) err := model.DeleteOverrides(ctx, projKey) diff --git a/internal/dev_server/model/project.go b/internal/dev_server/model/project.go index f796498f..ad3e899d 100644 --- a/internal/dev_server/model/project.go +++ b/internal/dev_server/model/project.go @@ -18,6 +18,7 @@ type Project struct { LastSyncTime time.Time AllFlagsState FlagsState AvailableVariations []FlagVariation + PayloadVersion int } // CreateProject creates a project and adds it to the database. @@ -25,6 +26,7 @@ func CreateProject(ctx context.Context, projectKey, sourceEnvironmentKey string, project := Project{ Key: projectKey, SourceEnvironmentKey: sourceEnvironmentKey, + PayloadVersion: 1, } if ldCtx == nil { @@ -87,6 +89,12 @@ func UpdateProject(ctx context.Context, projectKey string, context *ldcontext.Co return Project{}, errors.New("Project not updated") } + newPayloadVersion, err := store.IncrementProjectPayloadVersion(ctx, projectKey) + if err != nil { + return Project{}, errors.Wrap(err, "unable to increment payload version") + } + project.PayloadVersion = newPayloadVersion + allFlagsWithOverrides, err := project.GetFlagStateWithOverridesForProject(ctx) if err != nil { return Project{}, errors.Wrapf(err, "unable to get overrides for project, %s", projectKey) diff --git a/internal/dev_server/model/project_test.go b/internal/dev_server/model/project_test.go index eb47455c..4a6d4a15 100644 --- a/internal/dev_server/model/project_test.go +++ b/internal/dev_server/model/project_test.go @@ -183,6 +183,7 @@ func TestUpdateProject(t *testing.T) { sdk.EXPECT().GetAllFlagsState(gomock.Any(), gomock.Any(), "sdkKey").Return(allFlagsState, nil) api.EXPECT().GetAllFlags(gomock.Any(), proj.Key).Return(allFlags, nil) store.EXPECT().UpdateProject(gomock.Any(), gomock.Any()).Return(true, nil) + store.EXPECT().IncrementProjectPayloadVersion(gomock.Any(), proj.Key).Return(2, nil) store.EXPECT().GetOverridesForProject(gomock.Any(), proj.Key).Return(model.Overrides{}, nil) observer. EXPECT(). @@ -193,7 +194,9 @@ func TestUpdateProject(t *testing.T) { project, err := model.UpdateProject(ctx, proj.Key, nil, nil) require.Nil(t, err) - assert.Equal(t, proj, project) + expectedProj := proj + expectedProj.PayloadVersion = 2 + assert.Equal(t, expectedProj, project) }) } diff --git a/internal/dev_server/model/store.go b/internal/dev_server/model/store.go index 8c3d9284..7a3b3883 100644 --- a/internal/dev_server/model/store.go +++ b/internal/dev_server/model/store.go @@ -28,6 +28,8 @@ type Store interface { UpsertOverride(ctx context.Context, override Override) (Override, error) GetOverridesForProject(ctx context.Context, projectKey string) (Overrides, error) GetAvailableVariationsForProject(ctx context.Context, projectKey string) (map[string][]Variation, error) + // IncrementProjectPayloadVersion atomically increments the payload version for the project and returns the new version. + IncrementProjectPayloadVersion(ctx context.Context, projectKey string) (int, error) CreateBackup(ctx context.Context) (io.ReadCloser, int64, error) RestoreBackup(ctx context.Context, stream io.Reader) (string, error) diff --git a/internal/dev_server/model/sync_test.go b/internal/dev_server/model/sync_test.go index 643e3653..ed508135 100644 --- a/internal/dev_server/model/sync_test.go +++ b/internal/dev_server/model/sync_test.go @@ -153,8 +153,9 @@ func TestInitialSync(t *testing.T) { sdk.EXPECT().GetAllFlagsState(gomock.Any(), gomock.Any(), sdkKey).Return(allFlagsState, nil) api.EXPECT().GetAllFlags(gomock.Any(), projKey).Return(allFlags, nil) store.EXPECT().InsertProject(gomock.Any(), gomock.Any()).Return(nil) - store.EXPECT().UpsertOverride(gomock.Any(), override).Return(override, nil) store.EXPECT().GetDevProject(gomock.Any(), projKey).Return(&proj, nil) + store.EXPECT().UpsertOverride(gomock.Any(), override).Return(override, nil) + store.EXPECT().IncrementProjectPayloadVersion(gomock.Any(), projKey).Return(1, nil) input := model.InitialProjectSettings{ Enabled: true, diff --git a/internal/dev_server/sdk/fdv2.go b/internal/dev_server/sdk/fdv2.go new file mode 100644 index 00000000..4f665a97 --- /dev/null +++ b/internal/dev_server/sdk/fdv2.go @@ -0,0 +1,136 @@ +package sdk + +import ( + "encoding/json" + "fmt" + "strconv" + "strings" + + "github.com/launchdarkly/go-server-sdk/v7/subsystems" + "github.com/launchdarkly/ldcli/internal/dev_server/model" +) + +const ( + fdv2ReasonUpToDate = "up-to-date" + fdv2ReasonCantCatchup = "cant-catchup" + fdv2ReasonPayloadMissing = "payload-missing" +) + +// parseBasis extracts the payload ID and version from a basis state string of the +// form "(p::)". Returns ("", 0) if the string is absent or unparseable. +// +// Note: in production LD selectors the payload ID is an opaque server-assigned value. +// The dev server uses the project key as the payload ID (see makePayloadTransferredEvent). +// This is a dev-server-specific convention and should not be assumed elsewhere. +func parseBasis(basis string) (string, int) { + if !strings.HasPrefix(basis, "(p:") || !strings.HasSuffix(basis, ")") { + return "", 0 + } + // Strip the "(p:" prefix and ")" suffix to get ":". + inner := basis[3 : len(basis)-1] + lastColon := strings.LastIndex(inner, ":") + if lastColon == -1 { + return "", 0 + } + version, err := strconv.Atoi(inner[lastColon+1:]) + if err != nil || version < 0 { + return "", 0 + } + return inner[:lastColon], version +} + +// buildPollResponse constructs the FDv2 polling response. +// +// payloadID is the stable identifier for this payload (the project key). +// currentVersion is the project's current PayloadVersion. +// flags is the current flag state with overrides applied. +// basis is the raw ?basis query param from the SDK (empty string = no basis provided). +// +// Delta transfers are not supported: stale clients always receive a full payload. +// Tracking the change history required for deltas is overkill for a local dev server. +func buildPollResponse(payloadID string, currentVersion int, flags model.FlagsState, basis string) (subsystems.PollingPayload, error) { + basisPayloadID, basisVersion := parseBasis(basis) + switch { + case basisVersion == 0: + return buildFullTransferResponse(payloadID, currentVersion, flags, fdv2ReasonPayloadMissing) + case basisPayloadID == payloadID && basisVersion == currentVersion: + event, err := makeServerIntentEvent(payloadID, currentVersion, subsystems.IntentNone, fdv2ReasonUpToDate) + if err != nil { + return subsystems.PollingPayload{}, err + } + return subsystems.PollingPayload{Events: []subsystems.RawEvent{event}}, nil + default: + // Payload ID mismatch, stale version, or version ahead of current (e.g. project recreated): + // we can't compute a delta — send the full payload. + return buildFullTransferResponse(payloadID, currentVersion, flags, fdv2ReasonCantCatchup) + } +} + +func buildFullTransferResponse(payloadID string, version int, flags model.FlagsState, reason string) (subsystems.PollingPayload, error) { + intentEvent, err := makeServerIntentEvent(payloadID, version, subsystems.IntentTransferFull, reason) + if err != nil { + return subsystems.PollingPayload{}, err + } + events := []subsystems.RawEvent{intentEvent} + + for key, flagState := range flags { + event, err := makePutObjectEvent(version, key, flagState) + if err != nil { + return subsystems.PollingPayload{}, err + } + events = append(events, event) + } + + transferredEvent, err := makePayloadTransferredEvent(payloadID, version) + if err != nil { + return subsystems.PollingPayload{}, err + } + events = append(events, transferredEvent) + + return subsystems.PollingPayload{Events: events}, nil +} + +func makeServerIntentEvent(payloadID string, target int, intentCode subsystems.IntentCode, reason string) (subsystems.RawEvent, error) { + data, err := json.Marshal(subsystems.ServerIntent{ + Payload: subsystems.Payload{ + ID: payloadID, + Target: target, + Code: intentCode, + Reason: reason, + }, + }) + if err != nil { + return subsystems.RawEvent{}, err + } + return subsystems.RawEvent{Name: subsystems.EventServerIntent, Data: data}, nil +} + +func makePutObjectEvent(version int, key string, flagState model.FlagState) (subsystems.RawEvent, error) { + object, err := json.Marshal(serverFlagFromFlagState(key, flagState)) + if err != nil { + return subsystems.RawEvent{}, err + } + data, err := json.Marshal(subsystems.PutObject{ + Version: version, + Kind: subsystems.FlagKind, + Key: key, + Object: object, + }) + if err != nil { + return subsystems.RawEvent{}, err + } + return subsystems.RawEvent{Name: subsystems.EventPutObject, Data: data}, nil +} + +func makePayloadTransferredEvent(payloadID string, version int) (subsystems.RawEvent, error) { + // The selector state is synthetic and dev-server-specific: the dev server uses the + // project key as the payload ID rather than a server-assigned opaque value. The SDK + // echoes this selector back as ?basis on subsequent polls, where parseBasisVersion + // extracts the version from it. + selector := subsystems.NewSelector(fmt.Sprintf("(p:%s:%d)", payloadID, version), version) + data, err := json.Marshal(selector) + if err != nil { + return subsystems.RawEvent{}, err + } + return subsystems.RawEvent{Name: subsystems.EventPayloadTransferred, Data: data}, nil +} diff --git a/internal/dev_server/sdk/fdv2_test.go b/internal/dev_server/sdk/fdv2_test.go new file mode 100644 index 00000000..4020f5cb --- /dev/null +++ b/internal/dev_server/sdk/fdv2_test.go @@ -0,0 +1,257 @@ +package sdk + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/gorilla/mux" + "github.com/launchdarkly/go-sdk-common/v3/ldcontext" + "github.com/launchdarkly/go-sdk-common/v3/ldvalue" + "github.com/launchdarkly/go-server-sdk/v7/subsystems" + "github.com/launchdarkly/ldcli/internal/dev_server/model" + "github.com/launchdarkly/ldcli/internal/dev_server/model/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +func TestParseBasis(t *testing.T) { + tests := []struct { + basis string + expectedID string + expectedVersion int + }{ + {"", "", 0}, + {"(p:my-project:5)", "my-project", 5}, + {"(p:my-project:1)", "my-project", 1}, + {"(p:complex:key:with:colons:99)", "complex:key:with:colons", 99}, + {"not-valid", "", 0}, + {"(p:no-version)", "", 0}, + {"(p:negative:-1)", "", 0}, + {"(p:nan:abc)", "", 0}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("basis=%q", tt.basis), func(t *testing.T) { + id, version := parseBasis(tt.basis) + assert.Equal(t, tt.expectedID, id) + assert.Equal(t, tt.expectedVersion, version) + }) + } +} + +func TestBuildPollResponse(t *testing.T) { + payloadID := "test-project" + currentVersion := 5 + flags := model.FlagsState{ + "flag-1": model.FlagState{Value: ldvalue.Bool(true), Version: 2}, + } + + t.Run("no basis sends xfer-full with payload-missing", func(t *testing.T) { + resp, err := buildPollResponse(payloadID, currentVersion, flags, "") + require.NoError(t, err) + + require.GreaterOrEqual(t, len(resp.Events), 3) // server-intent + put-objects + payload-transferred + + assertServerIntentEvent(t, resp.Events[0], payloadID, currentVersion, subsystems.IntentTransferFull, fdv2ReasonPayloadMissing) + assertPayloadTransferredEvent(t, resp.Events[len(resp.Events)-1], payloadID, currentVersion) + }) + + t.Run("up-to-date basis sends none with up-to-date", func(t *testing.T) { + basis := fmt.Sprintf("(p:%s:%d)", payloadID, currentVersion) + resp, err := buildPollResponse(payloadID, currentVersion, flags, basis) + require.NoError(t, err) + + require.Len(t, resp.Events, 1) + assertServerIntentEvent(t, resp.Events[0], payloadID, currentVersion, subsystems.IntentNone, fdv2ReasonUpToDate) + }) + + t.Run("basis ahead of current version sends full transfer (e.g. project recreated)", func(t *testing.T) { + basis := fmt.Sprintf("(p:%s:%d)", payloadID, currentVersion+10) + resp, err := buildPollResponse(payloadID, currentVersion, flags, basis) + require.NoError(t, err) + + require.GreaterOrEqual(t, len(resp.Events), 3) + assertServerIntentEvent(t, resp.Events[0], payloadID, currentVersion, subsystems.IntentTransferFull, fdv2ReasonCantCatchup) + assertPayloadTransferredEvent(t, resp.Events[len(resp.Events)-1], payloadID, currentVersion) + }) + + t.Run("stale basis sends xfer-full with cant-catchup", func(t *testing.T) { + basis := fmt.Sprintf("(p:%s:%d)", payloadID, currentVersion-1) + resp, err := buildPollResponse(payloadID, currentVersion, flags, basis) + require.NoError(t, err) + + require.GreaterOrEqual(t, len(resp.Events), 3) + assertServerIntentEvent(t, resp.Events[0], payloadID, currentVersion, subsystems.IntentTransferFull, fdv2ReasonCantCatchup) + assertPayloadTransferredEvent(t, resp.Events[len(resp.Events)-1], payloadID, currentVersion) + }) + + t.Run("basis with wrong payload ID sends xfer-full", func(t *testing.T) { + basis := fmt.Sprintf("(p:%s:%d)", "other-project", currentVersion) + resp, err := buildPollResponse(payloadID, currentVersion, flags, basis) + require.NoError(t, err) + + require.GreaterOrEqual(t, len(resp.Events), 3) + assertServerIntentEvent(t, resp.Events[0], payloadID, currentVersion, subsystems.IntentTransferFull, fdv2ReasonCantCatchup) + assertPayloadTransferredEvent(t, resp.Events[len(resp.Events)-1], payloadID, currentVersion) + }) + + t.Run("full transfer includes a put-object for each flag", func(t *testing.T) { + multiFlags := model.FlagsState{ + "flag-a": model.FlagState{Value: ldvalue.Bool(true), Version: 1}, + "flag-b": model.FlagState{Value: ldvalue.String("hello"), Version: 2}, + } + resp, err := buildPollResponse(payloadID, currentVersion, multiFlags, "") + require.NoError(t, err) + + // server-intent + 2 put-objects + payload-transferred + assert.Len(t, resp.Events, 4) + putKeys := make(map[string]bool) + for _, event := range resp.Events { + if event.Name == subsystems.EventPutObject { + var put subsystems.PutObject + require.NoError(t, json.Unmarshal(event.Data, &put)) + putKeys[put.Key] = true + assert.Equal(t, currentVersion, put.Version) + assert.Equal(t, subsystems.FlagKind, put.Kind) + } + } + assert.True(t, putKeys["flag-a"]) + assert.True(t, putKeys["flag-b"]) + }) +} + +func TestPollV2Handler(t *testing.T) { + mockController := gomock.NewController(t) + store := mocks.NewMockStore(mockController) + observers := model.NewObservers() + + router := mux.NewRouter() + router.Use(model.ObserversMiddleware(observers)) + router.Use(model.StoreMiddleware(store)) + BindRoutes(router) + + project := &model.Project{ + Key: exampleProjectKey, + SourceEnvironmentKey: "my-environment", + Context: ldcontext.Context{}, + LastSyncTime: time.Unix(0, 0), + AllFlagsState: model.FlagsState{ + "flag-1": model.FlagState{Value: ldvalue.Bool(true), Version: 1}, + }, + AvailableVariations: nil, + PayloadVersion: 3, + } + + t.Run("no basis returns full payload", func(t *testing.T) { + store.EXPECT().GetDevProject(gomock.Any(), exampleProjectKey).Return(project, nil) + store.EXPECT().GetOverridesForProject(gomock.Any(), exampleProjectKey).Return(nil, nil) + + req := httptest.NewRequest(http.MethodGet, "/sdk/poll", nil) + req.Header.Set("Authorization", exampleProjectKey) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) + + var resp subsystems.PollingPayload + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.GreaterOrEqual(t, len(resp.Events), 3) + assertServerIntentEvent(t, resp.Events[0], exampleProjectKey, 3, subsystems.IntentTransferFull, fdv2ReasonPayloadMissing) + assertPayloadTransferredEvent(t, resp.Events[len(resp.Events)-1], exampleProjectKey, 3) + }) + + t.Run("up-to-date basis returns none intent", func(t *testing.T) { + store.EXPECT().GetDevProject(gomock.Any(), exampleProjectKey).Return(project, nil) + store.EXPECT().GetOverridesForProject(gomock.Any(), exampleProjectKey).Return(nil, nil) + + basisState := fmt.Sprintf("(p:%s:%d)", exampleProjectKey, project.PayloadVersion) + req := httptest.NewRequest(http.MethodGet, "/sdk/poll?basis="+basisState, nil) + req.Header.Set("Authorization", exampleProjectKey) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp subsystems.PollingPayload + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Len(t, resp.Events, 1) + assertServerIntentEvent(t, resp.Events[0], exampleProjectKey, 3, subsystems.IntentNone, fdv2ReasonUpToDate) + }) + + t.Run("url-encoded basis is decoded correctly", func(t *testing.T) { + store.EXPECT().GetDevProject(gomock.Any(), exampleProjectKey).Return(project, nil) + store.EXPECT().GetOverridesForProject(gomock.Any(), exampleProjectKey).Return(nil, nil) + + basisState := fmt.Sprintf("(p:%s:%d)", exampleProjectKey, project.PayloadVersion) + req := httptest.NewRequest(http.MethodGet, "/sdk/poll?basis="+url.QueryEscape(basisState), nil) + req.Header.Set("Authorization", exampleProjectKey) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp subsystems.PollingPayload + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Len(t, resp.Events, 1) + assertServerIntentEvent(t, resp.Events[0], exampleProjectKey, 3, subsystems.IntentNone, fdv2ReasonUpToDate) + }) + + t.Run("stale basis returns full payload with cant-catchup", func(t *testing.T) { + store.EXPECT().GetDevProject(gomock.Any(), exampleProjectKey).Return(project, nil) + store.EXPECT().GetOverridesForProject(gomock.Any(), exampleProjectKey).Return(nil, nil) + + basisState := fmt.Sprintf("(p:%s:%d)", exampleProjectKey, project.PayloadVersion-1) + req := httptest.NewRequest(http.MethodGet, "/sdk/poll?basis="+basisState, nil) + req.Header.Set("Authorization", exampleProjectKey) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp subsystems.PollingPayload + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.GreaterOrEqual(t, len(resp.Events), 3) + assertServerIntentEvent(t, resp.Events[0], exampleProjectKey, 3, subsystems.IntentTransferFull, fdv2ReasonCantCatchup) + assertPayloadTransferredEvent(t, resp.Events[len(resp.Events)-1], exampleProjectKey, 3) + }) + + t.Run("unknown project returns 404", func(t *testing.T) { + store.EXPECT().GetDevProject(gomock.Any(), exampleProjectKey).Return(nil, model.NewErrNotFound("project", exampleProjectKey)) + + req := httptest.NewRequest(http.MethodGet, "/sdk/poll", nil) + req.Header.Set("Authorization", exampleProjectKey) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) + }) +} + +// assertServerIntentEvent unmarshals a server-intent event and checks its fields. +func assertServerIntentEvent(t *testing.T, event subsystems.RawEvent, payloadID string, target int, intentCode subsystems.IntentCode, reason string) { + t.Helper() + assert.Equal(t, subsystems.EventServerIntent, event.Name) + var data subsystems.ServerIntent + require.NoError(t, json.Unmarshal(event.Data, &data)) + assert.Equal(t, payloadID, data.Payload.ID) + assert.Equal(t, target, data.Payload.Target) + assert.Equal(t, intentCode, data.Payload.Code) + assert.Equal(t, reason, data.Payload.Reason) +} + +// assertPayloadTransferredEvent unmarshals a payload-transferred event and checks its fields. +func assertPayloadTransferredEvent(t *testing.T, event subsystems.RawEvent, payloadID string, version int) { + t.Helper() + assert.Equal(t, subsystems.EventPayloadTransferred, event.Name) + var data subsystems.Selector + require.NoError(t, json.Unmarshal(event.Data, &data)) + assert.Equal(t, version, data.Version()) + assert.Equal(t, fmt.Sprintf("(p:%s:%d)", payloadID, version), data.State()) +} diff --git a/internal/dev_server/sdk/polling.go b/internal/dev_server/sdk/polling.go index 20547f3a..b3d46989 100644 --- a/internal/dev_server/sdk/polling.go +++ b/internal/dev_server/sdk/polling.go @@ -4,9 +4,39 @@ import ( "encoding/json" "net/http" + "github.com/launchdarkly/ldcli/internal/dev_server/model" "github.com/pkg/errors" ) +func PollV2(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + store := model.StoreFromContext(ctx) + projectKey := GetProjectKeyFromContext(ctx) + + project, err := store.GetDevProject(ctx, projectKey) + if err != nil { + WriteError(ctx, w, errors.Wrap(err, "failed to get project")) + return + } + + allFlags, err := project.GetFlagStateWithOverridesForProject(ctx) + if err != nil { + WriteError(ctx, w, errors.Wrap(err, "failed to get flag state")) + return + } + + response, err := buildPollResponse(projectKey, project.PayloadVersion, allFlags, r.URL.Query().Get("basis")) + if err != nil { + WriteError(ctx, w, errors.Wrap(err, "failed to build poll response")) + return + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(response); err != nil { + WriteError(ctx, w, errors.Wrap(err, "failed to encode response")) + } +} + func LatestAll(w http.ResponseWriter, r *http.Request) { ctx := r.Context() allFlags, err := GetAllFlagsFromContext(ctx) diff --git a/internal/dev_server/sdk/routes.go b/internal/dev_server/sdk/routes.go index f651ad09..fdc221be 100644 --- a/internal/dev_server/sdk/routes.go +++ b/internal/dev_server/sdk/routes.go @@ -21,6 +21,7 @@ func BindRoutes(router *mux.Router) { router.Handle("/all", GetProjectKeyFromAuthorizationHeader(http.HandlerFunc(StreamServerAllPayload))) router.Handle("/sdk/latest-all", GetProjectKeyFromAuthorizationHeader(http.HandlerFunc(LatestAll))) + router.Handle("/sdk/poll", GetProjectKeyFromAuthorizationHeader(http.HandlerFunc(PollV2))) router.PathPrefix("/sdk/flags/{flagKey}"). Methods(http.MethodGet).