Skip to content

Commit

Permalink
Stored client mode at decision time. Remove clientMode func injection.
Browse files Browse the repository at this point in the history
  • Loading branch information
mlw committed Apr 28, 2023
1 parent b07b946 commit e7d0cef
Show file tree
Hide file tree
Showing 15 changed files with 74 additions and 98 deletions.
2 changes: 2 additions & 0 deletions Source/common/SNTCachedDecision.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,6 @@
@property NSString *customMsg;
@property BOOL silentBlock;

@property SNTClientMode decisionClientMode;

@end
7 changes: 3 additions & 4 deletions Source/santad/Logs/EndpointSecurity/Logger.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,9 @@ class Logger {
public:
static std::unique_ptr<Logger> Create(
std::shared_ptr<santa::santad::event_providers::endpoint_security::EndpointSecurityAPI> esapi,
SNTEventLogType log_type, SNTDecisionCache *decision_cache,
santa::santad::logs::endpoint_security::serializers::ClientModeFunc GetClientMode,
NSString *event_log_path, NSString *spool_log_path, size_t spool_dir_size_threshold,
size_t spool_file_size_threshold, uint64_t spool_flush_timeout_ms);
SNTEventLogType log_type, SNTDecisionCache *decision_cache, NSString *event_log_path,
NSString *spool_log_path, size_t spool_dir_size_threshold, size_t spool_file_size_threshold,
uint64_t spool_flush_timeout_ms);

Logger(std::shared_ptr<serializers::Serializer> serializer,
std::shared_ptr<writers::Writer> writer);
Expand Down
14 changes: 6 additions & 8 deletions Source/santad/Logs/EndpointSecurity/Logger.mm
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
using santa::santad::event_providers::endpoint_security::EnrichedProcess;
using santa::santad::event_providers::endpoint_security::Message;
using santa::santad::logs::endpoint_security::serializers::BasicString;
using santa::santad::logs::endpoint_security::serializers::ClientModeFunc;
using santa::santad::logs::endpoint_security::serializers::Empty;
using santa::santad::logs::endpoint_security::serializers::Protobuf;
using santa::santad::logs::endpoint_security::writers::File;
Expand All @@ -52,26 +51,25 @@
// Translate configured log type to appropriate Serializer/Writer pairs
std::unique_ptr<Logger> Logger::Create(std::shared_ptr<EndpointSecurityAPI> esapi,
SNTEventLogType log_type, SNTDecisionCache *decision_cache,
ClientModeFunc GetClientMode, NSString *event_log_path,
NSString *spool_log_path, size_t spool_dir_size_threshold,
NSString *event_log_path, NSString *spool_log_path,
size_t spool_dir_size_threshold,
size_t spool_file_size_threshold,
uint64_t spool_flush_timeout_ms) {
switch (log_type) {
case SNTEventLogTypeFilelog:
return std::make_unique<Logger>(
BasicString::Create(esapi, std::move(decision_cache), std::move(GetClientMode)),
BasicString::Create(esapi, std::move(decision_cache)),
File::Create(event_log_path, kFlushBufferTimeoutMS, kBufferBatchSizeBytes,
kMaxExpectedWriteSizeBytes));
case SNTEventLogTypeSyslog:
return std::make_unique<Logger>(
BasicString::Create(esapi, std::move(decision_cache), std::move(GetClientMode), false),
Syslog::Create());
return std::make_unique<Logger>(BasicString::Create(esapi, std::move(decision_cache), false),
Syslog::Create());
case SNTEventLogTypeNull: return std::make_unique<Logger>(Empty::Create(), Null::Create());
case SNTEventLogTypeProtobuf:
LOGW(@"The EventLogType value protobuf is currently in beta. The protobuf schema is subject "
@"to change.");
return std::make_unique<Logger>(
Protobuf::Create(esapi, std::move(decision_cache), std::move(GetClientMode)),
Protobuf::Create(esapi, std::move(decision_cache)),
Spool::Create([spool_log_path UTF8String], spool_dir_size_threshold,
spool_file_size_threshold, spool_flush_timeout_ms));
default: LOGE(@"Invalid log type: %ld", log_type); return nullptr;
Expand Down
39 changes: 21 additions & 18 deletions Source/santad/Logs/EndpointSecurity/LoggerTest.mm
Original file line number Diff line number Diff line change
Expand Up @@ -103,26 +103,26 @@ - (void)testCreate {
// Ensure that the factory method creates expected serializers/writers pairs
auto mockESApi = std::make_shared<MockEndpointSecurityAPI>();

XCTAssertEqual(nullptr, Logger::Create(mockESApi, (SNTEventLogType)123, nil, nullptr,
@"/tmp/temppy", @"/tmp/spool", 1, 1, 1));
XCTAssertEqual(nullptr, Logger::Create(mockESApi, (SNTEventLogType)123, nil, @"/tmp/temppy",
@"/tmp/spool", 1, 1, 1));

LoggerPeer logger(Logger::Create(mockESApi, SNTEventLogTypeFilelog, nil, nullptr, @"/tmp/temppy",
@"/tmp/spool", 1, 1, 1));
LoggerPeer logger(
Logger::Create(mockESApi, SNTEventLogTypeFilelog, nil, @"/tmp/temppy", @"/tmp/spool", 1, 1, 1));
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<BasicString>(logger.Serializer()));
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<File>(logger.Writer()));

logger = LoggerPeer(Logger::Create(mockESApi, SNTEventLogTypeSyslog, nil, nullptr, @"/tmp/temppy",
@"/tmp/spool", 1, 1, 1));
logger = LoggerPeer(
Logger::Create(mockESApi, SNTEventLogTypeSyslog, nil, @"/tmp/temppy", @"/tmp/spool", 1, 1, 1));
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<BasicString>(logger.Serializer()));
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<Syslog>(logger.Writer()));

logger = LoggerPeer(Logger::Create(mockESApi, SNTEventLogTypeNull, nil, nullptr, @"/tmp/temppy",
@"/tmp/spool", 1, 1, 1));
logger = LoggerPeer(
Logger::Create(mockESApi, SNTEventLogTypeNull, nil, @"/tmp/temppy", @"/tmp/spool", 1, 1, 1));
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<Empty>(logger.Serializer()));
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<Null>(logger.Writer()));

logger = LoggerPeer(Logger::Create(mockESApi, SNTEventLogTypeProtobuf, nil, nullptr,
@"/tmp/temppy", @"/tmp/spool", 1, 1, 1));
logger = LoggerPeer(Logger::Create(mockESApi, SNTEventLogTypeProtobuf, nil, @"/tmp/temppy",
@"/tmp/spool", 1, 1, 1));
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<Protobuf>(logger.Serializer()));
XCTAssertNotEqual(nullptr, std::dynamic_pointer_cast<Spool>(logger.Writer()));
}
Expand All @@ -136,16 +136,19 @@ - (void)testLog {
es_message_t msg;

mockESApi->SetExpectationsRetainReleaseMessage();
auto enrichedMsg = std::make_shared<EnrichedMessage>(
EnrichedClose(Message(mockESApi, &msg),
EnrichedProcess(std::nullopt, std::nullopt, std::nullopt, std::nullopt,
EnrichedFile(std::nullopt, std::nullopt, std::nullopt)),
EnrichedFile(std::nullopt, std::nullopt, std::nullopt)));

EXPECT_CALL(*mockSerializer, SerializeMessage(testing::A<const EnrichedClose &>())).Times(1);
EXPECT_CALL(*mockWriter, Write).Times(1);
{
auto enrichedMsg = std::make_shared<EnrichedMessage>(
EnrichedClose(Message(mockESApi, &msg),
EnrichedProcess(std::nullopt, std::nullopt, std::nullopt, std::nullopt,
EnrichedFile(std::nullopt, std::nullopt, std::nullopt)),
EnrichedFile(std::nullopt, std::nullopt, std::nullopt)));

EXPECT_CALL(*mockSerializer, SerializeMessage(testing::A<const EnrichedClose &>())).Times(1);
EXPECT_CALL(*mockWriter, Write).Times(1);

Logger(mockSerializer, mockWriter).Log(enrichedMsg);
Logger(mockSerializer, mockWriter).Log(enrichedMsg);
}

XCTBubbleMockVerifyAndClearExpectations(mockESApi.get());
XCTBubbleMockVerifyAndClearExpectations(mockSerializer.get());
Expand Down
4 changes: 2 additions & 2 deletions Source/santad/Logs/EndpointSecurity/Serializers/BasicString.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ class BasicString : public Serializer {
public:
static std::shared_ptr<BasicString> Create(
std::shared_ptr<santa::santad::event_providers::endpoint_security::EndpointSecurityAPI> esapi,
SNTDecisionCache *decision_cache, ClientModeFunc GetClientMode, bool prefix_time_name = true);
SNTDecisionCache *decision_cache, bool prefix_time_name = true);

BasicString(
std::shared_ptr<santa::santad::event_providers::endpoint_security::EndpointSecurityAPI> esapi,
SNTDecisionCache *decision_cache, ClientModeFunc GetClientMode, bool prefix_time_name);
SNTDecisionCache *decision_cache, bool prefix_time_name);

std::vector<uint8_t> SerializeMessage(
const santa::santad::event_providers::endpoint_security::EnrichedClose &) override;
Expand Down
13 changes: 4 additions & 9 deletions Source/santad/Logs/EndpointSecurity/Serializers/BasicString.mm
Original file line number Diff line number Diff line change
Expand Up @@ -177,18 +177,13 @@ static inline void AppendUserGroup(std::string &str, const audit_token_t &tok,

std::shared_ptr<BasicString> BasicString::Create(std::shared_ptr<EndpointSecurityAPI> esapi,
SNTDecisionCache *decision_cache,
ClientModeFunc GetClientMode,
bool prefix_time_name) {
return std::make_shared<BasicString>(esapi, decision_cache, std::move(GetClientMode),
prefix_time_name);
return std::make_shared<BasicString>(esapi, decision_cache, prefix_time_name);
}

BasicString::BasicString(std::shared_ptr<EndpointSecurityAPI> esapi,
SNTDecisionCache *decision_cache, ClientModeFunc GetClientMode,
bool prefix_time_name)
: Serializer(std::move(decision_cache), std::move(GetClientMode)),
esapi_(esapi),
prefix_time_name_(prefix_time_name) {}
SNTDecisionCache *decision_cache, bool prefix_time_name)
: Serializer(std::move(decision_cache)), esapi_(esapi), prefix_time_name_(prefix_time_name) {}

std::string BasicString::CreateDefaultString(size_t reserved_size) {
std::string str;
Expand Down Expand Up @@ -294,7 +289,7 @@ static inline void AppendUserGroup(std::string &str, const audit_token_t &tok,
msg.instigator().real_group());

str.append("|mode=");
str.append(GetModeString(GetClientMode()));
str.append(GetModeString(cd.decisionClientMode));
str.append("|path=");
str.append(FilePath(esm.event.exec.target->executable).Sanitized());

Expand Down
25 changes: 10 additions & 15 deletions Source/santad/Logs/EndpointSecurity/Serializers/BasicStringTest.mm
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
using santa::santad::event_providers::endpoint_security::Enricher;
using santa::santad::event_providers::endpoint_security::Message;
using santa::santad::logs::endpoint_security::serializers::BasicString;
using santa::santad::logs::endpoint_security::serializers::ClientModeFunc;
using santa::santad::logs::endpoint_security::serializers::Serializer;

namespace santa::santad::logs::endpoint_security::serializers {
Expand All @@ -57,12 +56,10 @@
using santa::santad::logs::endpoint_security::serializers::GetReasonString;

std::string BasicStringSerializeMessage(std::shared_ptr<MockEndpointSecurityAPI> mockESApi,
es_message_t *esMsg, SNTDecisionCache *decisionCache,
ClientModeFunc GetClientMode) {
es_message_t *esMsg, SNTDecisionCache *decisionCache) {
mockESApi->SetExpectationsRetainReleaseMessage();

std::shared_ptr<Serializer> bs =
BasicString::Create(mockESApi, decisionCache, std::move(GetClientMode), false);
std::shared_ptr<Serializer> bs = BasicString::Create(mockESApi, decisionCache, false);
std::vector<uint8_t> ret = bs->SerializeMessage(Enricher().Enrich(Message(mockESApi, esMsg)));

XCTBubbleMockVerifyAndClearExpectations(mockESApi.get());
Expand All @@ -72,7 +69,7 @@

std::string BasicStringSerializeMessage(es_message_t *esMsg) {
auto mockESApi = std::make_shared<MockEndpointSecurityAPI>();
return BasicStringSerializeMessage(mockESApi, esMsg, nil, nullptr);
return BasicStringSerializeMessage(mockESApi, esMsg, nil);
}

@interface BasicStringTest : XCTestCase
Expand All @@ -97,6 +94,7 @@ - (void)setUp {
self.testCachedDecision.sha256 = @"1234_hash";
self.testCachedDecision.quarantineURL = @"google.com";
self.testCachedDecision.certSHA256 = @"5678_hash";
self.testCachedDecision.decisionClientMode = SNTClientModeLockdown;

self.mockDecisionCache = OCMClassMock([SNTDecisionCache class]);
OCMStub([self.mockDecisionCache sharedCache]).andReturn(self.mockDecisionCache);
Expand Down Expand Up @@ -166,9 +164,7 @@ - (void)testSerializeMessageExec {
.WillOnce(testing::Return(es_string_token_t{5, "-l\n-t"}))
.WillOnce(testing::Return(es_string_token_t{8, "-v\r--foo"}));

std::string got = BasicStringSerializeMessage(mockESApi, &esMsg, self.mockDecisionCache, ^{
return [self.mockConfigurator clientMode];
});
std::string got = BasicStringSerializeMessage(mockESApi, &esMsg, self.mockDecisionCache);
std::string want =
"action=EXEC|decision=ALLOW|reason=BINARY|explain=extra!|sha256=1234_hash|"
"cert_sha256=5678_hash|cert_cn=|quarantine_url=google.com|pid=12|pidversion="
Expand Down Expand Up @@ -294,7 +290,7 @@ - (void)testSerializeFileAccess {
mockESApi->SetExpectationsRetainReleaseMessage();

std::vector<uint8_t> ret =
BasicString::Create(nullptr, nil, nullptr, false)
BasicString::Create(nullptr, nil, false)
->SerializeFileAccess("v1.0", "pol_name", Message(mockESApi, &esMsg),
Enricher().Enrich(*esMsg.process), "file_target",
FileAccessPolicyDecision::kAllowedAuditOnly);
Expand All @@ -315,7 +311,7 @@ - (void)testSerializeAllowlist {
auto mockESApi = std::make_shared<MockEndpointSecurityAPI>();
mockESApi->SetExpectationsRetainReleaseMessage();

std::vector<uint8_t> ret = BasicString::Create(mockESApi, nil, nullptr, false)
std::vector<uint8_t> ret = BasicString::Create(mockESApi, nil, false)
->SerializeAllowlist(Message(mockESApi, &esMsg), "test_hash");

XCTAssertTrue(testing::Mock::VerifyAndClearExpectations(mockESApi.get()),
Expand All @@ -339,7 +335,7 @@ - (void)testSerializeBundleHashingEvent {
se.filePath = @"file_path";

std::vector<uint8_t> ret =
BasicString::Create(nullptr, nil, nullptr, false)->SerializeBundleHashingEvent(se);
BasicString::Create(nullptr, nil, false)->SerializeBundleHashingEvent(se);
std::string got(ret.begin(), ret.end());

std::string want = "action=BUNDLE|sha256=file_hash"
Expand All @@ -366,8 +362,7 @@ - (void)testSerializeDiskAppeared {
OCMStub([self.mockConfigurator configurator]).andReturn(self.mockConfigurator);
OCMStub([self.mockConfigurator enableMachineIDDecoration]).andReturn(NO);

std::vector<uint8_t> ret =
BasicString::Create(nullptr, nil, nullptr, false)->SerializeDiskAppeared(props);
std::vector<uint8_t> ret = BasicString::Create(nullptr, nil, false)->SerializeDiskAppeared(props);
std::string got(ret.begin(), ret.end());

std::string want = "action=DISKAPPEAR|mount=path|volume=|bsdname=bsd|fs=apfs"
Expand All @@ -384,7 +379,7 @@ - (void)testSerializeDiskDisappeared {
};

std::vector<uint8_t> ret =
BasicString::Create(nullptr, nil, nullptr, false)->SerializeDiskDisappeared(props);
BasicString::Create(nullptr, nil, false)->SerializeDiskDisappeared(props);
std::string got(ret.begin(), ret.end());

std::string want = "action=DISKDISAPPEAR|mount=path|volume=|bsdname=bsd|machineid=my_id\n";
Expand Down
2 changes: 1 addition & 1 deletion Source/santad/Logs/EndpointSecurity/Serializers/Empty.mm
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
return std::make_shared<Empty>();
}

Empty::Empty() : Serializer(nil, nullptr) {}
Empty::Empty() : Serializer(nil) {}

std::vector<uint8_t> Empty::SerializeMessage(const EnrichedClose &msg) {
return {};
Expand Down
4 changes: 2 additions & 2 deletions Source/santad/Logs/EndpointSecurity/Serializers/Protobuf.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ class Protobuf : public Serializer {
public:
static std::shared_ptr<Protobuf> Create(
std::shared_ptr<santa::santad::event_providers::endpoint_security::EndpointSecurityAPI> esapi,
SNTDecisionCache *decision_cache, ClientModeFunc GetClientMode);
SNTDecisionCache *decision_cache);

Protobuf(
std::shared_ptr<santa::santad::event_providers::endpoint_security::EndpointSecurityAPI> esapi,
SNTDecisionCache *decision_cache, ClientModeFunc GetClientMode);
SNTDecisionCache *decision_cache);

std::vector<uint8_t> SerializeMessage(
const santa::santad::event_providers::endpoint_security::EnrichedClose &) override;
Expand Down
12 changes: 5 additions & 7 deletions Source/santad/Logs/EndpointSecurity/Serializers/Protobuf.mm
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,12 @@
namespace santa::santad::logs::endpoint_security::serializers {

std::shared_ptr<Protobuf> Protobuf::Create(std::shared_ptr<EndpointSecurityAPI> esapi,
SNTDecisionCache *decision_cache,
ClientModeFunc GetClientMode) {
return std::make_shared<Protobuf>(esapi, std::move(decision_cache), std::move(GetClientMode));
SNTDecisionCache *decision_cache) {
return std::make_shared<Protobuf>(esapi, std::move(decision_cache));
}

Protobuf::Protobuf(std::shared_ptr<EndpointSecurityAPI> esapi, SNTDecisionCache *decision_cache,
ClientModeFunc GetClientMode)
: Serializer(std::move(decision_cache), std::move(GetClientMode)), esapi_(esapi) {}
Protobuf::Protobuf(std::shared_ptr<EndpointSecurityAPI> esapi, SNTDecisionCache *decision_cache)
: Serializer(std::move(decision_cache)), esapi_(esapi) {}

static inline void EncodeTimestamp(Timestamp *timestamp, struct timespec ts) {
timestamp->set_seconds(ts.tv_sec);
Expand Down Expand Up @@ -486,7 +484,7 @@ static inline void EncodeCertificateInfo(::pbv1::CertificateInfo *pb_cert_info,

pb_exec->set_decision(GetDecisionEnum(cd.decision));
pb_exec->set_reason(GetReasonEnum(cd.decision));
pb_exec->set_mode(GetModeEnum(GetClientMode()));
pb_exec->set_mode(GetModeEnum(cd.decisionClientMode));

if (cd.certSHA256 || cd.certCommonName) {
EncodeCertificateInfo(pb_exec->mutable_certificate_info(), cd.certSHA256, cd.certCommonName);
Expand Down
Loading

0 comments on commit e7d0cef

Please sign in to comment.