Skip to content

Commit

Permalink
Snapshot diff merging fixes (#174)
Browse files Browse the repository at this point in the history
* Started on more efficient diffs

* Flip order of checking merge regions

* More tweaks

* Remove unnecessary logging
  • Loading branch information
Shillaker committed Nov 12, 2021
1 parent 990c640 commit 41207d9
Show file tree
Hide file tree
Showing 7 changed files with 227 additions and 117 deletions.
3 changes: 3 additions & 0 deletions include/faabric/util/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,7 @@ AlignedChunk getPageAlignedChunk(long offset, long length);
void resetDirtyTracking();

std::vector<int> getDirtyPageNumbers(const uint8_t* ptr, int nPages);

std::vector<std::pair<uint32_t, uint32_t>> getDirtyRegions(const uint8_t* ptr,
int nPages);
}
5 changes: 4 additions & 1 deletion include/faabric/util/snapshot.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ class SnapshotMergeRegion

void addDiffs(std::vector<SnapshotDiff>& diffs,
const uint8_t* original,
const uint8_t* updated);
uint32_t originalSize,
const uint8_t* updated,
uint32_t dirtyRegionStart,
uint32_t dirtyRegionEnd);
};

class SnapshotData
Expand Down
25 changes: 25 additions & 0 deletions src/util/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,29 @@ std::vector<int> getDirtyPageNumbers(const uint8_t* ptr, int nPages)

return pageNumbers;
}

std::vector<std::pair<uint32_t, uint32_t>> getDirtyRegions(const uint8_t* ptr,
int nPages)
{
std::vector<int> dirtyPages = getDirtyPageNumbers(ptr, nPages);

// Add a new region for each page, unless the one before it was also dirty,
// in which case we merge them
std::vector<std::pair<uint32_t, uint32_t>> regions;
for (int p = 0; p < dirtyPages.size(); p++) {
int thisPageNum = dirtyPages.at(p);
uint32_t thisPageStart = thisPageNum * HOST_PAGE_SIZE;
uint32_t thisPageEnd = thisPageStart + HOST_PAGE_SIZE;

if (p > 0 && dirtyPages.at(p - 1) == thisPageNum - 1) {
// Previous page was also dirty, just update last region
regions.back().second = thisPageEnd;
} else {
// Previous page wasn't dirty, add new region
regions.emplace_back(thisPageStart, thisPageEnd);
}
}

return regions;
}
}
160 changes: 70 additions & 90 deletions src/util/snapshot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,72 +47,31 @@ std::vector<SnapshotDiff> SnapshotData::getChangeDiffs(const uint8_t* updated,
return diffs;
}

for (const auto& mr : mergeRegions) {
SPDLOG_TRACE("Merge region {} {} at {}-{}",
snapshotDataTypeStr(mr.second.dataType),
snapshotMergeOpStr(mr.second.operation),
mr.second.offset,
mr.second.offset + mr.second.length);
}

// Work out which pages have changed (these will be sorted)
// Work out which regions of memory have changed
size_t nThisPages = getRequiredHostPages(updatedSize);
std::vector<int> dirtyPageNumbers =
getDirtyPageNumbers(updated, nThisPages);

// Iterate through each dirty page, work out if there's an overlapping merge
// region, tell that region to add their diffs to the list
std::map<uint32_t, SnapshotMergeRegion>::iterator mergeIt =
mergeRegions.begin();

for (int i : dirtyPageNumbers) {
int pageStart = i * HOST_PAGE_SIZE;
int pageEnd = pageStart + HOST_PAGE_SIZE;

SPDLOG_TRACE("Checking dirty page {} at {}-{}", i, pageStart, pageEnd);
std::vector<std::pair<uint32_t, uint32_t>> dirtyRegions =
getDirtyRegions(updated, nThisPages);
SPDLOG_TRACE("Found {} dirty regions", dirtyRegions.size());

// Skip any merge regions we've passed
while (mergeIt != mergeRegions.end() &&
(mergeIt->second.offset < pageStart)) {
SPDLOG_TRACE("Gone past {} {} merge region at {}-{}",
snapshotDataTypeStr(mergeIt->second.dataType),
snapshotMergeOpStr(mergeIt->second.operation),
mergeIt->second.offset,
mergeIt->second.offset + mergeIt->second.length);
// Iterate through merge regions, see which ones overlap with dirty memory
// regions, and add corresponding diffs
for (auto& mrPair : mergeRegions) {
SnapshotMergeRegion& mr = mrPair.second;

++mergeIt;
}

if (mergeIt == mergeRegions.end()) {
// Done if no more merge regions left
SPDLOG_TRACE("No more merge regions left");
break;
}

// For each merge region that overlaps this dirty page, get it to add
// its diffs, and move onto the next one
// TODO - make this more efficient by passing in dirty pages to merge
// regions so that they avoid unnecessary work if they're large.
while (mergeIt != mergeRegions.end() &&
(mergeIt->second.offset >= pageStart &&
mergeIt->second.offset < pageEnd)) {

uint8_t* original = data;

// If we're outside the range of the original data, pass a nullptr
if (mergeIt->second.offset > size) {
SPDLOG_TRACE(
"Checking {} {} merge region {}-{} outside original snapshot",
snapshotDataTypeStr(mergeIt->second.dataType),
snapshotMergeOpStr(mergeIt->second.operation),
mergeIt->second.offset,
mergeIt->second.offset + mergeIt->second.length);

original = nullptr;
}

mergeIt->second.addDiffs(diffs, original, updated);
mergeIt++;
SPDLOG_TRACE("Merge region {} {} at {}-{}",
snapshotDataTypeStr(mr.dataType),
snapshotMergeOpStr(mr.operation),
mr.offset,
mr.offset + mr.length);

for (auto& dirtyRegion : dirtyRegions) {
// Add the diffs
mr.addDiffs(diffs,
data,
size,
updated,
dirtyRegion.first,
dirtyRegion.second);
}
}

Expand Down Expand Up @@ -210,21 +169,35 @@ std::string snapshotMergeOpStr(SnapshotMergeOperation op)

void SnapshotMergeRegion::addDiffs(std::vector<SnapshotDiff>& diffs,
const uint8_t* original,
const uint8_t* updated)
uint32_t originalSize,
const uint8_t* updated,
uint32_t dirtyRegionStart,
uint32_t dirtyRegionEnd)
{
SPDLOG_TRACE("Checking for {} {} merge region at {}-{}",
// If the region has zero length, it signifies that it goes to the
// end of the memory, so we go all the way to the end of the dirty region.
// For all other regions, we just check if the dirty range is within the
// merge region.
bool isInRange = (dirtyRegionEnd > offset) &&
((length == 0) || (dirtyRegionStart < offset + length));

if (!isInRange) {
return;
}

SPDLOG_TRACE("Checking for {} {} merge region in dirty region {}-{}",
snapshotDataTypeStr(dataType),
snapshotMergeOpStr(operation),
offset,
offset + length);
dirtyRegionStart,
dirtyRegionEnd);

switch (dataType) {
case (SnapshotDataType::Int): {
// Check if the value has changed
const uint8_t* updatedValue = updated + offset;
int updatedInt = *(reinterpret_cast<const int*>(updatedValue));

if (original == nullptr) {
if (originalSize < offset) {
throw std::runtime_error(
"Do not support int operations outside original snapshot");
}
Expand All @@ -237,16 +210,6 @@ void SnapshotMergeRegion::addDiffs(std::vector<SnapshotDiff>& diffs,
return;
}

// Add the diff
diffs.emplace_back(
dataType, operation, offset, updatedValue, length);

SPDLOG_TRACE("Adding {} {} diff at {}-{}",
snapshotDataTypeStr(dataType),
snapshotMergeOpStr(operation),
offset,
offset + length);

// Potentially modify the original in place depending on the
// operation
switch (operation) {
Expand Down Expand Up @@ -285,25 +248,42 @@ void SnapshotMergeRegion::addDiffs(std::vector<SnapshotDiff>& diffs,
std::memcpy(
(uint8_t*)updatedValue, BYTES(&updatedInt), sizeof(int32_t));

// Add the diff
diffs.emplace_back(
dataType, operation, offset, updatedValue, length);

SPDLOG_TRACE("Adding {} {} diff at {}-{} ({})",
snapshotDataTypeStr(dataType),
snapshotMergeOpStr(operation),
offset,
offset + length,
updatedInt);

break;
}
case (SnapshotDataType::Raw): {
switch (operation) {
case (SnapshotMergeOperation::Overwrite): {
// Add subsections of diffs only for the bytes that
// have changed
// Work out bounds of region we're checking
uint32_t checkStart =
std::max<uint32_t>(dirtyRegionStart, offset);

uint32_t checkEnd;
if (length == 0) {
checkEnd = dirtyRegionEnd;
} else {
checkEnd =
std::min<uint32_t>(dirtyRegionEnd, offset + length);
}

bool diffInProgress = false;
int diffStart = 0;
for (int b = offset; b <= offset + length; b++) {
bool isDirtyByte = false;

if (original == nullptr) {
isDirtyByte = true;
} else {
isDirtyByte = *(original + b) != *(updated + b);
}
for (int b = checkStart; b <= checkEnd; b++) {
// If this byte is outside the original region, we can't
// compare (i.e. always dirty)
bool isDirtyByte = (b > originalSize) ||
(*(original + b) != *(updated + b));

SPDLOG_TRACE("BYTE {} dirty {}", b, isDirtyByte);
if (isDirtyByte && !diffInProgress) {
// Diff starts here if it's different and diff
// not in progress
Expand Down Expand Up @@ -331,7 +311,7 @@ void SnapshotMergeRegion::addDiffs(std::vector<SnapshotDiff>& diffs,
// If we've reached the end of this region with a diff
// in progress, we need to close it off
if (diffInProgress) {
int finalDiffLength = (offset + length) - diffStart + 1;
int finalDiffLength = checkEnd - diffStart;
SPDLOG_TRACE(
"Adding {} {} diff at {}-{} (end of region)",
snapshotDataTypeStr(dataType),
Expand Down
6 changes: 3 additions & 3 deletions tests/test/snapshot/test_snapshot_diffs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ TEST_CASE_METHOD(SnapshotTestFixture, "Test snapshot diffs", "[snapshot]")
// NOTE - deliberately add merge regions out of order
// Diff starting in merge region and overlapping the end
std::vector<uint8_t> dataC = { 7, 6, 5, 4, 3, 2, 1 };
std::vector<uint8_t> expectedDataC = { 7, 6, 5, 4, 3 };
std::vector<uint8_t> expectedDataC = { 7, 6, 5, 4 };
int offsetC = 2 * HOST_PAGE_SIZE;
std::memcpy(sharedMem + offsetC, dataC.data(), dataC.size());

Expand All @@ -104,7 +104,7 @@ TEST_CASE_METHOD(SnapshotTestFixture, "Test snapshot diffs", "[snapshot]")

// Merge region within a change
std::vector<uint8_t> dataD = { 1, 1, 2, 2, 3, 3, 4 };
std::vector<uint8_t> expectedDataD = { 2, 2, 3, 3 };
std::vector<uint8_t> expectedDataD = { 2, 2, 3 };
int offsetD = 3 * HOST_PAGE_SIZE - dataD.size();
std::memcpy(sharedMem + offsetD, dataD.data(), dataD.size());

Expand All @@ -119,7 +119,7 @@ TEST_CASE_METHOD(SnapshotTestFixture, "Test snapshot diffs", "[snapshot]")
// add a merge region larger than it. Anything outside the original snapshot
// should be marked as changed.
std::vector<uint8_t> dataExtra = { 2, 2, 2 };
std::vector<uint8_t> expectedDataExtra = { 0, 0, 2, 2, 2, 0, 0, 0 };
std::vector<uint8_t> expectedDataExtra = { 0, 0, 2, 2, 2, 0, 0 };
int extraOffset = snapSize + HOST_PAGE_SIZE + 10;
std::memcpy(sharedMem + extraOffset, dataExtra.data(), dataExtra.size());

Expand Down
49 changes: 49 additions & 0 deletions tests/test/util/test_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,4 +303,53 @@ TEST_CASE("Test dirty page checking", "[util]")

munmap(sharedMemory, memSize);
}

TEST_CASE("Test dirty region checking", "[util]")
{
int nPages = 15;
size_t memSize = HOST_PAGE_SIZE * nPages;
auto* sharedMemory = (uint8_t*)mmap(
nullptr, memSize, PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0);

if (sharedMemory == nullptr) {
FAIL("Could not provision memory");
}

resetDirtyTracking();

std::vector<std::pair<uint32_t, uint32_t>> actual =
faabric::util::getDirtyRegions(sharedMemory, nPages);
REQUIRE(actual.empty());

// Dirty some pages, some adjacent
uint8_t* pageZero = sharedMemory;
uint8_t* pageOne = pageZero + HOST_PAGE_SIZE;
uint8_t* pageThree = pageZero + (3 * HOST_PAGE_SIZE);
uint8_t* pageFour = pageZero + (4 * HOST_PAGE_SIZE);
uint8_t* pageSeven = pageZero + (7 * HOST_PAGE_SIZE);
uint8_t* pageNine = pageZero + (9 * HOST_PAGE_SIZE);

// Set some byte within each page
pageZero[1] = 1;
pageOne[11] = 1;
pageThree[33] = 1;
pageFour[44] = 1;
pageSeven[77] = 1;
pageNine[99] = 1;

// Expect adjacent regions to be merged
std::vector<std::pair<uint32_t, uint32_t>> expected = {
{ 0, 2 * HOST_PAGE_SIZE },
{ 3 * HOST_PAGE_SIZE, 5 * HOST_PAGE_SIZE },
{ 7 * HOST_PAGE_SIZE, 8 * HOST_PAGE_SIZE },
{ 9 * HOST_PAGE_SIZE, 10 * HOST_PAGE_SIZE },
};

actual = faabric::util::getDirtyRegions(sharedMemory, nPages);
REQUIRE(actual.size() == expected.size());
for (int i = 0; i < actual.size(); i++) {
REQUIRE(actual.at(i).first == expected.at(i).first);
REQUIRE(actual.at(i).second == expected.at(i).second);
}
}
}
Loading

0 comments on commit 41207d9

Please sign in to comment.