diff --git a/premium/tables.go b/premium/tables.go index 2f9a8217ba..ec8884fbff 100644 --- a/premium/tables.go +++ b/premium/tables.go @@ -12,16 +12,11 @@ func ContainsPaidTables(tables schema.Tables) bool { return false } -// MakeAllTablesPaid sets all tables to paid +// MakeAllTablesPaid sets all tables to paid (including relations) func MakeAllTablesPaid(tables schema.Tables) schema.Tables { for _, table := range tables { - MakeTablePaid(table) + table.IsPaid = true + MakeAllTablesPaid(table.Relations) } return tables } - -// MakeTablePaid sets the table to paid -func MakeTablePaid(table *schema.Table) *schema.Table { - table.IsPaid = true - return table -} diff --git a/premium/tables_test.go b/premium/tables_test.go index 251be02d0b..5da642209c 100644 --- a/premium/tables_test.go +++ b/premium/tables_test.go @@ -1,9 +1,10 @@ package premium import ( + "testing" + "github.com/cloudquery/plugin-sdk/v4/schema" "github.com/stretchr/testify/assert" - "testing" ) func TestContainsPaidTables(t *testing.T) { @@ -28,12 +29,22 @@ func TestMakeAllTablesPaid(t *testing.T) { &schema.Table{Name: "table1", IsPaid: false}, &schema.Table{Name: "table2", IsPaid: false}, &schema.Table{Name: "table3", IsPaid: false}, + &schema.Table{Name: "table_with_relations", IsPaid: false, Relations: schema.Tables{ + &schema.Table{Name: "relation_table", IsPaid: false}, + }}, } paidTables := MakeAllTablesPaid(noPaidTables) - assert.Equal(t, 3, len(paidTables)) - for _, table := range paidTables { + assert.Equal(t, 4, len(paidTables)) + assert.Equal(t, 5, len(paidTables.FlattenTables())) + assertAllArePaid(t, paidTables) +} + +func assertAllArePaid(t *testing.T, tables schema.Tables) { + t.Helper() + for _, table := range tables { assert.True(t, table.IsPaid) + assertAllArePaid(t, table.Relations) } }