Skip to content

Commit

Permalink
Fix undefined behavior in log (#62)
Browse files Browse the repository at this point in the history
Signed-off-by: Michael X. Grey <grey@openrobotics.org>
  • Loading branch information
mxgrey committed May 5, 2022
1 parent 63fce74 commit 32de25f
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 17 deletions.
1 change: 1 addition & 0 deletions rmf_task/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

2.1.0 (2022-XX-YY)
------------------
* Fix undefined behavior in log: [#62](https://github.com/open-rmf/rmf_task/pull/62)

2.0.0 (2022-02-14)
------------------
Expand Down
2 changes: 2 additions & 0 deletions rmf_task/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ find_package(rmf_utils REQUIRED)
find_package(rmf_traffic REQUIRED)
find_package(rmf_battery REQUIRED)
find_package(Eigen3 REQUIRED)
find_package(Threads)

find_package(ament_cmake_catch2 QUIET)
find_package(ament_cmake_uncrustify QUIET)
Expand All @@ -40,6 +41,7 @@ add_library(rmf_task SHARED
target_link_libraries(rmf_task
PUBLIC
rmf_battery::rmf_battery
Threads::Threads
)

target_include_directories(rmf_task
Expand Down
7 changes: 4 additions & 3 deletions rmf_task/cmake/rmf_task-config.cmake.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ get_filename_component(rmf_task_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH)

include(CMakeFindDependencyMacro)

find_dependency(rmf_utils REQUIRED)
find_dependency(rmf_traffic REQUIRED)
find_dependency(rmf_battery REQUIRED)
find_dependency(rmf_utils)
find_dependency(rmf_traffic)
find_dependency(rmf_battery)
find_dependency(Threads)

if(NOT TARGET rmf_task::rmf_task)
include("${rmf_task_CMAKE_DIR}/rmf_task-targets.cmake")
Expand Down
49 changes: 35 additions & 14 deletions rmf_task/src/rmf_task/Log.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <rmf_task/Log.hpp>

#include <optional>
#include <mutex>
#include <stdexcept>
#include <list>

Expand All @@ -29,6 +30,7 @@ class Log::Implementation
public:
std::function<rmf_traffic::Time()> clock;
std::shared_ptr<std::list<Log::Entry>> entries;
mutable std::mutex mutex;
uint32_t seq = 0;

Implementation(std::function<rmf_traffic::Time()> clock_)
Expand Down Expand Up @@ -183,24 +185,15 @@ class Log::Reader::Iterable::iterator::Implementation
Log::Reader::Iterable Log::Reader::Iterable::Implementation::make(
std::shared_ptr<const std::list<Log::Entry>> shared,
std::optional<base_iterator> begin,
std::optional<base_iterator> last)
std::optional<base_iterator> last_in_view)
{
Iterable iterable;
iterable._pimpl = rmf_utils::make_impl<Implementation>();
iterable._pimpl->shared = std::move(shared);
if (begin.has_value())
{
if (++base_iterator(last.value()) == *begin)
{
// If the beginning iterator is already the end() iterator, we should
// directly set it to that right now.
iterable._pimpl->begin = iterator::Implementation::end();
}
else
{
iterable._pimpl->begin =
iterator::Implementation::make(*begin, last.value());
}
iterable._pimpl->begin =
iterator::Implementation::make(*begin, last_in_view.value());
}
else
{
Expand All @@ -219,12 +212,37 @@ auto Log::Reader::Implementation::read(const View& view) -> Iterable
if (memory.weak.lock())
{
if (!memory.last.has_value())
{
memory.last = v.begin;
}
else if (v.last.has_value())
{
if ((*memory.last)->seq() >= (*v.last)->seq())
{
// If the last memory of this reader is more recent than this view, then
// we will return an empty iterable.
return Iterable::Implementation::make(
v.shared, std::nullopt, std::nullopt);
}
else
{
// If the last memory of this reader is behind the last value in the
// view, then we will move it forward by one.
++(*memory.last);
}
}
else
++(*memory.last);
{
// The view is missing a "last" value, meaning it's an empty view.
// TODO(MXG): We should write explicit unit tests for this.
return Iterable::Implementation::make(
v.shared, std::nullopt, std::nullopt);
}
}
else
{
// Reset this memory, because it points at an expired list whose memory
// address is being recycled.
memory.weak = v.shared;
memory.last = v.begin;
}
Expand All @@ -239,7 +257,7 @@ auto Log::Reader::Implementation::read(const View& view) -> Iterable

//==============================================================================
Log::Log(std::function<rmf_traffic::Time()> clock)
: _pimpl(rmf_utils::make_impl<Implementation>(std::move(clock)))
: _pimpl(rmf_utils::make_unique_impl<Implementation>(std::move(clock)))
{
// Do nothing
}
Expand Down Expand Up @@ -273,6 +291,7 @@ void Log::push(Tier tier, std::string text)
// *INDENT-ON*
}

std::lock_guard<std::mutex> lock(_pimpl->mutex);
_pimpl->entries->emplace_back(
Entry::Implementation::make(
tier, _pimpl->seq++, _pimpl->clock(), std::move(text)));
Expand All @@ -287,6 +306,7 @@ void Log::insert(Log::Entry entry)
//==============================================================================
Log::View Log::view() const
{
std::lock_guard<std::mutex> lock(_pimpl->mutex);
return View::Implementation::make(*this);
}

Expand Down Expand Up @@ -402,6 +422,7 @@ bool Log::Reader::Iterable::iterator::operator!=(const iterator& other) const

//==============================================================================
Log::Reader::Iterable::iterator::iterator()
: _pimpl(nullptr)
{
// Do nothing
}
Expand Down
188 changes: 188 additions & 0 deletions rmf_task/test/unit/test_Log.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@

#include <rmf_task/Log.hpp>

#include <random>
#include <iostream>
#include <thread>
#include <mutex>
#include <atomic>
#include <condition_variable>

SCENARIO("Writing and reading logs")
{
rmf_task::Log log;
Expand Down Expand Up @@ -58,3 +65,184 @@ SCENARIO("Writing and reading logs")

CHECK(count == expected_count);
}

struct SyncView
{
bool ready_for_new_view = false;
std::condition_variable new_view_ready;
std::mutex mutex;
std::optional<rmf_task::Log::View> view;
};

//==============================================================================
SCENARIO("Multi-threaded read/write with synced view")
{
std::shared_ptr<SyncView> sync = std::make_shared<SyncView>();
std::shared_ptr<std::atomic_bool> test_finished =
std::make_shared<std::atomic_bool>(false);

auto all_seqs = std::make_shared<std::vector<uint32_t>>();
auto all_text = std::make_shared<std::vector<std::string>>();
auto log = std::make_shared<rmf_task::Log>();

std::thread producer(
[](
std::shared_ptr<rmf_task::Log> log,
std::shared_ptr<SyncView> sync,
std::shared_ptr<std::atomic_bool> test_finished)
{
std::random_device r;
std::default_random_engine eng(r());
std::uniform_real_distribution<double> entry_dist(0, 1);
std::uniform_int_distribution<uint> tier_dist(
static_cast<uint>(rmf_task::Log::Tier::Info),
static_cast<uint>(rmf_task::Log::Tier::Error));

std::size_t counter = 0;
while (!test_finished->load())
{
if (entry_dist(eng) > 0.95)
{
const auto tier = static_cast<rmf_task::Log::Tier>(tier_dist(eng));
log->push(tier, "This is log #" + std::to_string(counter++));
}

std::unique_lock<std::mutex> lock(sync->mutex, std::defer_lock);
if (lock.try_lock())
{
if (sync->ready_for_new_view)
{
sync->view = log->view();
sync->ready_for_new_view = false;
lock.unlock();
sync->new_view_ready.notify_all();
}
}
}
}, log, sync, test_finished);

std::thread consumer(
[](
std::shared_ptr<std::vector<uint32_t>> all_seqs,
std::shared_ptr<std::vector<std::string>> all_text,
std::shared_ptr<SyncView> sync,
std::shared_ptr<std::atomic_bool> test_finished)
{
rmf_task::Log::Reader reader;
while (!test_finished->load())
{
std::size_t new_entry_count = 0;
std::unique_lock<std::mutex> lock(sync->mutex);
if (sync->view.has_value())
{
for (const auto& entry : reader.read(*sync->view))
{
++new_entry_count;
all_seqs->push_back(entry.seq());
all_text->push_back(entry.text());
}
}

sync->view = std::nullopt;
sync->ready_for_new_view = true;
sync->new_view_ready.wait(
lock,
[sync, test_finished]()
{
return sync->view.has_value() || test_finished->load();
});
}
}, all_seqs, all_text, sync, test_finished);

std::this_thread::sleep_for(std::chrono::seconds(1));
test_finished->store(true);

// Use this condition variable to wake up the consumer, in case it's waiting
sync->new_view_ready.notify_all();

producer.join();
consumer.join();

std::size_t index = 0;
for (const auto& entry : rmf_task::Log::Reader().read(log->view()))
{
if (index < all_seqs->size())
CHECK((*all_seqs)[index] == entry.seq());

if (index < all_text->size())
CHECK((*all_text)[index] == entry.text());

++index;
}
}

//==============================================================================
SCENARIO("Multi-threaded read/write without syncing")
{

std::shared_ptr<std::atomic_bool> test_finished =
std::make_shared<std::atomic_bool>(false);

auto all_seqs = std::make_shared<std::vector<uint32_t>>();
auto all_text = std::make_shared<std::vector<std::string>>();
auto log = std::make_shared<rmf_task::Log>();

std::thread producer(
[](
std::shared_ptr<rmf_task::Log> log,
std::shared_ptr<std::atomic_bool> test_finished)
{
std::random_device r;
std::default_random_engine eng(r());
std::uniform_real_distribution<double> entry_dist(0, 1);
std::uniform_int_distribution<uint> tier_dist(
static_cast<uint>(rmf_task::Log::Tier::Info),
static_cast<uint>(rmf_task::Log::Tier::Error));

std::size_t counter = 0;
while (!test_finished->load())
{
if (entry_dist(eng) > 0.95)
{
const auto tier = static_cast<rmf_task::Log::Tier>(tier_dist(eng));
log->push(tier, "This is log #" + std::to_string(counter++));
}
}
}, log, test_finished);

std::thread consumer(
[](
std::shared_ptr<rmf_task::Log> log,
std::shared_ptr<std::vector<uint32_t>> all_seqs,
std::shared_ptr<std::vector<std::string>> all_text,
std::shared_ptr<std::atomic_bool> test_finished)
{
rmf_task::Log::Reader reader;
while (!test_finished->load())
{
for (const auto& entry : reader.read(log->view()))
{
all_seqs->push_back(entry.seq());
all_text->push_back(entry.text());
}
}
}, log, all_seqs, all_text, test_finished);

std::this_thread::sleep_for(std::chrono::seconds(1));
test_finished->store(true);

producer.join();
consumer.join();

std::size_t index = 0;
for (const auto& entry : rmf_task::Log::Reader().read(log->view()))
{
if (index < all_seqs->size())
CHECK((*all_seqs)[index] == entry.seq());

if (index < all_text->size())
CHECK((*all_text)[index] == entry.text());

++index;
}
}

0 comments on commit 32de25f

Please sign in to comment.