Skip to content

Commit

Permalink
Refactoring: Remove global variables & add more tests (#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
radulucut committed Feb 17, 2024
1 parent 6022008 commit 62672e5
Show file tree
Hide file tree
Showing 40 changed files with 1,355 additions and 921 deletions.
131 changes: 114 additions & 17 deletions cmd/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"os"
"path/filepath"
"runtime"
"slices"
"strconv"
"strings"

Expand All @@ -25,28 +26,62 @@ var (

var SESSION_PATH string

func inProgressUpdates(ci bool) bool {
return !(ci)
func (r *Root) updateContext(cmd string, args []string) error {
r.ctx.Cmd = cmd // Get the command name

targetQuery, err := parseTargetQuery(cmd, args)
if err != nil {
return err
}

r.ctx.Target = targetQuery.Target

if targetQuery.From != "" {
r.ctx.From = targetQuery.From
}

if targetQuery.Resolver != "" {
r.ctx.Resolver = targetQuery.Resolver
}

// Check env for CI
if os.Getenv("CI") != "" {
r.ctx.CIMode = true
}

// Check if it is a terminal or being piped/redirected
// We want to disable realtime updates if that is the case
f, ok := r.printer.OutWriter.(*os.File)
if ok {
stdoutFileInfo, err := f.Stat()
if err != nil {
return fmt.Errorf("stdout stat failed: %s", err)
}
if (stdoutFileInfo.Mode() & os.ModeCharDevice) == 0 {
// stdout is piped, run in ci mode
r.ctx.CIMode = true
}
} else {
r.ctx.CIMode = true
}

return nil
}

func createLocations(from string) ([]globalping.Locations, bool, error) {
fromArr := strings.Split(from, ",")
if len(fromArr) == 1 {
mId, err := mapToMeasurementID(fromArr[0])
mId, err := mapFromHistory(fromArr[0])
if err != nil {
return nil, false, err
}
isPreviousMeasurementId := false
isFromHistory := false
if mId == "" {
mId = strings.TrimSpace(fromArr[0])
} else {
isPreviousMeasurementId = true
isFromHistory = true
}
return []globalping.Locations{
{
Magic: mId,
},
}, isPreviousMeasurementId, nil
return []globalping.Locations{{Magic: mId}}, isFromHistory, nil
}
locations := make([]globalping.Locations, len(fromArr))
for i, v := range fromArr {
Expand All @@ -57,8 +92,70 @@ func createLocations(from string) ([]globalping.Locations, bool, error) {
return locations, false, nil
}

// Maps a location to a measurement ID if possible
func mapToMeasurementID(location string) (string, error) {
type TargetQuery struct {
Target string
From string
Resolver string
}

var commandsWithResolver = []string{
"dns",
"http",
}

func parseTargetQuery(cmd string, args []string) (*TargetQuery, error) {
targetQuery := &TargetQuery{}
if len(args) == 0 {
return nil, errors.New("provided target is empty")
}

resolver, argsWithoutResolver := findAndRemoveResolver(args)
if resolver != "" {
// resolver was found
if !slices.Contains(commandsWithResolver, cmd) {
return nil, fmt.Errorf("command %s does not accept a resolver argument. @%s was provided", cmd, resolver)
}

targetQuery.Resolver = resolver
}

targetQuery.Target = argsWithoutResolver[0]

if len(argsWithoutResolver) > 1 {
if argsWithoutResolver[1] == "from" {
targetQuery.From = strings.TrimSpace(strings.Join(argsWithoutResolver[2:], " "))
} else {
return nil, errors.New("invalid command format")
}
}

return targetQuery, nil
}

func findAndRemoveResolver(args []string) (string, []string) {
var resolver string
resolverIndex := -1
for i := 0; i < len(args); i++ {
if len(args[i]) > 0 && args[i][0] == '@' && args[i-1] != "from" {
resolver = args[i][1:]
resolverIndex = i
break
}
}

if resolverIndex == -1 {
// resolver was not found
return "", args
}

argsClone := slices.Clone(args)
argsWithoutResolver := slices.Delete(argsClone, resolverIndex, resolverIndex+1)

return resolver, argsWithoutResolver
}

// Maps a location to a measurement ID from history, if possible.
func mapFromHistory(location string) (string, error) {
if location == "" {
return "", nil
}
Expand All @@ -67,19 +164,19 @@ func mapToMeasurementID(location string) (string, error) {
if err != nil {
return "", ErrInvalidIndex
}
return getMeasurementID(index)
return getIdFromHistory(index)
}
if location == "first" {
return getMeasurementID(1)
return getIdFromHistory(1)
}
if location == "last" || location == "previous" {
return getMeasurementID(-1)
return getIdFromHistory(-1)
}
return "", nil
}

// Returns the measurement ID at the given index from the session history
func getMeasurementID(index int) (string, error) {
func getIdFromHistory(index int) (string, error) {
if index == 0 {
return "", ErrInvalidIndex
}
Expand Down Expand Up @@ -130,7 +227,7 @@ func getMeasurementID(index int) (string, error) {
}

// Saves the measurement ID to the session history
func saveMeasurementID(id string) error {
func saveIdToHistory(id string) error {
_, err := os.Stat(getSessionPath())
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
Expand Down
Loading

0 comments on commit 62672e5

Please sign in to comment.