Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix undefined behavior in log #62

Merged
merged 11 commits into from
May 5, 2022
1 change: 1 addition & 0 deletions rmf_task/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

2.1.0 (2022-XX-YY)
------------------
* Fix undefined behavior in log: [#62](https://github.com/open-rmf/rmf_task/pull/62)

2.0.0 (2022-02-14)
------------------
Expand Down
2 changes: 2 additions & 0 deletions rmf_task/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ find_package(rmf_utils REQUIRED)
find_package(rmf_traffic REQUIRED)
find_package(rmf_battery REQUIRED)
find_package(Eigen3 REQUIRED)
find_package(Threads)

find_package(ament_cmake_catch2 QUIET)
find_package(ament_cmake_uncrustify QUIET)
Expand All @@ -40,6 +41,7 @@ add_library(rmf_task SHARED
target_link_libraries(rmf_task
PUBLIC
rmf_battery::rmf_battery
Threads::Threads
)

target_include_directories(rmf_task
Expand Down
7 changes: 4 additions & 3 deletions rmf_task/cmake/rmf_task-config.cmake.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ get_filename_component(rmf_task_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH)

include(CMakeFindDependencyMacro)

find_dependency(rmf_utils REQUIRED)
find_dependency(rmf_traffic REQUIRED)
find_dependency(rmf_battery REQUIRED)
find_dependency(rmf_utils)
find_dependency(rmf_traffic)
find_dependency(rmf_battery)
find_dependency(Threads)

if(NOT TARGET rmf_task::rmf_task)
include("${rmf_task_CMAKE_DIR}/rmf_task-targets.cmake")
Expand Down
49 changes: 35 additions & 14 deletions rmf_task/src/rmf_task/Log.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <rmf_task/Log.hpp>

#include <optional>
#include <mutex>
#include <stdexcept>
#include <list>

Expand All @@ -29,6 +30,7 @@ class Log::Implementation
public:
std::function<rmf_traffic::Time()> clock;
std::shared_ptr<std::list<Log::Entry>> entries;
mutable std::mutex mutex;
uint32_t seq = 0;

Implementation(std::function<rmf_traffic::Time()> clock_)
Expand Down Expand Up @@ -183,24 +185,15 @@ class Log::Reader::Iterable::iterator::Implementation
Log::Reader::Iterable Log::Reader::Iterable::Implementation::make(
std::shared_ptr<const std::list<Log::Entry>> shared,
std::optional<base_iterator> begin,
std::optional<base_iterator> last)
std::optional<base_iterator> last_in_view)
{
Iterable iterable;
iterable._pimpl = rmf_utils::make_impl<Implementation>();
iterable._pimpl->shared = std::move(shared);
if (begin.has_value())
{
if (++base_iterator(last.value()) == *begin)
{
// If the beginning iterator is already the end() iterator, we should
// directly set it to that right now.
iterable._pimpl->begin = iterator::Implementation::end();
}
else
{
iterable._pimpl->begin =
iterator::Implementation::make(*begin, last.value());
}
iterable._pimpl->begin =
iterator::Implementation::make(*begin, last_in_view.value());
}
else
{
Expand All @@ -219,12 +212,37 @@ auto Log::Reader::Implementation::read(const View& view) -> Iterable
if (memory.weak.lock())
{
if (!memory.last.has_value())
{
memory.last = v.begin;
}
else if (v.last.has_value())
{
if ((*memory.last)->seq() >= (*v.last)->seq())
{
// If the last memory of this reader is more recent than this view, then
// we will return an empty iterable.
return Iterable::Implementation::make(
v.shared, std::nullopt, std::nullopt);
}
else
{
// If the last memory of this reader is behind the last value in the
// view, then we will move it forward by one.
++(*memory.last);
}
}
else
++(*memory.last);
{
// The view is missing a "last" value, meaning it's an empty view.
// TODO(MXG): We should write explicit unit tests for this.
return Iterable::Implementation::make(
v.shared, std::nullopt, std::nullopt);
}
}
else
{
// Reset this memory, because it points at an expired list whose memory
// address is being recycled.
memory.weak = v.shared;
memory.last = v.begin;
}
Expand All @@ -239,7 +257,7 @@ auto Log::Reader::Implementation::read(const View& view) -> Iterable

//==============================================================================
Log::Log(std::function<rmf_traffic::Time()> clock)
: _pimpl(rmf_utils::make_impl<Implementation>(std::move(clock)))
: _pimpl(rmf_utils::make_unique_impl<Implementation>(std::move(clock)))
{
// Do nothing
}
Expand Down Expand Up @@ -273,6 +291,7 @@ void Log::push(Tier tier, std::string text)
// *INDENT-ON*
}

std::lock_guard<std::mutex> lock(_pimpl->mutex);
_pimpl->entries->emplace_back(
Entry::Implementation::make(
tier, _pimpl->seq++, _pimpl->clock(), std::move(text)));
Expand All @@ -287,6 +306,7 @@ void Log::insert(Log::Entry entry)
//==============================================================================
Log::View Log::view() const
{
std::lock_guard<std::mutex> lock(_pimpl->mutex);
return View::Implementation::make(*this);
}

Expand Down Expand Up @@ -402,6 +422,7 @@ bool Log::Reader::Iterable::iterator::operator!=(const iterator& other) const

//==============================================================================
Log::Reader::Iterable::iterator::iterator()
: _pimpl(nullptr)
{
// Do nothing
}
Expand Down
188 changes: 188 additions & 0 deletions rmf_task/test/unit/test_Log.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@

#include <rmf_task/Log.hpp>

#include <random>
#include <iostream>
#include <thread>
#include <mutex>
#include <atomic>
#include <condition_variable>

SCENARIO("Writing and reading logs")
{
rmf_task::Log log;
Expand Down Expand Up @@ -58,3 +65,184 @@ SCENARIO("Writing and reading logs")

CHECK(count == expected_count);
}

struct SyncView
{
bool ready_for_new_view = false;
std::condition_variable new_view_ready;
std::mutex mutex;
std::optional<rmf_task::Log::View> view;
};

//==============================================================================
SCENARIO("Multi-threaded read/write with synced view")
{
std::shared_ptr<SyncView> sync = std::make_shared<SyncView>();
std::shared_ptr<std::atomic_bool> test_finished =
std::make_shared<std::atomic_bool>(false);

auto all_seqs = std::make_shared<std::vector<uint32_t>>();
auto all_text = std::make_shared<std::vector<std::string>>();
auto log = std::make_shared<rmf_task::Log>();

std::thread producer(
[](
std::shared_ptr<rmf_task::Log> log,
std::shared_ptr<SyncView> sync,
std::shared_ptr<std::atomic_bool> test_finished)
{
std::random_device r;
std::default_random_engine eng(r());
std::uniform_real_distribution<double> entry_dist(0, 1);
std::uniform_int_distribution<uint> tier_dist(
static_cast<uint>(rmf_task::Log::Tier::Info),
static_cast<uint>(rmf_task::Log::Tier::Error));

std::size_t counter = 0;
while (!test_finished->load())
{
if (entry_dist(eng) > 0.95)
{
const auto tier = static_cast<rmf_task::Log::Tier>(tier_dist(eng));
log->push(tier, "This is log #" + std::to_string(counter++));
}

std::unique_lock<std::mutex> lock(sync->mutex, std::defer_lock);
if (lock.try_lock())
{
if (sync->ready_for_new_view)
{
sync->view = log->view();
sync->ready_for_new_view = false;
lock.unlock();
sync->new_view_ready.notify_all();
}
}
}
}, log, sync, test_finished);

std::thread consumer(
[](
std::shared_ptr<std::vector<uint32_t>> all_seqs,
std::shared_ptr<std::vector<std::string>> all_text,
std::shared_ptr<SyncView> sync,
std::shared_ptr<std::atomic_bool> test_finished)
{
rmf_task::Log::Reader reader;
while (!test_finished->load())
{
std::size_t new_entry_count = 0;
std::unique_lock<std::mutex> lock(sync->mutex);
if (sync->view.has_value())
{
for (const auto& entry : reader.read(*sync->view))
{
++new_entry_count;
all_seqs->push_back(entry.seq());
all_text->push_back(entry.text());
}
}

sync->view = std::nullopt;
sync->ready_for_new_view = true;
sync->new_view_ready.wait(
lock,
[sync, test_finished]()
{
return sync->view.has_value() || test_finished->load();
});
}
}, all_seqs, all_text, sync, test_finished);

std::this_thread::sleep_for(std::chrono::seconds(1));
test_finished->store(true);

// Use this condition variable to wake up the consumer, in case it's waiting
sync->new_view_ready.notify_all();

producer.join();
consumer.join();

std::size_t index = 0;
for (const auto& entry : rmf_task::Log::Reader().read(log->view()))
{
if (index < all_seqs->size())
CHECK((*all_seqs)[index] == entry.seq());

if (index < all_text->size())
CHECK((*all_text)[index] == entry.text());

++index;
}
}

//==============================================================================
SCENARIO("Multi-threaded read/write without syncing")
{

std::shared_ptr<std::atomic_bool> test_finished =
std::make_shared<std::atomic_bool>(false);

auto all_seqs = std::make_shared<std::vector<uint32_t>>();
auto all_text = std::make_shared<std::vector<std::string>>();
auto log = std::make_shared<rmf_task::Log>();

std::thread producer(
[](
std::shared_ptr<rmf_task::Log> log,
std::shared_ptr<std::atomic_bool> test_finished)
{
std::random_device r;
std::default_random_engine eng(r());
std::uniform_real_distribution<double> entry_dist(0, 1);
std::uniform_int_distribution<uint> tier_dist(
static_cast<uint>(rmf_task::Log::Tier::Info),
static_cast<uint>(rmf_task::Log::Tier::Error));

std::size_t counter = 0;
while (!test_finished->load())
{
if (entry_dist(eng) > 0.95)
{
const auto tier = static_cast<rmf_task::Log::Tier>(tier_dist(eng));
log->push(tier, "This is log #" + std::to_string(counter++));
}
}
}, log, test_finished);

std::thread consumer(
[](
std::shared_ptr<rmf_task::Log> log,
std::shared_ptr<std::vector<uint32_t>> all_seqs,
std::shared_ptr<std::vector<std::string>> all_text,
std::shared_ptr<std::atomic_bool> test_finished)
{
rmf_task::Log::Reader reader;
while (!test_finished->load())
{
for (const auto& entry : reader.read(log->view()))
{
all_seqs->push_back(entry.seq());
all_text->push_back(entry.text());
}
}
}, log, all_seqs, all_text, test_finished);

std::this_thread::sleep_for(std::chrono::seconds(1));
test_finished->store(true);

producer.join();
consumer.join();

std::size_t index = 0;
for (const auto& entry : rmf_task::Log::Reader().read(log->view()))
{
if (index < all_seqs->size())
CHECK((*all_seqs)[index] == entry.seq());

if (index < all_text->size())
CHECK((*all_text)[index] == entry.text());

++index;
}
}