Skip to content

Commit

Permalink
Merge pull request #16482 from hrydgard/savestate-checkpoints
Browse files Browse the repository at this point in the history
Add savestate checkpoints to verify that MEASURE and WRITE match
  • Loading branch information
unknownbrackets committed Dec 2, 2022
2 parents 0c19f6a + c7041d6 commit 7dee26e
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 35 deletions.
3 changes: 2 additions & 1 deletion Common/Serialize/SerializeFuncs.h
Expand Up @@ -85,7 +85,8 @@ template<class T>
void DoVector(PointerWrap &p, std::vector<T> &x, T &default_val) {
u32 vec_size = (u32)x.size();
Do(p, vec_size);
x.resize(vec_size, default_val);
if (vec_size != x.size())
x.resize(vec_size, default_val);
if (vec_size > 0)
DoArray(p, &x[0], vec_size);
}
Expand Down
12 changes: 8 additions & 4 deletions Common/Serialize/SerializeMap.h
Expand Up @@ -38,8 +38,8 @@ void DoMap(PointerWrap &p, M &x, typename M::mapped_type &default_val) {
x[first] = second;
--number;
}
break;
}
break;
case PointerWrap::MODE_WRITE:
case PointerWrap::MODE_MEASURE:
case PointerWrap::MODE_VERIFY:
Expand All @@ -52,8 +52,10 @@ void DoMap(PointerWrap &p, M &x, typename M::mapped_type &default_val) {
--number;
++itr;
}
break;
}
break;
case PointerWrap::MODE_NOOP:
break;
}
}

Expand Down Expand Up @@ -109,8 +111,8 @@ void DoMultimap(PointerWrap &p, M &x, typename M::mapped_type &default_val) {
x.insert(std::make_pair(first, second));
--number;
}
break;
}
break;
case PointerWrap::MODE_WRITE:
case PointerWrap::MODE_MEASURE:
case PointerWrap::MODE_VERIFY:
Expand All @@ -122,8 +124,10 @@ void DoMultimap(PointerWrap &p, M &x, typename M::mapped_type &default_val) {
--number;
++itr;
}
break;
}
break;
case PointerWrap::MODE_NOOP:
break;
}
}

Expand Down
60 changes: 54 additions & 6 deletions Common/Serialize/Serializer.cpp
Expand Up @@ -33,8 +33,27 @@ enum class SerializeCompressType {

static constexpr SerializeCompressType SAVE_TYPE = SerializeCompressType::ZSTD;

PointerWrapSection PointerWrap::Section(const char *title, int ver) {
return Section(title, ver, ver);
void PointerWrap::RewindForWrite(u8 *writePtr) {
_assert_(mode == MODE_MEASURE);
// Switch to writing mode, save the size for later checking and start again.
measuredSize_ = Offset();
mode = MODE_WRITE;
*ptr = writePtr;
ptrStart_ = writePtr;
}

bool PointerWrap::CheckAfterWrite() {
_assert_(error != ERROR_NONE || mode == MODE_WRITE);
size_t offset = Offset();
if (measuredSize_ != 0 && offset != measuredSize_) {
WARN_LOG(SAVESTATE, "CheckAfterWrite: Size mismatch! %d but expected %d", (int)offset, (int)measuredSize_);
return false;
}
if (!checkpoints_.empty() && curCheckpoint_ != checkpoints_.size()) {
WARN_LOG(SAVESTATE, "Checkpoint count mismatch!");
return false;
}
return true;
}

PointerWrapSection PointerWrap::Section(const char *title, int minVer, int ver) {
Expand All @@ -44,6 +63,32 @@ PointerWrapSection PointerWrap::Section(const char *title, int minVer, int ver)
// This is strncpy because we rely on its weird non-null-terminating zero-filling truncation behaviour.
// Can't replace it with the more sensible truncate_cpy because that would break savestates.
strncpy(marker, title, sizeof(marker));

// Compare the measure and write passes. Sanity check to catch bugs, doesn't do anything for output.
size_t offset = Offset();
if (mode == MODE_MEASURE) {
checkpoints_.emplace_back(marker, offset);
} else if (mode == MODE_WRITE) {
if (!checkpoints_.empty()) {
if (checkpoints_.size() <= curCheckpoint_) {
WARN_LOG(SAVESTATE, "Write: Not enough checkpoints from measure pass (%d). cur section: %s", (int)checkpoints_.size(), title);
SetError(ERROR_FAILURE);
return PointerWrapSection(*this, -1, title);
}
if (!checkpoints_[curCheckpoint_].Matches(marker, offset)) {
WARN_LOG(SAVESTATE, "Checkpoint mismatch during write! Section %s but expected %s, offset %d but expected %d", title, marker, offset, (int)checkpoints_[curCheckpoint_].offset);
if (curCheckpoint_ > 1) {
WARN_LOG(SAVESTATE, "Previous checkpoint: %s (%d)", checkpoints_[curCheckpoint_ - 1].title, (int)checkpoints_[curCheckpoint_ - 1].offset);
}
SetError(ERROR_FAILURE);
return PointerWrapSection(*this, -1, title);
}
} else {
WARN_LOG(SAVESTATE, "Writing savestate without checkpoints. This is OK but should be fixed.");
}
curCheckpoint_++;
}

if (!ExpectVoid(marker, sizeof(marker))) {
// Might be before we added name markers for safety.
if (foundVersion == 1 && ExpectVoid(&foundVersion, sizeof(foundVersion))) {
Expand All @@ -60,8 +105,10 @@ PointerWrapSection PointerWrap::Section(const char *title, int minVer, int ver)
if (!firstBadSectionTitle_) {
firstBadSectionTitle_ = title;
}
WARN_LOG(SAVESTATE, "Savestate failure: wrong version %d found for section '%s'", foundVersion, title);
SetError(ERROR_FAILURE);
if (mode != MODE_NOOP) {
WARN_LOG(SAVESTATE, "Savestate failure: wrong version %d found for section '%s'", foundVersion, title);
SetError(ERROR_FAILURE);
}
return PointerWrapSection(*this, -1, title);
}
return PointerWrapSection(*this, foundVersion, title);
Expand All @@ -72,8 +119,9 @@ void PointerWrap::SetError(Error error_) {
error = error_;
}
if (error > ERROR_WARNING) {
// For the rest of this run, just measure.
mode = PointerWrap::MODE_MEASURE;
// For the rest of this run, do nothing, to avoid running off the end of memory or something,
// and also not logspam like MEASURE will do in an error case.
mode = PointerWrap::MODE_NOOP;
}
}

Expand Down
95 changes: 74 additions & 21 deletions Common/Serialize/Serializer.h
Expand Up @@ -30,6 +30,7 @@
// - Serialization code for anything complex has to be manually written.

#include <string>
#include <cstring>
#include <vector>
#include <cstdlib>

Expand All @@ -52,8 +53,7 @@ class PointerWrap;
class PointerWrapSection
{
public:
PointerWrapSection(PointerWrap &p, int ver, const char *title) : p_(p), ver_(ver), title_(title) {
}
PointerWrapSection(PointerWrap &p, int ver, const char *title) : p_(p), ver_(ver), title_(title) {}
~PointerWrapSection();

bool operator == (const int &v) const { return ver_ == v; }
Expand All @@ -73,15 +73,32 @@ class PointerWrapSection
const char *title_;
};

// For measure vs write detailed verification
struct SerializeCheckpoint {
char title[17]; // 16-byte section header, plus a zero terminator for debug printing.
size_t offset;

SerializeCheckpoint(char _title[16], size_t off) {
memcpy(title, _title, 16);
title[16] = 0;
offset = off;
}

bool Matches(const char *_title, size_t off) const {
return memcmp(title, _title, 16) == 0 && off == offset;
}
};

// Wrapper class
class PointerWrap
{
public:
enum Mode {
MODE_READ = 1, // load
MODE_WRITE, // save
MODE_MEASURE, // calculate size
MODE_VERIFY, // compare
MODE_WRITE, // save
MODE_MEASURE, // calculate size
MODE_VERIFY, // compare
MODE_NOOP, // don't do anything. Useful to cleanly doing stuff once we've hit an error.
};

enum Error {
Expand All @@ -94,19 +111,26 @@ class PointerWrap
Mode mode;
Error error = ERROR_NONE;

PointerWrap(u8 **ptr_, Mode mode_) : ptr(ptr_), mode(mode_) {}
PointerWrap(unsigned char **ptr_, int mode_) : ptr((u8**)ptr_), mode((Mode)mode_) {}
PointerWrap(u8 **ptr_, Mode mode_) : ptr(ptr_), ptrStart_(*ptr), mode(mode_) {
if (mode == MODE_MEASURE) {
checkpoints_.reserve(750);
}
}

PointerWrapSection Section(const char *title, int ver);
void RewindForWrite(u8 *writePtr);
bool CheckAfterWrite();

// The returned object can be compared against the version that was loaded.
// This can be used to support versions as old as minVer.
// Version = 0 means the section was not found.
PointerWrapSection Section(const char *title, int minVer, int ver);
PointerWrapSection Section(const char *title, int ver) {
return Section(title, ver, ver);
}

void SetMode(Mode mode_) {mode = mode_;}
Mode GetMode() const {return mode;}
u8 **GetPPtr() {return ptr;}
void SetMode(Mode mode_) { mode = mode_; }
Mode GetMode() const { return mode; }
u8 **GetPPtr() { return ptr; }
void SetError(Error error_);

const char *GetBadSectionTitle() const {
Expand All @@ -119,8 +143,14 @@ class PointerWrap

void DoMarker(const char *prevName, u32 arbitraryNumber = 0x42);

size_t Offset() const { return *ptr - ptrStart_; }

private:
const char *firstBadSectionTitle_ = nullptr;
u8 *ptrStart_;
std::vector<SerializeCheckpoint> checkpoints_;
size_t curCheckpoint_ = 0;
size_t measuredSize_ = 0;
};

class CChunkFileReader
Expand Down Expand Up @@ -152,7 +182,7 @@ class CChunkFileReader
template<class T>
static size_t MeasurePtr(T &_class)
{
u8 *ptr = 0;
u8 *ptr = nullptr;
PointerWrap p(&ptr, PointerWrap::MODE_MEASURE);
_class.DoState(p);
return (size_t)ptr;
Expand All @@ -166,13 +196,39 @@ class CChunkFileReader
PointerWrap p(&ptr, PointerWrap::MODE_WRITE);
_class.DoState(p);

if (p.error != p.ERROR_FAILURE && (expected_end == ptr || expected_size == 0)) {
if (p.error != PointerWrap::ERROR_FAILURE && (expected_end == ptr || expected_size == 0)) {
return ERROR_NONE;
} else {
return ERROR_BROKEN_STATE;
}
}

template<class T>
static Error MeasureAndSavePtr(T &_class, u8 **saved, size_t *savedSize)
{
u8 *ptr = nullptr;
PointerWrap p(&ptr, PointerWrap::MODE_MEASURE);
_class.DoState(p);
_assert_(p.error == PointerWrap::ERROR_NONE);

size_t measuredSize = p.Offset();
u8 *data = (u8 *)malloc(measuredSize);
if (!data)
return ERROR_BAD_ALLOC;

p.RewindForWrite(data);
_class.DoState(p);

if (p.CheckAfterWrite()) {
*saved = data;
*savedSize = measuredSize;
return ERROR_NONE;
} else {
free(data);
return ERROR_BROKEN_STATE;
}
}

// Load file template
template<class T>
static Error Load(const Path &filename, std::string *gitVersion, T& _class, std::string *failureReason)
Expand All @@ -197,19 +253,16 @@ class CChunkFileReader
template<class T>
static Error Save(const Path &filename, const std::string &title, const char *gitVersion, T& _class)
{
// Get data
size_t const sz = MeasurePtr(_class);
u8 *buffer = (u8 *)malloc(sz);
if (!buffer)
return ERROR_BAD_ALLOC;
Error error = SavePtr(buffer, _class, sz);
u8 *buffer;
size_t sz;
Error error = MeasureAndSavePtr(_class, &buffer, &sz);

// SaveFile takes ownership of buffer
// SaveFile takes ownership of buffer (malloc/free)
if (error == ERROR_NONE)
error = SaveFile(filename, title, gitVersion, buffer, sz);
return error;
}

template <class T>
static Error Verify(T& _class)
{
Expand Down
2 changes: 1 addition & 1 deletion Core/Dialog/PSPSaveDialog.cpp
Expand Up @@ -1255,7 +1255,7 @@ void PSPSaveDialog::DoState(PointerWrap &p) {
// Just reset it.
bool hasParam = param.GetPspParam() != NULL;
Do(p, hasParam);
if (hasParam) {
if (hasParam && p.mode == p.MODE_READ) {
param.SetPspParam(&request);
}
Do(p, requestAddr);
Expand Down
4 changes: 2 additions & 2 deletions Core/Dialog/SavedataParam.cpp
Expand Up @@ -1890,13 +1890,13 @@ void SavedataParam::DoState(PointerWrap &p) {
Do(p, saveDataListCount);
Do(p, saveNameListDataCount);
if (p.mode == p.MODE_READ) {
if (saveDataList != NULL)
if (saveDataList)
delete [] saveDataList;
if (saveDataListCount != 0) {
saveDataList = new SaveFileInfo[saveDataListCount];
DoArray(p, saveDataList, saveDataListCount);
} else {
saveDataList = NULL;
saveDataList = nullptr;
}
}
else
Expand Down

0 comments on commit 7dee26e

Please sign in to comment.