32 changes: 16 additions & 16 deletions pkg/database/vercode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,13 @@ func TestVerificationCode_FindVerificationCode(t *testing.T) {
LongExpiresAt: time.Now().Add(2 * time.Hour),
}

if err := db.SaveVerificationCode(vc, realm); err != nil {
if err := realm.SaveVerificationCode(db, vc); err != nil {
t.Fatalf("error creating verification code: %v", err)
}

{
// Find by raw code
got, err := db.FindVerificationCode(code)
got, err := realm.FindVerificationCode(db, code)
if err != nil {
t.Fatal(err)
}
Expand All @@ -106,7 +106,7 @@ func TestVerificationCode_FindVerificationCode(t *testing.T) {

{
// Find by raw long code
got, err := db.FindVerificationCode(longCode)
got, err := realm.FindVerificationCode(db, longCode)
if err != nil {
t.Fatal(err)
}
Expand All @@ -116,7 +116,7 @@ func TestVerificationCode_FindVerificationCode(t *testing.T) {
}

vc.Claimed = true
if err := db.SaveVerificationCode(vc, realm); err != nil {
if err := realm.SaveVerificationCode(db, vc); err != nil {
t.Fatal(err)
}
}
Expand Down Expand Up @@ -145,7 +145,7 @@ func TestVerificationCode_FindVerificationCodeByUUID(t *testing.T) {
LongExpiresAt: time.Now().Add(2 * time.Hour),
}

if err := db.SaveVerificationCode(vc, realm); err != nil {
if err := realm.SaveVerificationCode(db, vc); err != nil {
t.Fatal(err)
}

Expand Down Expand Up @@ -203,7 +203,7 @@ func TestVerificationCode_ListRecentCodes(t *testing.T) {
LongExpiresAt: time.Now().Add(2 * time.Hour),
}

if err := db.SaveVerificationCode(vc, realm); err != nil {
if err := realm.SaveVerificationCode(db, vc); err != nil {
t.Fatal(err)
}

Expand All @@ -214,7 +214,7 @@ func TestVerificationCode_ListRecentCodes(t *testing.T) {

{
u := &User{Model: gorm.Model{ID: userID}}
got, err := db.ListRecentCodes(realm, u)
got, err := realm.ListRecentCodes(db, u)
if err != nil {
t.Fatal(err)
}
Expand All @@ -238,7 +238,7 @@ func TestVerificationCode_ExpireVerificationCode(t *testing.T) {
LongExpiresAt: time.Now().Add(2 * time.Hour),
}

if err := db.SaveVerificationCode(vc, realm); err != nil {
if err := realm.SaveVerificationCode(db, vc); err != nil {
t.Fatal(err)
}

Expand All @@ -248,7 +248,7 @@ func TestVerificationCode_ExpireVerificationCode(t *testing.T) {
}

{
got, err := db.ExpireCode(uuid)
got, err := realm.ExpireCode(db, uuid, SystemTest)
if err != nil {
t.Fatal(err)
}
Expand All @@ -260,7 +260,7 @@ func TestVerificationCode_ExpireVerificationCode(t *testing.T) {
}
}

if _, err := db.ExpireCode(uuid); err == nil {
if _, err := realm.ExpireCode(db, uuid, SystemTest); err == nil {
t.Errorf("Expected code already expired, got %v", err)
}
}
Expand All @@ -287,7 +287,7 @@ func TestSaveUserReport(t *testing.T) {
NonceRequired: true,
}

if err := db.SaveVerificationCode(vc, realm); err != nil {
if err := realm.SaveVerificationCode(db, vc); err != nil {
t.Fatal(err)
}

Expand Down Expand Up @@ -384,15 +384,15 @@ func TestDeleteVerificationCode(t *testing.T) {
LongExpiresAt: time.Now().Add(time.Hour),
}

if err := db.SaveVerificationCode(&code, realm); err != nil {
if err := realm.SaveVerificationCode(db, &code); err != nil {
t.Fatal(err)
}

if err := db.DeleteVerificationCode(code.ID); err != nil {
if err := realm.DeleteVerificationCode(db, code.ID); err != nil {
t.Fatal(err)
}

if _, err := db.FindVerificationCode("12345678"); !errors.Is(err, gorm.ErrRecordNotFound) {
if _, err := realm.FindVerificationCode(db, "12345678"); !errors.Is(err, gorm.ErrRecordNotFound) {
t.Fatal(err)
}
}
Expand All @@ -418,7 +418,7 @@ func TestVerificationCodesCleanup(t *testing.T) {
{Code: "333333", LongCode: "333333ABCDEF", RealmID: realm.ID, TestType: "negative", ExpiresAt: now.Add(time.Minute), LongExpiresAt: now.Add(time.Hour)},
}
for _, rec := range testData {
if err := db.SaveVerificationCode(rec, realm); err != nil {
if err := realm.SaveVerificationCode(db, rec); err != nil {
t.Fatalf("can't save test data: %v", err)
}
}
Expand Down Expand Up @@ -528,7 +528,7 @@ func TestStatDates(t *testing.T) {
}

for i, test := range tests {
if err := db.SaveVerificationCode(test.code, realm); err != nil {
if err := realm.SaveVerificationCode(db, test.code); err != nil {
t.Fatalf("[%d] error saving code: %v", i, err)
}

Expand Down
2 changes: 1 addition & 1 deletion tools/seed/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ func generateCodesAndStats(db *database.Database, realm *database.Realm) (map[st
}

// If a verification code already exists, it will fail to save, and we retry.
if err := db.SaveVerificationCode(verificationCode, realm); err != nil {
if err := realm.SaveVerificationCode(db, verificationCode); err != nil {
return nil, fmt.Errorf("failed to create verification code: %w", err)
}
db.UpdateStats(ctx, verificationCode)
Expand Down