diff --git a/cli.go b/cli.go index 6e2bbd6..cd7be85 100644 --- a/cli.go +++ b/cli.go @@ -30,8 +30,8 @@ type request struct { promAddr string serverName string srcAddr string - filter string config string + filter map[string]struct{} soIPTOS int soIPTTL int @@ -178,9 +178,9 @@ func getCli(args []string) (*request, []string, error) { grpcAddr: c.String("grpc-addr"), serverName: c.String("server-name"), srcAddr: c.String("source-addr"), - filter: c.String("filter"), config: c.String("config"), count: c.Int("count"), + filter: filterMap(c.String("filter")), soIPTOS: c.Int("tos"), soIPTTL: c.Int("ttl"), @@ -260,3 +260,15 @@ for more information: https://github.com/mehrdadrad/tcpprobe/wiki return r, targets, err } + +func filterMap(s string) map[string]struct{} { + m := map[string]struct{}{} + if len(s) < 1 { + return m + } + + for _, f := range strings.Split(strings.ToLower(s), ";") { + m[f] = struct{}{} + } + return m +} diff --git a/printer.go b/printer.go index 362c877..9bad42d 100644 --- a/printer.go +++ b/printer.go @@ -27,7 +27,7 @@ func (c *client) printer(counter int) { func (c *client) printText(counter int) { v := reflect.ValueOf(c.stats) - filter := strings.ToLower(c.req.filter) + filterLen := len(c.req.filter) ip, _, _ := net.SplitHostPort(c.addr) datetime := time.Unix(c.timestamp, 0).Format(time.RFC3339) @@ -37,7 +37,7 @@ func (c *client) printText(counter int) { if f.Tag.Get("unexported") == "true" { continue } - if strings.Contains(filter, strings.ToLower(f.Name)) || filter == "" { + if _, ok := c.req.filter[strings.ToLower(f.Name)]; ok || filterLen == 0 { fmt.Printf("%s:%v ", f.Name, v.Field(i).Interface()) } } @@ -65,7 +65,7 @@ func (c *client) printJSON(counter int, pretty bool) { c.stats, } - if c.req.filter != "" { + if len(c.req.filter) > 0 { b, err = jsonMarshalFilter(d, c.req.filter, pretty) } else if pretty { b, err = json.MarshalIndent(d, "", " ") @@ -81,7 +81,7 @@ func (c *client) printJSON(counter int, pretty bool) { fmt.Println(string(b)) } -func jsonMarshalFilter(s interface{}, filter string, pretty bool) ([]byte, error) { +func jsonMarshalFilter(s interface{}, filter map[string]struct{}, pretty bool) ([]byte, error) { var m map[string]interface{} b, err := json.Marshal(s) @@ -91,10 +91,8 @@ func jsonMarshalFilter(s interface{}, filter string, pretty bool) ([]byte, error json.Unmarshal(b, &m) - lFilter := strings.ToLower(filter) - for k := range m { - if !strings.Contains(lFilter, strings.ToLower(k)) { + if _, ok := filter[strings.ToLower(k)]; !ok { delete(m, k) } } diff --git a/tp_test.go b/tp_test.go index 1436751..5b4ffab 100644 --- a/tp_test.go +++ b/tp_test.go @@ -221,7 +221,7 @@ func TestPrintText(t *testing.T) { r, w, _ := os.Pipe() os.Stdout = w - c := &client{stats: stats{Rtt: 5}, req: &request{filter: "rtt"}, timestamp: 1609558015} + c := &client{stats: stats{Rtt: 5}, req: &request{filter: map[string]struct{}{"rtt": struct{}{}}}, timestamp: 1609558015} c.printer(0) go io.Copy(buf, r) @@ -236,7 +236,7 @@ func TestPrintJsonPretty(t *testing.T) { r, w, _ := os.Pipe() os.Stdout = w - c := &client{stats: stats{}, req: &request{jsonPretty: true, filter: "rtt"}} + c := &client{stats: stats{}, req: &request{jsonPretty: true, filter: map[string]struct{}{"rtt": struct{}{}}}} c.printer(0) buf := make([]byte, 13) @@ -252,7 +252,7 @@ func TestPrintJson(t *testing.T) { r, w, _ := os.Pipe() os.Stdout = w - c := &client{stats: stats{}, req: &request{json: true, filter: "rtt"}} + c := &client{stats: stats{}, req: &request{json: true, filter: map[string]struct{}{"rtt": struct{}{}}}} c.printer(0) buf := make([]byte, 9)