Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bugs in amq_detection_component's use of select #1587

Merged
merged 3 commits into from Sep 9, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
74 changes: 30 additions & 44 deletions trunk/detection/executor/cpp/batch/main.cpp
Expand Up @@ -74,6 +74,7 @@ std::string get_component_name_and_set_env_var();

std::string get_log_level_and_set_env_var();

bool quit_received(bool got_message_on_last_pull);

/**
* This is the main program for the Detection Component. It accepts two
Expand Down Expand Up @@ -286,26 +287,10 @@ string get_file_name(const string& s) {
template <typename Logger, typename ComponentHandle>
int run_jobs(Logger &logger, const std::string &broker_uri, const std::string &request_queue,
const std::string &app_dir, ComponentHandle &detection_engine) {
int pollingInterval = 1;

// Remain in loop handling job request messages
// until 'q\n' is received on stdin
bool error_occurred = false;
try {

int nfds = 0;
int bytes_read = 0;
int input_buf_size = 2;
char input_buf[input_buf_size];
string quit_string("q\n");
fd_set readfds;
struct timeval tv;
bool keep_running = true;

// Set timeout on check for 'q\n' to 5 seconds
tv.tv_sec = 5;
tv.tv_usec = 0;

MPFMessenger<Logger> messenger(logger, broker_uri, request_queue);

detection_engine.SetRunDirectory(app_dir + "/../plugins");
Expand All @@ -321,22 +306,16 @@ int run_jobs(Logger &logger, const std::string &broker_uri, const std::string &r
string service_name(getenv("SERVICE_NAME"));
logger.Info("Completed initialization of ", service_name, '.');

bool gotMessageOnLastPull = false;
while (keep_running) {
// Sleep for pollingInterval seconds between polls.
if (gotMessageOnLastPull == false) {
sleep(pollingInterval);
}
gotMessageOnLastPull = false;

// Initially set to true to avoid blocking on the first iteration.
bool got_message_on_last_pull = true;
// Remain in loop handling job request messages until 'q\n' is received on stdin
while (!quit_received(got_message_on_last_pull)) {
// Receive job request
MPFMessageMetadata msg_metadata;
std::vector<unsigned char> request_contents = messenger.ReceiveMessage(msg_metadata);

if (!request_contents.empty()) {
// Set not to sleep flag.
gotMessageOnLastPull = true;

got_message_on_last_pull = !request_contents.empty();
if (got_message_on_last_pull) {
MPFDetectionBuffer detection_buf(request_contents);

std::vector<unsigned char> detection_response_body;
Expand All @@ -347,7 +326,7 @@ int run_jobs(Logger &logger, const std::string &broker_uri, const std::string &r

map<string, string> algorithm_properties;
detection_buf.GetAlgorithmProperties(algorithm_properties);
for (auto env_prop_pair : env_job_props) {
for (const auto& env_prop_pair : env_job_props) {
algorithm_properties[env_prop_pair.first] = env_prop_pair.second;
}

Expand Down Expand Up @@ -637,22 +616,8 @@ int run_jobs(Logger &logger, const std::string &broker_uri, const std::string &r
logger.Error('[', job_name, "] Failed to generate a detection response.");
}
}

// Check for 'q\n' input
// Read from file descriptor 0 (stdin)
FD_ZERO(&readfds);
FD_SET(0, &readfds);
nfds = select(1, &readfds, NULL, NULL, &tv);
if (nfds != 0 && FD_ISSET(0, &readfds)) {
bytes_read = read(0, input_buf, input_buf_size);
string std_input(input_buf);
std_input.resize(input_buf_size);
if ((bytes_read > 0) && (std_input == quit_string)) {
logger.Info("Received quit command.");
keep_running = false;
}
}
} // end while
logger.Info("Received quit command.");
} catch (std::exception &e) {
error_occurred = true;
logger.Error("Standard Exception caught in main.cpp: ", e.what(), '\n');
Expand All @@ -671,3 +636,24 @@ int run_jobs(Logger &logger, const std::string &broker_uri, const std::string &r
log4cxx::LogManager::shutdown();
return error_occurred ? 1 : 0;
}

bool quit_received(bool got_message_on_last_pull) {
// Check for 'q\n' input
// Read from file descriptor 0 (stdin)
fd_set stdin_fd_set;
FD_ZERO(&stdin_fd_set);
FD_SET(0, &stdin_fd_set);
// Set timeout on check for 'q\n'
struct timeval select_timeout {
.tv_sec = got_message_on_last_pull ? 0 : 5,
.tv_usec = 0 };

int nfds = select(1, &stdin_fd_set, nullptr, nullptr, &select_timeout);
if (nfds < 1 || !FD_ISSET(0, &stdin_fd_set)) {
return false;
}

char input_buf[2];
size_t bytes_read = read(0, input_buf, 2);
return bytes_read >= 2 && input_buf[0] == 'q' && input_buf[1] == '\n';
}