From d2128b3208e73b3963e763423ebd4faa3621ce75 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 4 Aug 2023 18:41:23 -0700 Subject: [PATCH] Match --signal PIDs with globular-style expression. When multiple instances are running on the machine a PID argument suffixed with a '*' character will signal all matching PIDs. Example: `nats-server --signal reload=*` --- server/signal.go | 90 ++++++++++++++++++++++++++++--------------- server/signal_test.go | 42 ++++++++++++++++++++ 2 files changed, 101 insertions(+), 31 deletions(-) diff --git a/server/signal.go b/server/signal.go index 28955618663..aa133b4f313 100644 --- a/server/signal.go +++ b/server/signal.go @@ -82,54 +82,82 @@ func (s *Server) handleSignals() { // ProcessSignal sends the given signal command to the given process. If pidStr // is empty, this will send the signal to the single running instance of -// nats-server. If multiple instances are running, it returns an error. This returns -// an error if the given process is not running or the command is invalid. -func ProcessSignal(command Command, pidStr string) error { - var pid int - if pidStr == "" { - pids, err := resolvePids() - if err != nil { +// nats-server. If multiple instances are running, pidStr can be a globular +// expression ending with '*'. This returns an error if the given process is +// not running or the command is invalid. +func ProcessSignal(command Command, pidExpr string) error { + var ( + err error + errStr string + pids = make([]int, 1) + pidStr = strings.TrimSuffix(pidExpr, "*") + isGlob = strings.HasSuffix(pidExpr, "*") + ) + + // Validate input if given + if pidStr != "" { + if pids[0], err = strconv.Atoi(pidStr); err != nil { + return fmt.Errorf("invalid pid: %s", pidStr) + } + } + // Gather all PIDs unless the input is specific + if pidStr == "" || isGlob { + if pids, err = resolvePids(); err != nil { return err } - if len(pids) == 0 { - return fmt.Errorf("no %s processes running", processName) + } + // Multiple instances are running and the input is not an expression + if len(pids) > 1 && !isGlob { + errStr = fmt.Sprintf("multiple %s processes running:", processName) + for _, p := range pids { + errStr += fmt.Sprintf("\n%d", p) } - if len(pids) > 1 { - errStr := fmt.Sprintf("multiple %s processes running:\n", processName) - prefix := "" - for _, p := range pids { - errStr += fmt.Sprintf("%s%d", prefix, p) - prefix = "\n" + return errors.New(errStr) + } + // No instances are running + if len(pids) == 0 { + return fmt.Errorf("no %s processes running", processName) + } + + var signum syscall.Signal + if signum, err = CommandToSignal(command); err != nil { + return err + } + + for _, pid := range pids { + if _pidStr := strconv.Itoa(pid); _pidStr != pidStr && pidStr != "" { + if !isGlob || !strings.HasPrefix(_pidStr, pidStr) { + continue } - return errors.New(errStr) } - pid = pids[0] - } else { - p, err := strconv.Atoi(pidStr) - if err != nil { - return fmt.Errorf("invalid pid: %s", pidStr) + if err = kill(pid, signum); err != nil { + errStr += fmt.Sprintf("\nsignal %q %d: %s", command, pid, err) } - pid = p } + if errStr != "" { + return errors.New(errStr) + } + return nil +} - var err error +// Translates a command to a signal number +func CommandToSignal(command Command) (syscall.Signal, error) { switch command { case CommandStop: - err = kill(pid, syscall.SIGKILL) + return syscall.SIGKILL, nil case CommandQuit: - err = kill(pid, syscall.SIGINT) + return syscall.SIGINT, nil case CommandReopen: - err = kill(pid, syscall.SIGUSR1) + return syscall.SIGUSR1, nil case CommandReload: - err = kill(pid, syscall.SIGHUP) + return syscall.SIGHUP, nil case commandLDMode: - err = kill(pid, syscall.SIGUSR2) + return syscall.SIGUSR2, nil case commandTerm: - err = kill(pid, syscall.SIGTERM) + return syscall.SIGTERM, nil default: - err = fmt.Errorf("unknown signal %q", command) + return 0, fmt.Errorf("unknown signal %q", command) } - return err } // resolvePids returns the pids for all running nats-server processes. diff --git a/server/signal_test.go b/server/signal_test.go index 1127546d317..77c388f0482 100644 --- a/server/signal_test.go +++ b/server/signal_test.go @@ -141,6 +141,48 @@ func TestProcessSignalMultipleProcesses(t *testing.T) { } } +func TestProcessSignalMultipleProcessesGlob(t *testing.T) { + pid := os.Getpid() + pgrepBefore := pgrep + pgrep = func() ([]byte, error) { + return []byte(fmt.Sprintf("123\n456\n%d\n", pid)), nil + } + defer func() { + pgrep = pgrepBefore + }() + + err := ProcessSignal(CommandStop, "*") + if err == nil { + t.Fatal("Expected error") + } + expectedStr := "\nsignal \"stop\" 123: no such process" + expectedStr += "\nsignal \"stop\" 456: no such process" + if err.Error() != expectedStr { + t.Fatalf("Error is incorrect.\nexpected: %s\ngot: %s", expectedStr, err.Error()) + } +} + +func TestProcessSignalMultipleProcessesGlobPartial(t *testing.T) { + pid := os.Getpid() + pgrepBefore := pgrep + pgrep = func() ([]byte, error) { + return []byte(fmt.Sprintf("123\n124\n456\n%d\n", pid)), nil + } + defer func() { + pgrep = pgrepBefore + }() + + err := ProcessSignal(CommandStop, "12*") + if err == nil { + t.Fatal("Expected error") + } + expectedStr := "\nsignal \"stop\" 123: no such process" + expectedStr += "\nsignal \"stop\" 124: no such process" + if err.Error() != expectedStr { + t.Fatalf("Error is incorrect.\nexpected: %s\ngot: %s", expectedStr, err.Error()) + } +} + func TestProcessSignalPgrepError(t *testing.T) { pgrepBefore := pgrep pgrep = func() ([]byte, error) {