Skip to content

Commit

Permalink
refactor: simplify test asserts (#1271)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexandear committed Oct 3, 2023
1 parent d37b38f commit 2959fc0
Show file tree
Hide file tree
Showing 10 changed files with 151 additions and 297 deletions.
17 changes: 4 additions & 13 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

const configPath = "../testdata/config/"
Expand Down Expand Up @@ -126,23 +127,13 @@ func TestTranslate(t *testing.T) {
viper.SetConfigName(tt.cfgName)
viper.SetConfigType("toml")
err := viper.ReadInConfig()
if err != nil {
t.Error(err)
}
require.NoError(t, err)

var vc ViperConfig
err = viper.Unmarshal(&vc)
if err != nil {
t.Error(err)
}
require.NoError(t, err)
cfg, err := vc.Translate()
if tt.wantError != nil {
if err == nil {
t.Errorf("expected error")
}
assert.Equal(t, tt.wantError, err)
}

assert.Equal(t, tt.wantError, err)
assert.Equal(t, cfg.Rules, tt.cfg.Rules)
}
}
5 changes: 3 additions & 2 deletions detect/baseline_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"

"github.com/zricethezav/gitleaks/v8/report"
)

Expand Down Expand Up @@ -82,7 +83,7 @@ func TestFileLoadBaseline(t *testing.T) {

for _, test := range tests {
_, err := LoadBaseline(test.Filename)
assert.Equal(t, test.ExpectedError.Error(), err.Error())
assert.Equal(t, test.ExpectedError, err)
}
}

Expand Down Expand Up @@ -132,6 +133,6 @@ func TestIgnoreIssuesInBaseline(t *testing.T) {
for _, finding := range test.findings {
d.addFinding(finding)
}
assert.Equal(t, test.expectCount, len(d.findings))
assert.Len(t, d.findings, test.expectCount)
}
}
135 changes: 37 additions & 98 deletions detect/detect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/zricethezav/gitleaks/v8/config"
"github.com/zricethezav/gitleaks/v8/report"
Expand Down Expand Up @@ -336,23 +337,14 @@ func TestDetect(t *testing.T) {
viper.SetConfigName(tt.cfgName)
viper.SetConfigType("toml")
err := viper.ReadInConfig()
if err != nil {
t.Error(err)
}
require.NoError(t, err)

var vc config.ViperConfig
err = viper.Unmarshal(&vc)
if err != nil {
t.Error(err)
}
require.NoError(t, err)
cfg, err := vc.Translate()
cfg.Path = filepath.Join(configPath, tt.cfgName+".toml")
if tt.wantError != nil {
if err == nil {
t.Errorf("expected error")
}
assert.Equal(t, tt.wantError, err)
}
assert.Equal(t, tt.wantError, err)
d := NewDetector(cfg)
d.baselinePath = tt.baselinePath

Expand Down Expand Up @@ -444,56 +436,38 @@ func TestFromGit(t *testing.T) {
},
}

err := moveDotGit("dotGit", ".git")
if err != nil {
t.Fatal(err)
}
defer func() {
if err := moveDotGit(".git", "dotGit"); err != nil {
t.Error(err)
}
}()
moveDotGit(t, "dotGit", ".git")
defer moveDotGit(t, ".git", "dotGit")

for _, tt := range tests {

viper.AddConfigPath(configPath)
viper.SetConfigName("simple")
viper.SetConfigType("toml")
err = viper.ReadInConfig()
if err != nil {
t.Error(err)
}
err := viper.ReadInConfig()
require.NoError(t, err)

var vc config.ViperConfig
err = viper.Unmarshal(&vc)
if err != nil {
t.Error(err)
}
require.NoError(t, err)
cfg, err := vc.Translate()
if err != nil {
t.Error(err)
}
require.NoError(t, err)
detector := NewDetector(cfg)

var ignorePath string
info, err := os.Stat(tt.source)
if err != nil {
t.Fatalf("could not os.Stat: %v", err)
}
require.NoError(t, err)

if info.IsDir() {
ignorePath = filepath.Join(tt.source, ".gitleaksignore")
} else {
ignorePath = filepath.Join(filepath.Dir(tt.source), ".gitleaksignore")
}
if err = detector.AddGitleaksIgnore(ignorePath); err != nil {
t.Fatalf("could not call AddGitleaksIgnore: %v", err)
}
err = detector.AddGitleaksIgnore(ignorePath)
require.NoError(t, err)

findings, err := detector.DetectGit(tt.source, tt.logOpts, DetectType)
if err != nil {
t.Error(err)
}
require.NoError(t, err)

for _, f := range findings {
f.Match = "" // remove lines cause copying and pasting them has some wack formatting
Expand Down Expand Up @@ -540,43 +514,27 @@ func TestFromGitStaged(t *testing.T) {
},
}

err := moveDotGit("dotGit", ".git")
if err != nil {
t.Fatal(err)
}
defer func() {
if err := moveDotGit(".git", "dotGit"); err != nil {
t.Error(err)
}
}()
moveDotGit(t, "dotGit", ".git")
defer moveDotGit(t, ".git", "dotGit")

for _, tt := range tests {

viper.AddConfigPath(configPath)
viper.SetConfigName("simple")
viper.SetConfigType("toml")
err = viper.ReadInConfig()
if err != nil {
t.Error(err)
}
err := viper.ReadInConfig()
require.NoError(t, err)

var vc config.ViperConfig
err = viper.Unmarshal(&vc)
if err != nil {
t.Error(err)
}
require.NoError(t, err)
cfg, err := vc.Translate()
if err != nil {
t.Error(err)
}
require.NoError(t, err)
detector := NewDetector(cfg)
if err = detector.AddGitleaksIgnore(filepath.Join(tt.source, ".gitleaksignore")); err != nil {
t.Fatalf("could not call AddGitleaksIgnore: %v", err)
}
err = detector.AddGitleaksIgnore(filepath.Join(tt.source, ".gitleaksignore"))
require.NoError(t, err)
findings, err := detector.DetectGit(tt.source, tt.logOpts, ProtectStagedType)
if err != nil {
t.Error(err)
}
require.NoError(t, err)

for _, f := range findings {
f.Match = "" // remove lines cause copying and pasting them has some wack formatting
Expand Down Expand Up @@ -647,38 +605,28 @@ func TestFromFiles(t *testing.T) {
viper.SetConfigName("simple")
viper.SetConfigType("toml")
err := viper.ReadInConfig()
if err != nil {
t.Error(err)
}
require.NoError(t, err)

var vc config.ViperConfig
err = viper.Unmarshal(&vc)
if err != nil {
t.Error(err)
}
require.NoError(t, err)
cfg, _ := vc.Translate()
detector := NewDetector(cfg)

var ignorePath string
info, err := os.Stat(tt.source)
if err != nil {
t.Fatalf("could not call os.Stat: %v", err)
}
require.NoError(t, err)

if info.IsDir() {
ignorePath = filepath.Join(tt.source, ".gitleaksignore")
} else {
ignorePath = filepath.Join(filepath.Dir(tt.source), ".gitleaksignore")
}
if err = detector.AddGitleaksIgnore(ignorePath); err != nil {
t.Fatalf("could not call AddGitleaksIgnore: %v", err)
}
err = detector.AddGitleaksIgnore(ignorePath)
require.NoError(t, err)
detector.FollowSymlinks = true
findings, err := detector.DetectFiles(tt.source)
if err != nil {
t.Error(err)
}

require.NoError(t, err)
assert.ElementsMatch(t, tt.expectedFindings, findings)
}
}
Expand Down Expand Up @@ -718,31 +666,25 @@ func TestDetectWithSymlinks(t *testing.T) {
viper.SetConfigName("simple")
viper.SetConfigType("toml")
err := viper.ReadInConfig()
if err != nil {
t.Error(err)
}
require.NoError(t, err)

var vc config.ViperConfig
err = viper.Unmarshal(&vc)
if err != nil {
t.Error(err)
}
require.NoError(t, err)
cfg, _ := vc.Translate()
detector := NewDetector(cfg)
detector.FollowSymlinks = true
findings, err := detector.DetectFiles(tt.source)
if err != nil {
t.Error(err)
}
require.NoError(t, err)
assert.ElementsMatch(t, tt.expectedFindings, findings)
}
}

func moveDotGit(from, to string) error {
func moveDotGit(t *testing.T, from, to string) {
t.Helper()

repoDirs, err := os.ReadDir("../testdata/repos")
if err != nil {
return err
}
require.NoError(t, err)
for _, dir := range repoDirs {
if to == ".git" {
_, err := os.Stat(fmt.Sprintf("%s/%s/%s", repoBasePath, dir.Name(), "dotGit"))
Expand All @@ -762,9 +704,6 @@ func moveDotGit(from, to string) error {

err = os.Rename(fmt.Sprintf("%s/%s/%s", repoBasePath, dir.Name(), from),
fmt.Sprintf("%s/%s/%s", repoBasePath, dir.Name(), to))
if err != nil {
return err
}
require.NoError(t, err)
}
return nil
}
9 changes: 3 additions & 6 deletions detect/location_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package detect

import (
"testing"

"github.com/stretchr/testify/assert"
)

// TestGetLocation tests the getLocation function.
Expand Down Expand Up @@ -50,11 +52,6 @@ func TestGetLocation(t *testing.T) {

for _, test := range tests {
loc := location(Fragment{newlineIndices: test.linePairs}, []int{test.start, test.end})
if loc != test.wantLocation {
t.Errorf("\nstartLine %d\nstartColumn: %d\nendLine: %d\nendColumn: %d\nstartLineIndex: %d\nendlineIndex %d",
loc.startLine, loc.startColumn, loc.endLine, loc.endColumn, loc.startLineIndex, loc.endLineIndex)

t.Error("got", loc, "want", test.wantLocation)
}
assert.Equal(t, test.wantLocation, loc)
}
}
Loading

0 comments on commit 2959fc0

Please sign in to comment.