Skip to content

Commit

Permalink
Fix sigint handling when running in helper mode (#869)
Browse files Browse the repository at this point in the history
  • Loading branch information
olsen232 committed Jun 20, 2023
1 parent 889a47e commit b07ec84
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 13 deletions.
33 changes: 27 additions & 6 deletions cli_helper/kart.c
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ int is_helper_enabled()
/**
* @brief Exit signal handler for SIGALRM
*/
void exit_on_alarm(int sig)
void exit_on_sigalrm(int sig)
{
int semval = semctl(semid, SEMNUM, GETVAL);
if (semval < 0)
Expand All @@ -209,6 +209,17 @@ void exit_on_alarm(int sig)
exit(exit_code);
}

/**
* @brief Exit signal handler for SIGINT.
* Tries to kill the whole process group.
*/
void exit_on_sigint(int sig)
{
putchar('\n');
killpg(0, sig);
exit(128 + sig);
}

int main(int argc, char **argv, char **environ)
{
char cmd_path[PATH_MAX];
Expand All @@ -221,6 +232,10 @@ int main(int argc, char **argv, char **environ)
{
debug("enabled %s, pid=%d\n", cmd_path, getpid());

// Make this process the leader of a process group:
// The procress-group ID (pgid) will be the same as the pid.
setpgrp();

// start or use an existing helper process
char **env_ptr;

Expand Down Expand Up @@ -267,8 +282,15 @@ int main(int argc, char **argv, char **environ)
int fp = open(getcwd(NULL, 0), O_RDONLY);
int fds[4] = {fileno(stdin), fileno(stdout), fileno(stderr), fp};

char *socket_filename = malloc(strlen(getenv("HOME")) + strlen(".kart.socket") + 2);
sprintf(socket_filename, "%s/%s", getenv("HOME"), ".kart.socket");
size_t socket_filename_sz = strlen(getenv("HOME")) + strlen("/.kart..socket") + sizeof(pid_t) * 3 + 1;
char *socket_filename = malloc(socket_filename_sz);
int r = snprintf(socket_filename, socket_filename_sz, "%s/.kart.%d.socket", getenv("HOME"), getsid(0));
if (r < 0 || (size_t) r >= socket_filename_sz)
{
fprintf(stderr, "Error allocating socket filename\n");
exit(1);
}

int socket_fd = socket(AF_UNIX, SOCK_STREAM, 0);

struct sockaddr_un addr;
Expand All @@ -289,8 +311,6 @@ int main(int argc, char **argv, char **environ)
// process are left open in it
if (fork() == 0)
{
setsid();

// start helper in background and wait
char *helper_argv[] = {&cmd_path[0], "helper", "--socket", socket_filename, NULL};

Expand Down Expand Up @@ -370,7 +390,8 @@ int main(int argc, char **argv, char **environ)
memcpy((int *)CMSG_DATA(cmsg), fds, sizeof(fds));
msg.msg_controllen = cmsg->cmsg_len;

signal(SIGALRM, exit_on_alarm);
signal(SIGALRM, exit_on_sigalrm);
signal(SIGINT, exit_on_sigint);

if (sendmsg(socket_fd, &msg, 0) < 0)
{
Expand Down
9 changes: 6 additions & 3 deletions kart/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,11 @@ def _configure_process_cleanup_nonwindows():
# to attempt to give Kart it's own process group ID (PGID)
if "_KART_PGID_SET" not in os.environ and os.getpid() != os.getpgrp():
try:
os.setsid()
os.setpgrp()
# No need to do this again for any Kart subprocess of this Kart process.
os.environ["_KART_PGID_SET"] = "1"
except OSError as e:
L.warning("Error setting Kart PGID - os.setsid() failed. %s", e)
L.warning("Error setting Kart PGID - os.setpgrp() failed. %s", e)

# If Kart now has its own PGID, which its children share - we want to SIGTERM that when Kart exits.
if os.getpid() == os.getpgrp():
Expand All @@ -199,7 +199,10 @@ def _cleanup_process_group(signum, stack_frame):
if _kart_process_group_killed:
return
_kart_process_group_killed = True
os.killpg(os.getpid(), signum)
try:
os.killpg(0, signum)
except Exception:
pass
sys.exit(128 + signum)

signal.signal(signal.SIGTERM, _cleanup_process_group)
Expand Down
2 changes: 1 addition & 1 deletion kart/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def print_version(ctx):
# report on whether this was run through helper mode
helper_pid = os.environ.get("KART_HELPER_PID")
if helper_pid:
click.echo(f"Executed via helper, PID: {helper_pid}")
click.echo(f"Executed via helper, SID={os.getsid(0)} PID={helper_pid}")

ctx.exit()

Expand Down
19 changes: 18 additions & 1 deletion kart/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,18 @@ def _helper_log(msg):
log_file.write(f"{datetime.now()} [{os.getpid()}]: {msg}\n")


def getsid():
if hasattr(os, "getsid"):
return os.getsid(0)
return 0


@click.command(context_settings=dict(ignore_unknown_options=True))
@click.pass_context
@click.option(
"--socket",
"socket_filename",
default=Path.home() / ".kart.socket",
default=Path.home() / f".kart.{getsid()}.socket",
show_default=True,
help="What socket to use",
)
Expand Down Expand Up @@ -133,6 +139,7 @@ def helper(ctx, socket_filename, timeout, args):
else:
# child
_helper_log("post-fork")

payload, fds = recv_json_and_fds(client, maxfds=4)
if not payload or len(fds) != 4:
click.echo(
Expand Down Expand Up @@ -169,6 +176,16 @@ def helper(ctx, socket_filename, timeout, args):
f"Payload:\n{repr(payload)}",
)
else:
try:
# Join the process group of the calling process - so that if they get killed, we get killed to.
os.setpgid(0, calling_environment["pid"])
os.environ["_KART_PGID_SET"] = "1"
except OSError as e:
# Kart will still work even if this fails: it just means SIGINT Ctrl+C might not work properly.
# We'll just log it and hope for the best.
_helper_log(f"error joining caller's process group: {e}")
pass

sys.argv[1:] = calling_environment["argv"][1:]
_helper_log(f"cmd={' '.join(calling_environment['argv'])}")
os.environ.clear()
Expand Down
97 changes: 95 additions & 2 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import contextlib
import json
import os
import re
import sys
from pathlib import Path
import signal
import sys
from time import sleep

import pytest

from kart import cli
from kart import cli, is_windows


H = pytest.helpers.helpers()
Expand Down Expand Up @@ -109,3 +112,93 @@ def test_ext_run(tmp_path, cli_runner, sys_path_reset):
assert (val1, val2) == (False, 3)
assert Path(sfile) == (tmp_path / "three.py")
assert sname == "kart.ext_run.three"


TEST_SIGINT_PY = r"""
import datetime
import os
import sys
from time import sleep
def main(ctx, args):
print(os.getpid())
fork_id = os.fork()
if fork_id == 0:
with open(args[0], 'w') as output:
while True:
output.write(datetime.datetime.now().isoformat() + '\n')
output.flush()
sleep(0.01)
else:
print(fork_id)
sys.stdout.flush()
os.wait()
"""


@pytest.mark.skipif(is_windows, reason="No SIGINT on windows")
@pytest.mark.parametrize("use_helper", [False, True])
def test_sigint_handling_unix(use_helper, tmp_path):
import subprocess

kart_bin_dir = Path(sys.executable).parent
kart_exe = kart_bin_dir / "kart"
kart_cli_exe = kart_bin_dir / "kart_cli"

kart_with_helper_mode = kart_exe if kart_cli_exe.is_file() else kart_cli_exe
kart_without_helper = kart_cli_exe if kart_cli_exe.is_file() else kart_exe

if use_helper and not kart_with_helper_mode.is_file():
raise pytest.skip(f"Couldn't find kart helper mode in {kart_bin_dir}")

kart_to_use = kart_with_helper_mode if use_helper else kart_without_helper
assert kart_to_use.is_file(), "Couldn't find kart"

# working example
test_sigint_py_path = tmp_path / "test_sigint.py"
with open(test_sigint_py_path, "wt") as fs:
fs.write(TEST_SIGINT_PY)

subprocess_output_path = tmp_path / "output"

env = os.environ.copy()
env.pop("_KART_PGID_SET", None)
env.pop("NO_CONFIGURE_PROCESS_CLEANUP", None)

p = subprocess.Popen(
[
str(kart_to_use),
"ext-run",
str(test_sigint_py_path),
str(subprocess_output_path),
],
encoding="utf8",
env=env,
stdout=subprocess.PIPE,
)
sleep(1)
child_pid = int(p.stdout.readline())
grandchild_pid = int(p.stdout.readline())

# The new kart process should be in a new process group.
assert os.getpgid(0) != os.getpgid(child_pid)
# And its subprocess should be in the same process group.
assert os.getpgid(child_pid) == os.getpgid(grandchild_pid)

# Time goes past and grandchild keeps writing output
output_size_1 = subprocess_output_path.stat().st_size
sleep(1)
assert p.poll() == None # Grandchild subprocess keeps running...
output_size_2 = subprocess_output_path.stat().st_size
assert output_size_2 > output_size_1 # Grandchild output keeps growing...

os.kill(child_pid, signal.SIGINT)
sleep(1)
assert p.poll() != None

# Time goes past but granchild's output has stopped.
output_size_3 = subprocess_output_path.stat().st_size
sleep(1)
output_size_4 = subprocess_output_path.stat().st_size
assert output_size_3 == output_size_4

0 comments on commit b07ec84

Please sign in to comment.