Skip to content

Commit

Permalink
Reduce duplication and function size in main.go
Browse files Browse the repository at this point in the history
Signed-off-by: Eric Cornelissen <ericornelissen@gmail.com>
  • Loading branch information
ericcornelissen committed Sep 6, 2023
1 parent fbcaa05 commit 6d3efc0
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 84 deletions.
166 changes: 87 additions & 79 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ import (
"fmt"
"os"
"path"
"regexp"
)

const (
exitSuccess = 0
exitError = 1
exitProblems = 2
exitSuccess = 0
exitError = 1
exitViolations = 2
)

var (
Expand All @@ -42,10 +43,10 @@ var (
)

func main() {
os.Exit(ades())
os.Exit(run())
}

func ades() int {
func run() int {
flag.Usage = func() { usage() }
flag.Parse()

Expand All @@ -65,20 +66,24 @@ func ades() int {
return exitError
}

hasProblems, err := false, nil
hasViolations, hasError := false, false
for i, target := range targets {
if len(targets) > 1 {
fmt.Println("Scanning", target)
}

targetHasProblems, targetErr := run(target)
if targetErr != nil {
err = targetErr
fmt.Printf("An unexpected error occurred: %s\n", targetErr)
violations, err := analyzeTarget(target)
if err == nil {
printViolations(violations)
} else {
fmt.Printf("An unexpected error occurred: %s\n", err)
hasError = true
}

if targetHasProblems {
hasProblems = true
for _, fileVioviolations := range violations {
if len(fileVioviolations) > 0 {
hasViolations = true
}
}

if len(targets) > 1 && i < len(targets)-1 {
Expand All @@ -87,81 +92,86 @@ func ades() int {
}

switch {
case err != nil:
case hasError:
return exitError
case hasProblems:
return exitProblems
case hasViolations:
return exitViolations
default:
return exitSuccess
}
}

func run(target string) (hasProblems bool, err error) {
func analyzeTarget(target string) (map[string][]Violation, error) {
stat, err := os.Stat(target)
if err != nil {
return hasProblems, fmt.Errorf("could not process %s: %v", target, err)
return nil, fmt.Errorf("could not process %s: %v", target, err)
}

if stat.IsDir() {
if violations, err := tryManifest(path.Join(target, "action.yml")); err != nil {
fmt.Printf("Could not process manifest 'action.yml': %v\n", err)
} else {
hasProblems = len(violations) > 0 || hasProblems
printProblems("action.yml", violations)
return analyzeRepository(target)
} else {
fileViolations, err := analyzeFile(target)
if err != nil {
return nil, err
}

if violations, err := tryManifest(path.Join(target, "action.yaml")); err != nil {
fmt.Printf("Could not process manifest 'action.yaml': %v\n", err)
} else {
hasProblems = len(violations) > 0 || hasProblems
printProblems("action.yaml", violations)
}
violations := make(map[string][]Violation)
violations[target] = fileViolations
return violations, nil
}
}

workflowsDir := path.Join(target, ".github", "workflows")
workflows, err := os.ReadDir(workflowsDir)
if err != nil {
return hasProblems, fmt.Errorf("could not read workflows directory: %v", err)
}
func analyzeRepository(target string) (map[string][]Violation, error) {
violations := make(map[string][]Violation)

for _, entry := range workflows {
if entry.Type().IsDir() {
continue
}
if fileViolations, err := tryManifest(path.Join(target, "action.yml")); err == nil {
violations["action.yml"] = fileViolations
} else {
fmt.Printf("Could not process manifest 'action.yml': %v\n", err)
}

if path.Ext(entry.Name()) != ".yml" {
continue
}
if fileViolations, err := tryManifest(path.Join(target, "action.yaml")); err == nil {
violations["action.yaml"] = fileViolations
} else {
fmt.Printf("Could not process manifest 'action.yaml': %v\n", err)
}

workflowPath := path.Join(workflowsDir, entry.Name())
if violations, err := tryWorkflow(workflowPath); err != nil {
fmt.Printf("Could not process workflow %s: %v\n", entry.Name(), err)
} else {
hasProblems = len(violations) > 0 || hasProblems
printProblems(entry.Name(), violations)
}
workflowsDir := path.Join(target, ".github", "workflows")
workflows, err := os.ReadDir(workflowsDir)
if err != nil {
return violations, fmt.Errorf("could not read workflows directory: %v", err)
}

for _, entry := range workflows {
if entry.IsDir() {
continue
}
} else {
if stat.Name() == "action.yml" || stat.Name() == "action.yaml" {
if violations, err := tryManifest(target); err != nil {
return hasProblems, err
} else {
hasProblems = len(violations) > 0 || hasProblems
printProblems(target, violations)
}

if path.Ext(entry.Name()) != ".yml" {
continue
}

workflowPath := path.Join(workflowsDir, entry.Name())
if workflowViolations, err := tryWorkflow(workflowPath); err == nil {
targetRelativePath := path.Join(".github", "workflows", entry.Name())
violations[targetRelativePath] = workflowViolations
} else {
if violations, err := tryWorkflow(target); err != nil {
return hasProblems, err
} else {
hasProblems = len(violations) > 0 || hasProblems
printProblems(target, violations)
}
fmt.Printf("Could not process workflow %s: %v\n", entry.Name(), err)
}
}

return hasProblems, nil
return violations, nil
}

func tryManifest(manifestPath string) (violations []Violation, err error) {
func analyzeFile(target string) ([]Violation, error) {
if matched, _ := regexp.MatchString("action.ya?ml", target); matched {
return tryManifest(target)
} else {
return tryWorkflow(target)
}
}

func tryManifest(manifestPath string) ([]Violation, error) {
data, err := os.ReadFile(manifestPath)
if err != nil {
return nil, nil
Expand All @@ -175,7 +185,7 @@ func tryManifest(manifestPath string) (violations []Violation, err error) {
return analyzeManifest(&manifest), nil
}

func tryWorkflow(workflowPath string) (violations []Violation, err error) {
func tryWorkflow(workflowPath string) ([]Violation, error) {
data, err := os.ReadFile(workflowPath)
if err != nil {
return nil, err
Expand All @@ -189,30 +199,28 @@ func tryWorkflow(workflowPath string) (violations []Violation, err error) {
return analyzeWorkflow(&workflow), nil
}

func printProblems(file string, violations []Violation) {
if cnt := len(violations); cnt > 0 {
fmt.Printf("Detected %d violation(s) in '%s':\n", cnt, file)
for _, violation := range violations {
if violation.jobId == "" {
fmt.Printf(" step %s has '%s'\n", violation.stepId, violation.problem)
} else {
fmt.Printf(" job %s, step %s has '%s'\n", violation.jobId, violation.stepId, violation.problem)
func printViolations(violations map[string][]Violation) {
for file, fileViolations := range violations {
if cnt := len(fileViolations); cnt > 0 {
fmt.Printf("Detected %d violation(s) in '%s':\n", cnt, file)
for _, violation := range fileViolations {
if violation.jobId == "" {
fmt.Printf(" step %s has '%s'\n", violation.stepId, violation.problem)
} else {
fmt.Printf(" job %s, step %s has '%s'\n", violation.jobId, violation.stepId, violation.problem)
}
}
}
}
}

func getTargets(argv []string) ([]string, error) {
if len(argv) > 0 {
return argv, nil
} else {
if len(argv) == 0 {
wd, err := os.Getwd()
if err != nil {
return nil, err
}

return []string{wd}, err
}

return argv, nil
}

func legal() {
Expand Down
2 changes: 1 addition & 1 deletion main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

func TestMain(m *testing.M) {
commands := map[string]func() int{
"ades": ades,
"ades": run,
}

os.Exit(testscript.RunMain(m, commands))
Expand Down
4 changes: 2 additions & 2 deletions test/args-multiple.txtar
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ on: [push]

jobs:
example:
name: Example safe job
name: Example safe job
runs-on: ubuntu-latest
steps:
- name: Checkout repository
Expand All @@ -35,5 +35,5 @@ jobs:
Scanning project-a

Scanning project-b
Detected 1 violation(s) in 'unsafe.yml':
Detected 1 violation(s) in '.github/workflows/unsafe.yml':
job 'Example unsafe job', step 'Unsafe run' has '${{ inputs.name }}'
4 changes: 2 additions & 2 deletions test/cwd-workflows.txtar
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ on: [push]

jobs:
example:
name: Example safe job
name: Example safe job
runs-on: ubuntu-latest
steps:
- name: Checkout repository
Expand All @@ -32,5 +32,5 @@ jobs:
run: echo 'Hello ${{ inputs.name }}'

-- stdout.txt --
Detected 1 violation(s) in 'unsafe.yml':
Detected 1 violation(s) in '.github/workflows/unsafe.yml':
job 'Example unsafe job', step 'Unsafe run' has '${{ inputs.name }}'

0 comments on commit 6d3efc0

Please sign in to comment.