diff --git a/net/ebtables/ebtables.go b/net/ebtables/ebtables.go index 3e984a27..88b6e3b3 100644 --- a/net/ebtables/ebtables.go +++ b/net/ebtables/ebtables.go @@ -121,10 +121,17 @@ func getEbtablesVersionString(exec utilexec.Interface) (string, error) { if err != nil { return "", err } - versionMatcher := regexp.MustCompile(`v([0-9]+\.[0-9]+\.[0-9]+)`) - match := versionMatcher.FindStringSubmatch(string(bytes)) + return parseVersion(string(bytes)) +} + +func parseVersion(version string) (string, error) { + // the regular expression contains `v?` at the beginning because + // different OS distros have different version format output i.e + // either starts with `v` or it doesn't + versionMatcher := regexp.MustCompile(`v?([0-9]+\.[0-9]+\.[0-9]+)`) + match := versionMatcher.FindStringSubmatch(version) if match == nil { - return "", fmt.Errorf("no ebtables version found in string: %s", bytes) + return "", fmt.Errorf("no ebtables version found in string: %s", version) } return match[1], nil } diff --git a/net/ebtables/ebtables_test.go b/net/ebtables/ebtables_test.go index 9933493f..1920ac9b 100644 --- a/net/ebtables/ebtables_test.go +++ b/net/ebtables/ebtables_test.go @@ -167,3 +167,61 @@ Bridge chain: TEST, entries: 0, policy: ACCEPT`), nil, nil t.Errorf("expected err = nil") } } + +func Test_parseVersion(t *testing.T) { + tests := []struct { + name string + version string + want string + wantErr bool + }{ + { + name: "version starting with `v`", + version: "v2.0.10", + want: "2.0.10", + wantErr: false, + }, + { + name: "version without containing `v`", + version: "2.0.10", + want: "2.0.10", + wantErr: false, + }, + { + name: "version containing `v` in between the regex expression match", + version: "2.0v.10", + want: "", + wantErr: true, + }, + { + name: "version containing `v` after the regex expression match", + version: "2.0.10v", + want: "2.0.10", + wantErr: false, + }, + { + name: "version starting with `v` and containing a symbol in between", + version: "v2.0.10-4", + want: "2.0.10", + wantErr: false, + }, + { + name: "version starting with `v` and containing a symbol/alphabets in between", + version: "v2.0a.10-4", + want: "", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseVersion(tt.version) + if (err != nil) != tt.wantErr { + t.Errorf("parseVersion() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("parseVersion() got = %v, want %v", got, tt.want) + } + }) + } +}