diff --git a/pkg/sqlite/load.go b/pkg/sqlite/load.go index c63643710..7f8f3666c 100644 --- a/pkg/sqlite/load.go +++ b/pkg/sqlite/load.go @@ -1816,7 +1816,8 @@ func (s sqlLoader) RemoveOverwrittenChannelHead(pkg, bundle string) error { return err } } else { - if _, err := tx.Exec(`DELETE FROM channel WHERE name = ? AND package_name = ?`, channel, pkg); err != nil { + // NULL default channel before dropping to let packagemanifest detect default channel + if _, err := tx.Exec(`UPDATE channel SET head_operatorbundle_name = NULL WHERE name = ? AND package_name = ? AND name IN (SELECT default_channel FROM package WHERE name = ?)`, channel, pkg, pkg); err != nil { return err } } diff --git a/pkg/sqlite/load_test.go b/pkg/sqlite/load_test.go index 00e8498d2..7af8112ec 100644 --- a/pkg/sqlite/load_test.go +++ b/pkg/sqlite/load_test.go @@ -2,6 +2,7 @@ package sqlite import ( "context" + "database/sql" "encoding/json" "fmt" "strings" @@ -1093,6 +1094,42 @@ func TestRemoveOverwrittenChannelHead(t *testing.T) { }, }, }, + + { + description: "PersistDefaultChannel", + fields: fields{ + bundles: []*registry.Bundle{ + newBundle(t, "csv-a", "pkg-0", []string{"a"}, newUnstructuredCSV(t, "csv-a", "")), + newBundle(t, "csv-b", "pkg-0", []string{"b"}, newUnstructuredCSV(t, "csv-b", "")), + }, + pkgs: []registry.PackageManifest{ + { + PackageName: "pkg-0", + Channels: []registry.PackageChannel{ + { + Name: "a", + CurrentCSVName: "csv-a", + }, + { + Name: "b", + CurrentCSVName: "csv-b", + }, + }, + DefaultChannelName: "a", + }, + }, + }, + args: args{ + bundle: "csv-a", + pkg: "pkg-0", + }, + expected: expected{ + err: nil, + bundles: map[string]struct{}{ + "pkg-0/b/csv-b": {}, + }, + }, + }, } for _, tt := range tests { t.Run(tt.description, func(t *testing.T) { @@ -1100,7 +1137,7 @@ func TestRemoveOverwrittenChannelHead(t *testing.T) { defer cleanup() store, err := NewSQLLiteLoader(db) require.NoError(t, err) - err = store.Migrate(context.TODO()) + err = store.Migrate(context.Background()) require.NoError(t, err) for _, bundle := range tt.fields.bundles { @@ -1112,6 +1149,21 @@ func TestRemoveOverwrittenChannelHead(t *testing.T) { // Throw away any errors loading packages (not testing this) store.AddPackageChannels(pkg) } + + getDefaultChannel := func(pkg string) sql.NullString { + // get defaultChannel before delete + rows, err := db.QueryContext(context.Background(), `SELECT default_channel FROM package WHERE name = ?`, pkg) + require.NoError(t, err) + defer rows.Close() + var defaultChannel sql.NullString + for rows.Next() { + require.NoError(t, rows.Scan(&defaultChannel)) + break + } + return defaultChannel + } + oldDefaultChannel := getDefaultChannel(tt.args.pkg) + err = store.(registry.HeadOverwriter).RemoveOverwrittenChannelHead(tt.args.pkg, tt.args.bundle) if tt.expected.err != nil { require.EqualError(t, err, tt.expected.err.Error()) @@ -1121,7 +1173,7 @@ func TestRemoveOverwrittenChannelHead(t *testing.T) { querier := NewSQLLiteQuerierFromDb(db) - bundles, err := querier.ListBundles(context.TODO()) + bundles, err := querier.ListBundles(context.Background()) require.NoError(t, err) var extra []string @@ -1141,6 +1193,9 @@ func TestRemoveOverwrittenChannelHead(t *testing.T) { t.Errorf("unexpected bundles found: %v", extra) } + // should preserve defaultChannel entry in package table + currentDefaultChannel := getDefaultChannel(tt.args.pkg) + require.Equal(t, oldDefaultChannel, currentDefaultChannel) }) } }