diff --git a/src/hyperlight_host/benches/benchmarks.rs b/src/hyperlight_host/benches/benchmarks.rs index e3c62cf8d..536e745be 100644 --- a/src/hyperlight_host/benches/benchmarks.rs +++ b/src/hyperlight_host/benches/benchmarks.rs @@ -172,7 +172,7 @@ fn bench_guest_call_with_restore(b: &mut criterion::Bencher, size: SandboxSize) b.iter(|| { sbox.call::("Echo", "hello\n".to_string()).unwrap(); - sbox.restore(&snapshot).unwrap(); + sbox.restore(snapshot.clone()).unwrap(); }); } @@ -340,7 +340,7 @@ fn bench_snapshot_restore(b: &mut criterion::Bencher, size: SandboxSize) { // Measure only the restore time let start = Instant::now(); - sbox.restore(&snapshot).unwrap(); + sbox.restore(snapshot.clone()).unwrap(); total_duration += start.elapsed(); } diff --git a/src/hyperlight_host/src/error.rs b/src/hyperlight_host/src/error.rs index dea5ca587..97b6ccc19 100644 --- a/src/hyperlight_host/src/error.rs +++ b/src/hyperlight_host/src/error.rs @@ -32,7 +32,7 @@ use thiserror::Error; #[cfg(target_os = "windows")] use crate::hypervisor::wrappers::HandleWrapper; -use crate::mem::memory_region::MemoryRegionFlags; +use crate::mem::memory_region::{MemoryRegion, MemoryRegionFlags}; use crate::mem::ptr::RawPtr; /// The error type for Hyperlight operations @@ -148,6 +148,10 @@ pub enum HyperlightError { #[error("Memory Protection Failed with OS Error {0:?}.")] MemoryProtectionFailed(Option), + /// Memory region size mismatch + #[error("Memory region size mismatch: host size {0:?}, guest size {1:?} region {2:?}")] + MemoryRegionSizeMismatch(usize, usize, MemoryRegion), + /// The memory request exceeds the maximum size allowed #[error("Memory requested {0} exceeds maximum size allowed {1}")] MemoryRequestTooBig(usize, usize), @@ -222,6 +226,10 @@ pub enum HyperlightError { #[error("Failed To Convert Return Value {0:?} to {1:?}")] ReturnValueConversionFailure(ReturnValue, &'static str), + /// Attempted to process a snapshot but the snapshot size does not match the current memory size + #[error("Snapshot Size Mismatch: Memory Size {0:?} Snapshot Size {1:?}")] + SnapshotSizeMismatch(usize, usize), + /// Stack overflow detected in guest #[error("Stack overflow detected")] StackOverflow(), @@ -322,7 +330,9 @@ impl HyperlightError { | HyperlightError::PoisonedSandbox | HyperlightError::ExecutionAccessViolation(_) | HyperlightError::StackOverflow() - | HyperlightError::MemoryAccessViolation(_, _, _) => true, + | HyperlightError::MemoryAccessViolation(_, _, _) + | HyperlightError::SnapshotSizeMismatch(_, _) + | HyperlightError::MemoryRegionSizeMismatch(_, _, _) => true, // All other errors do not poison the sandbox. HyperlightError::AnyhowError(_) diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index 55d5fc2af..4d7c366d9 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -33,8 +33,8 @@ use super::memory_region::{DEFAULT_GUEST_BLOB_MEM_FLAGS, MemoryRegionType}; use super::ptr::{GuestPtr, RawPtr}; use super::ptr_offset::Offset; use super::shared_mem::{ExclusiveSharedMemory, GuestSharedMemory, HostSharedMemory, SharedMemory}; -use super::shared_mem_snapshot::SharedMemorySnapshot; use crate::sandbox::SandboxConfiguration; +use crate::sandbox::snapshot::Snapshot; use crate::sandbox::uninitialized::GuestBlob; use crate::{Result, log_then_return, new_error}; @@ -285,20 +285,13 @@ where &mut self, sandbox_id: u64, mapped_regions: Vec, - ) -> Result { - SharedMemorySnapshot::new(&mut self.shared_mem, sandbox_id, mapped_regions) + ) -> Result { + Snapshot::new(&mut self.shared_mem, sandbox_id, mapped_regions) } /// This function restores a memory snapshot from a given snapshot. - pub(crate) fn restore_snapshot(&mut self, snapshot: &SharedMemorySnapshot) -> Result<()> { - if self.shared_mem.mem_size() != snapshot.mem_size() { - return Err(new_error!( - "Snapshot size does not match current memory size: {} != {}", - self.shared_mem.raw_mem_size(), - snapshot.mem_size() - )); - } - snapshot.restore_from_snapshot(&mut self.shared_mem)?; + pub(crate) fn restore_snapshot(&mut self, snapshot: &Snapshot) -> Result<()> { + self.shared_mem.restore_from_snapshot(snapshot)?; Ok(()) } } diff --git a/src/hyperlight_host/src/mem/mod.rs b/src/hyperlight_host/src/mem/mod.rs index 1bcc03eae..afc3577dd 100644 --- a/src/hyperlight_host/src/mem/mod.rs +++ b/src/hyperlight_host/src/mem/mod.rs @@ -35,9 +35,6 @@ pub mod ptr_offset; /// A wrapper around unsafe functionality to create and initialize /// a memory region for a guest running in a sandbox. pub mod shared_mem; -/// A wrapper around a `SharedMemory` and a snapshot in time -/// of the memory therein -pub mod shared_mem_snapshot; /// Utilities for writing shared memory tests #[cfg(test)] pub(crate) mod shared_mem_tests; diff --git a/src/hyperlight_host/src/mem/shared_mem.rs b/src/hyperlight_host/src/mem/shared_mem.rs index 030c2c958..526e9fea2 100644 --- a/src/hyperlight_host/src/mem/shared_mem.rs +++ b/src/hyperlight_host/src/mem/shared_mem.rs @@ -37,8 +37,10 @@ use windows::core::PCSTR; #[cfg(target_os = "windows")] use crate::HyperlightError::MemoryAllocationFailed; +use crate::HyperlightError::SnapshotSizeMismatch; #[cfg(target_os = "windows")] use crate::HyperlightError::{MemoryRequestTooBig, WindowsAPIError}; +use crate::sandbox::snapshot::Snapshot; use crate::{Result, log_then_return, new_error}; /// Makes sure that the given `offset` and `size` are within the bounds of the memory with size `mem_size`. @@ -675,6 +677,14 @@ pub trait SharedMemory { &mut self, f: F, ) -> Result; + + /// Restore a SharedMemory from a snapshot with matching size + fn restore_from_snapshot(&mut self, snapshot: &Snapshot) -> Result<()> { + if snapshot.memory().len() != self.mem_size() { + return Err(SnapshotSizeMismatch(self.mem_size(), snapshot.mem_size())); + } + self.with_exclusivity(|e| e.copy_from_slice(snapshot.memory(), 0))? + } } impl SharedMemory for ExclusiveSharedMemory { diff --git a/src/hyperlight_host/src/mem/shared_mem_snapshot.rs b/src/hyperlight_host/src/mem/shared_mem_snapshot.rs deleted file mode 100644 index dd44422df..000000000 --- a/src/hyperlight_host/src/mem/shared_mem_snapshot.rs +++ /dev/null @@ -1,139 +0,0 @@ -/* -Copyright 2025 The Hyperlight Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -use tracing::{Span, instrument}; - -use super::memory_region::MemoryRegion; -use super::shared_mem::SharedMemory; -use crate::Result; - -/// A wrapper around a `SharedMemory` reference and a snapshot -/// of the memory therein -#[derive(Clone)] -pub(crate) struct SharedMemorySnapshot { - // Unique ID of the sandbox this snapshot was taken from - sandbox_id: u64, - // Memory of the sandbox at the time this snapshot was taken - snapshot: Vec, - /// The memory regions that were mapped when this snapshot was taken (excluding initial sandbox regions) - regions: Vec, -} - -impl SharedMemorySnapshot { - /// Take a snapshot of the memory in `shared_mem`, then create a new - /// instance of `Self` with the snapshot stored therein. - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - pub(super) fn new( - shared_mem: &mut S, - sandbox_id: u64, - regions: Vec, - ) -> Result { - // TODO: Track dirty pages instead of copying entire memory - let snapshot = shared_mem.with_exclusivity(|e| e.copy_all_to_vec())??; - Ok(Self { - sandbox_id, - snapshot, - regions, - }) - } - - /// Copy the memory from the internally-stored memory snapshot - /// into the internally-stored `SharedMemory`. - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - pub(super) fn restore_from_snapshot(&self, shared_mem: &mut S) -> Result<()> { - shared_mem.with_exclusivity(|e| e.copy_from_slice(self.snapshot.as_slice(), 0))??; - Ok(()) - } - - /// The id of the sandbox this snapshot was taken from. - pub(crate) fn sandbox_id(&self) -> u64 { - self.sandbox_id - } - - /// Get the mapped regions from this snapshot - pub(crate) fn regions(&self) -> &[MemoryRegion] { - &self.regions - } - - /// Return the size of the snapshot in bytes. - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub(super) fn mem_size(&self) -> usize { - self.snapshot.len() - } -} - -#[cfg(test)] -mod tests { - use hyperlight_common::mem::PAGE_SIZE_USIZE; - - use crate::mem::shared_mem::ExclusiveSharedMemory; - - #[test] - fn restore() { - // Simplified version of the original test - let data1 = vec![b'a'; PAGE_SIZE_USIZE]; - let data2 = vec![b'b'; PAGE_SIZE_USIZE]; - - let mut gm = ExclusiveSharedMemory::new(PAGE_SIZE_USIZE).unwrap(); - gm.copy_from_slice(&data1, 0).unwrap(); - - // Take snapshot of data1 - let snapshot = super::SharedMemorySnapshot::new(&mut gm, 0, Vec::new()).unwrap(); - - // Modify memory to data2 - gm.copy_from_slice(&data2, 0).unwrap(); - assert_eq!(gm.as_slice(), &data2[..]); - - // Restore should bring back data1 - snapshot.restore_from_snapshot(&mut gm).unwrap(); - assert_eq!(gm.as_slice(), &data1[..]); - } - - #[test] - fn snapshot_mem_size() { - let size = PAGE_SIZE_USIZE * 2; - let mut gm = ExclusiveSharedMemory::new(size).unwrap(); - - let snapshot = super::SharedMemorySnapshot::new(&mut gm, 0, Vec::new()).unwrap(); - assert_eq!(snapshot.mem_size(), size); - } - - #[test] - fn multiple_snapshots_independent() { - let mut gm = ExclusiveSharedMemory::new(PAGE_SIZE_USIZE).unwrap(); - - // Create first snapshot with pattern A - let pattern_a = vec![0xAA; PAGE_SIZE_USIZE]; - gm.copy_from_slice(&pattern_a, 0).unwrap(); - let snapshot_a = super::SharedMemorySnapshot::new(&mut gm, 1, Vec::new()).unwrap(); - - // Create second snapshot with pattern B - let pattern_b = vec![0xBB; PAGE_SIZE_USIZE]; - gm.copy_from_slice(&pattern_b, 0).unwrap(); - let snapshot_b = super::SharedMemorySnapshot::new(&mut gm, 2, Vec::new()).unwrap(); - - // Clear memory - gm.copy_from_slice(&[0; PAGE_SIZE_USIZE], 0).unwrap(); - - // Restore snapshot A - snapshot_a.restore_from_snapshot(&mut gm).unwrap(); - assert_eq!(gm.as_slice(), &pattern_a[..]); - - // Restore snapshot B - snapshot_b.restore_from_snapshot(&mut gm).unwrap(); - assert_eq!(gm.as_slice(), &pattern_b[..]); - } -} diff --git a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs index d40ad9b2b..5f5c7b7b8 100644 --- a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs +++ b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs @@ -104,7 +104,7 @@ pub struct MultiUseSandbox { dbg_mem_access_fn: Arc>>, /// If the current state of the sandbox has been captured in a snapshot, /// that snapshot is stored here. - snapshot: Option, + snapshot: Option>, } impl MultiUseSandbox { @@ -163,7 +163,7 @@ impl MultiUseSandbox { /// # } /// ``` #[instrument(err(Debug), skip_all, parent = Span::current())] - pub fn snapshot(&mut self) -> Result { + pub fn snapshot(&mut self) -> Result> { if self.poisoned { return Err(crate::HyperlightError::PoisonedSandbox); } @@ -174,8 +174,7 @@ impl MultiUseSandbox { let mapped_regions_iter = self.vm.get_mapped_regions(); let mapped_regions_vec: Vec = mapped_regions_iter.cloned().collect(); let memory_snapshot = self.mem_mgr.snapshot(self.id, mapped_regions_vec)?; - let inner = Arc::new(memory_snapshot); - let snapshot = Snapshot { inner }; + let snapshot = Arc::new(memory_snapshot); self.snapshot = Some(snapshot.clone()); Ok(snapshot) } @@ -221,7 +220,7 @@ impl MultiUseSandbox { /// assert_eq!(value, 100); /// /// // Restore to previous state (same sandbox) - /// sandbox.restore(&snapshot)?; + /// sandbox.restore(snapshot)?; /// let restored_value: i32 = sandbox.call_guest_function_by_name("GetValue", ())?; /// assert_eq!(restored_value, 0); // Back to initial state /// # Ok(()) @@ -246,7 +245,7 @@ impl MultiUseSandbox { /// if result.is_err() { /// if sandbox.poisoned() { /// // Restore from snapshot to clear poison - /// sandbox.restore(&snapshot)?; + /// sandbox.restore(snapshot.clone())?; /// assert!(!sandbox.poisoned()); /// /// // Sandbox is now usable again @@ -257,22 +256,22 @@ impl MultiUseSandbox { /// # } /// ``` #[instrument(err(Debug), skip_all, parent = Span::current())] - pub fn restore(&mut self, snapshot: &Snapshot) -> Result<()> { + pub fn restore(&mut self, snapshot: Arc) -> Result<()> { if let Some(snap) = &self.snapshot - && Arc::ptr_eq(&snap.inner, &snapshot.inner) + && snap.as_ref() == snapshot.as_ref() { // If the snapshot is already the current one, no need to restore return Ok(()); } - if self.id != snapshot.inner.sandbox_id() { + if self.id != snapshot.sandbox_id() { return Err(SnapshotSandboxMismatch); } - self.mem_mgr.restore_snapshot(&snapshot.inner)?; + self.mem_mgr.restore_snapshot(&snapshot)?; let current_regions: HashSet<_> = self.vm.get_mapped_regions().cloned().collect(); - let snapshot_regions: HashSet<_> = snapshot.inner.regions().iter().cloned().collect(); + let snapshot_regions: HashSet<_> = snapshot.regions().iter().cloned().collect(); let regions_to_unmap = current_regions.difference(&snapshot_regions); let regions_to_map = snapshot_regions.difference(¤t_regions); @@ -356,7 +355,7 @@ impl MultiUseSandbox { } let snapshot = self.snapshot()?; let res = self.call(func_name, args); - self.restore(&snapshot)?; + self.restore(snapshot)?; res } @@ -430,7 +429,7 @@ impl MultiUseSandbox { /// /// if sandbox.poisoned() { /// eprintln!("Sandbox was poisoned, restoring from snapshot"); - /// sandbox.restore(&snapshot)?; + /// sandbox.restore(snapshot.clone())?; /// } /// } /// # Ok(()) @@ -857,7 +856,7 @@ mod tests { assert!(matches!(res, HyperlightError::PoisonedSandbox)); // restore to non-poisoned snapshot should work and clear poison - sbox.restore(&snapshot).unwrap(); + sbox.restore(snapshot.clone()).unwrap(); assert!(!sbox.poisoned()); // guest calls should work again after restore @@ -875,7 +874,7 @@ mod tests { assert!(sbox.poisoned()); // restore to non-poisoned snapshot should work again - sbox.restore(&snapshot).unwrap(); + sbox.restore(snapshot.clone()).unwrap(); assert!(!sbox.poisoned()); // guest calls should work again @@ -963,7 +962,7 @@ mod tests { let res: i32 = sbox.call("GetStatic", ()).unwrap(); assert_eq!(res, 5); - sbox.restore(&snapshot).unwrap(); + sbox.restore(snapshot).unwrap(); #[allow(deprecated)] let _ = sbox .call_guest_function_by_name::("AddToStatic", 5i32) @@ -1027,7 +1026,7 @@ mod tests { let res: i32 = sbox.call("GetStatic", ()).unwrap(); assert_eq!(res, 5); - sbox.restore(&snapshot).unwrap(); + sbox.restore(snapshot).unwrap(); let res: i32 = sbox.call("GetStatic", ()).unwrap(); assert_eq!(res, 0); } @@ -1244,11 +1243,11 @@ mod tests { assert_eq!(sbox.vm.get_mapped_regions().count(), 1); // 4. Restore to snapshot 1 (should unmap the region) - sbox.restore(&snapshot1).unwrap(); + sbox.restore(snapshot1.clone()).unwrap(); assert_eq!(sbox.vm.get_mapped_regions().count(), 0); // 5. Restore forward to snapshot 2 (should remap the region) - sbox.restore(&snapshot2).unwrap(); + sbox.restore(snapshot2.clone()).unwrap(); assert_eq!(sbox.vm.get_mapped_regions().count(), 1); // Verify the region is the same @@ -1282,7 +1281,7 @@ mod tests { assert_ne!(sandbox.id, sandbox2.id); let snapshot = sandbox.snapshot().unwrap(); - let err = sandbox2.restore(&snapshot); + let err = sandbox2.restore(snapshot.clone()); assert!(matches!(err, Err(HyperlightError::SnapshotSandboxMismatch))); let sandbox_id = sandbox.id; diff --git a/src/hyperlight_host/src/sandbox/snapshot.rs b/src/hyperlight_host/src/sandbox/snapshot.rs index c00aa4487..b8be04732 100644 --- a/src/hyperlight_host/src/sandbox/snapshot.rs +++ b/src/hyperlight_host/src/sandbox/snapshot.rs @@ -14,12 +14,157 @@ See the License for the specific language governing permissions and limitations under the License. */ -use std::sync::Arc; +use tracing::{Span, instrument}; -use crate::mem::shared_mem_snapshot::SharedMemorySnapshot; +use crate::HyperlightError::MemoryRegionSizeMismatch; +use crate::Result; +use crate::mem::memory_region::MemoryRegion; +use crate::mem::shared_mem::SharedMemory; -/// A snapshot capturing the state of the memory in a `MultiUseSandbox`. -#[derive(Clone)] +/// A wrapper around a `SharedMemory` reference and a snapshot +/// of the memory therein pub struct Snapshot { - pub(crate) inner: Arc, + // Unique ID of the sandbox this snapshot was taken from + sandbox_id: u64, + // Memory of the sandbox at the time this snapshot was taken + memory: Vec, + /// The memory regions that were mapped when this snapshot was taken (excluding initial sandbox regions) + regions: Vec, + /// The hash of the other portions of the snapshot. Morally, this + /// is just a memoization cache for [`hash`], below, but it is not + /// a [`std::sync::OnceLock`] because it may be persisted to disk + /// without being recomputed on load. + /// + /// It is not a [`blake3::Hash`] because we do not presently + /// require constant-time equality checking + hash: [u8; 32], +} + +fn hash(memory: &[u8], regions: &[MemoryRegion]) -> Result<[u8; 32]> { + let mut hasher = blake3::Hasher::new(); + hasher.update(memory); + for rgn in regions { + hasher.update(&usize::to_le_bytes(rgn.guest_region.start)); + let guest_len = rgn.guest_region.end - rgn.guest_region.start; + hasher.update(&usize::to_le_bytes(rgn.host_region.start)); + let host_len = rgn.host_region.end - rgn.host_region.start; + if guest_len != host_len { + return Err(MemoryRegionSizeMismatch(host_len, guest_len, rgn.clone())); + } + hasher.update(&usize::to_le_bytes(guest_len)); + hasher.update(&u32::to_le_bytes(rgn.flags.bits())); + } + Ok(hasher.finalize().into()) +} + +impl Snapshot { + /// Take a snapshot of the memory in `shared_mem`, then create a new + /// instance of `Self` with the snapshot stored therein. + #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] + pub(crate) fn new( + shared_mem: &mut S, + sandbox_id: u64, + regions: Vec, + ) -> Result { + // TODO: Track dirty pages instead of copying entire memory + let memory = shared_mem.with_exclusivity(|e| e.copy_all_to_vec())??; + let hash = hash(&memory, ®ions)?; + Ok(Self { + sandbox_id, + memory, + regions, + hash, + }) + } + + /// The id of the sandbox this snapshot was taken from. + pub(crate) fn sandbox_id(&self) -> u64 { + self.sandbox_id + } + + /// Get the mapped regions from this snapshot + pub(crate) fn regions(&self) -> &[MemoryRegion] { + &self.regions + } + + /// Return the size of the snapshot in bytes. + #[instrument(skip_all, parent = Span::current(), level= "Trace")] + pub(crate) fn mem_size(&self) -> usize { + self.memory.len() + } + + /// Return the main memory contents of the snapshot + #[instrument(skip_all, parent = Span::current(), level= "Trace")] + pub(crate) fn memory(&self) -> &[u8] { + &self.memory + } +} + +impl PartialEq for Snapshot { + fn eq(&self, other: &Snapshot) -> bool { + self.hash == other.hash + } +} + +#[cfg(test)] +mod tests { + use hyperlight_common::mem::PAGE_SIZE_USIZE; + + use crate::mem::shared_mem::{ExclusiveSharedMemory, SharedMemory}; + + #[test] + fn restore() { + // Simplified version of the original test + let data1 = vec![b'a'; PAGE_SIZE_USIZE]; + let data2 = vec![b'b'; PAGE_SIZE_USIZE]; + + let mut gm = ExclusiveSharedMemory::new(PAGE_SIZE_USIZE).unwrap(); + gm.copy_from_slice(&data1, 0).unwrap(); + + // Take snapshot of data1 + let snapshot = super::Snapshot::new(&mut gm, 0, Vec::new()).unwrap(); + + // Modify memory to data2 + gm.copy_from_slice(&data2, 0).unwrap(); + assert_eq!(gm.as_slice(), &data2[..]); + + // Restore should bring back data1 + gm.restore_from_snapshot(&snapshot).unwrap(); + assert_eq!(gm.as_slice(), &data1[..]); + } + + #[test] + fn snapshot_mem_size() { + let size = PAGE_SIZE_USIZE * 2; + let mut gm = ExclusiveSharedMemory::new(size).unwrap(); + + let snapshot = super::Snapshot::new(&mut gm, 0, Vec::new()).unwrap(); + assert_eq!(snapshot.mem_size(), size); + } + + #[test] + fn multiple_snapshots_independent() { + let mut gm = ExclusiveSharedMemory::new(PAGE_SIZE_USIZE).unwrap(); + + // Create first snapshot with pattern A + let pattern_a = vec![0xAA; PAGE_SIZE_USIZE]; + gm.copy_from_slice(&pattern_a, 0).unwrap(); + let snapshot_a = super::Snapshot::new(&mut gm, 1, Vec::new()).unwrap(); + + // Create second snapshot with pattern B + let pattern_b = vec![0xBB; PAGE_SIZE_USIZE]; + gm.copy_from_slice(&pattern_b, 0).unwrap(); + let snapshot_b = super::Snapshot::new(&mut gm, 2, Vec::new()).unwrap(); + + // Clear memory + gm.copy_from_slice(&[0; PAGE_SIZE_USIZE], 0).unwrap(); + + // Restore snapshot A + gm.restore_from_snapshot(&snapshot_a).unwrap(); + assert_eq!(gm.as_slice(), &pattern_a[..]); + + // Restore snapshot B + gm.restore_from_snapshot(&snapshot_b).unwrap(); + assert_eq!(gm.as_slice(), &pattern_b[..]); + } } diff --git a/src/hyperlight_host/tests/integration_test.rs b/src/hyperlight_host/tests/integration_test.rs index 2827acc60..c5e2d2a8a 100644 --- a/src/hyperlight_host/tests/integration_test.rs +++ b/src/hyperlight_host/tests/integration_test.rs @@ -68,7 +68,7 @@ fn interrupt_host_call() { assert!(sandbox.poisoned()); // Restore from snapshot to clear poison - sandbox.restore(&snapshot).unwrap(); + sandbox.restore(snapshot.clone()).unwrap(); assert!(!sandbox.poisoned()); thread.join().unwrap(); @@ -98,7 +98,7 @@ fn interrupt_in_progress_guest_call() { assert!(sbox1.poisoned()); // Restore from snapshot to clear poison - sbox1.restore(&snapshot).unwrap(); + sbox1.restore(snapshot.clone()).unwrap(); assert!(!sbox1.poisoned()); barrier.wait(); @@ -191,7 +191,7 @@ fn interrupt_same_thread() { _ => panic!("Unexpected return"), }; if sbox2.poisoned() { - sbox2.restore(&snapshot2).unwrap(); + sbox2.restore(snapshot2.clone()).unwrap(); } sbox3 .call::("Echo", "hello".to_string()) @@ -238,7 +238,7 @@ fn interrupt_same_thread_no_barrier() { _ => panic!("Unexpected return"), }; if sbox2.poisoned() { - sbox2.restore(&snapshot2).unwrap(); + sbox2.restore(snapshot2.clone()).unwrap(); } sbox3 .call::("Echo", "hello".to_string()) @@ -267,7 +267,7 @@ fn interrupt_moved_sandbox() { let res = sbox1.call::("Spin", ()).unwrap_err(); assert!(matches!(res, HyperlightError::ExecutionCanceledByHost())); assert!(sbox1.poisoned()); - sbox1.restore(&snapshot1).unwrap(); + sbox1.restore(snapshot1.clone()).unwrap(); assert!(!sbox1.poisoned()); }); @@ -327,7 +327,7 @@ fn interrupt_custom_signal_no_and_retry_delay() { assert!(sbox1.poisoned()); // immediately reenter another guest function call after having being cancelled, // so that the vcpu is running again before the interruptor-thread has a chance to see that the vcpu is not running - sbox1.restore(&snapshot1).unwrap(); + sbox1.restore(snapshot1.clone()).unwrap(); assert!(!sbox1.poisoned()); } thread.join().expect("Thread should finish"); @@ -906,7 +906,7 @@ fn interrupt_random_kill_stress_test() { // Wrapper to hold a sandbox and its snapshot together struct SandboxWithSnapshot { sandbox: MultiUseSandbox, - snapshot: Snapshot, + snapshot: Arc, } use std::collections::VecDeque; @@ -1128,7 +1128,10 @@ fn interrupt_random_kill_stress_test() { assert!(sandbox_wrapper.sandbox.poisoned()); // Try to restore the snapshot - if let Err(e) = sandbox_wrapper.sandbox.restore(&sandbox_wrapper.snapshot) { + if let Err(e) = sandbox_wrapper + .sandbox + .restore(sandbox_wrapper.snapshot.clone()) + { error!( "CRITICAL: Thread {} iteration {}: Failed to restore snapshot: {:?}", thread_id, iteration, e @@ -1442,7 +1445,7 @@ fn interrupt_infinite_loop_stress_test() { } // Restore the sandbox for the next iteration - sandbox.restore(&snapshot).unwrap(); + sandbox.restore(snapshot.clone()).unwrap(); } })); } diff --git a/src/hyperlight_host/tests/sandbox_host_tests.rs b/src/hyperlight_host/tests/sandbox_host_tests.rs index a79ae638c..669f717de 100644 --- a/src/hyperlight_host/tests/sandbox_host_tests.rs +++ b/src/hyperlight_host/tests/sandbox_host_tests.rs @@ -365,7 +365,7 @@ fn host_function_error() -> Result<()> { ); // C guest panics in rust guest lib when host function returns error, which will poison the sandbox if init_sandbox.poisoned() { - init_sandbox.restore(&snapshot)?; + init_sandbox.restore(snapshot.clone())?; } } }