diff --git a/pkg/database/vercode.go b/pkg/database/vercode.go index e1dbe2f03..44f266d7d 100644 --- a/pkg/database/vercode.go +++ b/pkg/database/vercode.go @@ -23,6 +23,7 @@ import ( "strings" "time" + "github.com/google/exposure-notifications-verification-server/pkg/timeutils" "github.com/jinzhu/gorm" ) @@ -73,7 +74,7 @@ func (VerificationCode) TableName() string { // to update statistics about usage. If the executions fail, an error is logged // but the transaction continues. This is called automatically by gorm. func (v *VerificationCode) AfterCreate(scope *gorm.Scope) { - date := v.CreatedAt.Truncate(24 * time.Hour) + date := timeutils.Midnight(v.CreatedAt) // If the issuer was a user, update the user stats for the day. if v.IssuingUserID != 0 { diff --git a/pkg/database/vercode_test.go b/pkg/database/vercode_test.go index 93868f2af..11c8cc2b3 100644 --- a/pkg/database/vercode_test.go +++ b/pkg/database/vercode_test.go @@ -303,3 +303,106 @@ func TestPurgeVerificationCodes(t *testing.T) { t.Fatalf("purge record count mismatch, want: 2, got: %v", count) } } + +func TestStatDatesOnCreate(t *testing.T) { + // Please note, this test is NOT exhaustive. A better engineer would try + // all dates, and a bunch of corner cases. This is intended as a + // smokescreen. + t.Parallel() + db := NewTestDatabase(t) + db.db.LogMode(true) + fmtString := "2006-01-02" + now := time.Now() + nowStr := now.Format(fmtString) + maxAge := time.Hour + + tests := []struct { + code *VerificationCode + statDate string + }{ + { + &VerificationCode{ + Code: "111111", + LongCode: "111111", + TestType: "negative", + ExpiresAt: now.Add(time.Second), + LongExpiresAt: now.Add(time.Second), + IssuingUserID: 100, // need for RealmUserStats + IssuingAppID: 200, // need for AuthorizedAppStats + RealmID: 300, // need for RealmStats + }, + nowStr}, + } + + for i, test := range tests { + if err := db.SaveVerificationCode(test.code, maxAge); err != nil { + t.Errorf("[%d] error saving code: %v", i, err) + } + + { + var stats []*RealmUserStats + if err := db.db. + Model(&UserStats{}). + Select("*"). + Scan(&stats). + Error; err != nil { + if IsNotFound(err) { + t.Fatalf("[%d] Error grabbing user stats %v", i, err) + } + } + if len(stats) != 1 { + t.Fatalf("[%d] expected one user stat", i) + } + if stats[0].CodesIssued != uint(i+1) { + t.Errorf("[%d] expected stat.CodesIssued = %d, expected %d", i, stats[0].CodesIssued, i+1) + } + if f := stats[0].Date.Format(fmtString); f != test.statDate { + t.Errorf("[%d] expected stat.Date = %s, expected %s", i, f, test.statDate) + } + } + + { + var stats []*AuthorizedAppStats + if err := db.db. + Model(&UserStats{}). + Select("*"). + Scan(&stats). + Error; err != nil { + if IsNotFound(err) { + t.Fatalf("[%d] Error grabbing app stats %v", i, err) + } + } + if len(stats) != 1 { + t.Fatalf("[%d] expected one user stat", i) + } + if stats[0].CodesIssued != uint(i+1) { + t.Errorf("[%d] expected stat.CodesIssued = %d, expected %d", i, stats[0].CodesIssued, i+1) + } + if f := stats[0].Date.Format(fmtString); f != test.statDate { + t.Errorf("[%d] expected stat.Date = %s, expected %s", i, f, test.statDate) + } + } + + { + var stats []*RealmStats + if err := db.db. + Model(&UserStats{}). + Select("*"). + Scan(&stats). + Error; err != nil { + if IsNotFound(err) { + t.Fatalf("[%d] Error grabbing realm stats %v", i, err) + } + } + if len(stats) != 1 { + t.Fatalf("[%d] expected one user stat", i) + } + if stats[0].CodesIssued != uint(i+1) { + t.Errorf("[%d] expected stat.CodesIssued = %d, expected %d", i, stats[0].CodesIssued, i+1) + } + if f := stats[0].Date.Format(fmtString); f != test.statDate { + t.Errorf("[%d] expected stat.Date = %s, expected %s", i, f, test.statDate) + } + } + } +} diff --git a/pkg/timeutils/midnight.go b/pkg/timeutils/midnight.go new file mode 100644 index 000000000..62c4058a5 --- /dev/null +++ b/pkg/timeutils/midnight.go @@ -0,0 +1,35 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package timeutils defines functions to close the gaps present in Golang's +// default implementation of Time. +package timeutils + +import "time" + +// LocalMidnight returns the local midnight of the given time. +func LocalMidnight(t time.Time) time.Time { + return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.Local) +} + +// UTCMidnight converts the given time to UTC, and returns the UTC time. +func UTCMidnight(t time.Time) time.Time { + t = t.UTC() + return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC) +} + +// Midnight returns the midnight of the given time. +func Midnight(t time.Time) time.Time { + return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location()) +}