diff --git a/vmm/src/config.rs b/vmm/src/config.rs index f47fabc2d0..8eca9ab439 100644 --- a/vmm/src/config.rs +++ b/vmm/src/config.rs @@ -201,6 +201,10 @@ pub enum ValidationError { InvalidIoPortHex(String), #[cfg(feature = "sev_snp")] InvalidHostData, + /// Restore expects all net ids that have fds + RestoreMissingRequiredNetId(String), + /// Number of FDs passed during Restore are incorrect to the NetConfig + RestoreNetFdCountMismatch(String, usize, usize), } type ValidationResult = std::result::Result; @@ -343,6 +347,15 @@ impl fmt::Display for ValidationError { InvalidHostData => { write!(f, "Invalid host data format") } + RestoreMissingRequiredNetId(s) => { + write!(f, "Net id {s} is associated with FDs and is required") + } + RestoreNetFdCountMismatch(s, u1, u2) => { + write!( + f, + "Number of Net FDs passed for '{s}' during Restore: {u1}. Expected: {u2}" + ) + } } } } @@ -2130,22 +2143,71 @@ impl NumaConfig { } } +#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize, Default)] +pub struct RestoredNetConfig { + pub id: String, + #[serde(default)] + pub num_fds: usize, + #[serde( + default, + serialize_with = "serialize_restorednetconfig_fds", + deserialize_with = "deserialize_restorednetconfig_fds" + )] + pub fds: Option>, +} + +fn serialize_restorednetconfig_fds( + x: &Option>, + s: S, +) -> std::result::Result +where + S: serde::Serializer, +{ + if let Some(x) = x { + warn!("'RestoredNetConfig' contains FDs that can't be serialized correctly. Serializing them as invalid FDs."); + let invalid_fds = vec![-1; x.len()]; + s.serialize_some(&invalid_fds) + } else { + s.serialize_none() + } +} + +fn deserialize_restorednetconfig_fds<'de, D>( + d: D, +) -> std::result::Result>, D::Error> +where + D: serde::Deserializer<'de>, +{ + let invalid_fds: Option> = Option::deserialize(d)?; + if let Some(invalid_fds) = invalid_fds { + warn!("'RestoredNetConfig' contains FDs that can't be deserialized correctly. Deserializing them as invalid FDs."); + Ok(Some(vec![-1; invalid_fds.len()])) + } else { + Ok(None) + } +} + #[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize, Default)] pub struct RestoreConfig { pub source_url: PathBuf, #[serde(default)] pub prefault: bool, + #[serde(default)] + pub net_fds: Option>, } impl RestoreConfig { pub const SYNTAX: &'static str = "Restore from a VM snapshot. \ - \nRestore parameters \"source_url=,prefault=on|off\" \ + \nRestore parameters \"source_url=,prefault=on|off,\ + net_fds=\" \ \n`source_url` should be a valid URL (e.g file:///foo/bar or tcp://192.168.1.10/foo) \ - \n`prefault` brings memory pages in when enabled (disabled by default)"; + \n`prefault` brings memory pages in when enabled (disabled by default) \ + \n`net_fds` is a list of net ids with new file descriptors. \ + Only net devices backed by FDs directly are needed as input."; pub fn parse(restore: &str) -> Result { let mut parser = OptionParser::new(); - parser.add("source_url").add("prefault"); + parser.add("source_url").add("prefault").add("net_fds"); parser.parse(restore).map_err(Error::ParseRestore)?; let source_url = parser @@ -2157,12 +2219,70 @@ impl RestoreConfig { .map_err(Error::ParseRestore)? .unwrap_or(Toggle(false)) .0; + let net_fds = parser + .convert::>>("net_fds") + .map_err(Error::ParseRestore)? + .map(|v| { + v.0.iter() + .map(|(id, fds)| RestoredNetConfig { + id: id.clone(), + num_fds: fds.len(), + fds: Some(fds.iter().map(|e| *e as i32).collect()), + }) + .collect() + }); Ok(RestoreConfig { source_url, prefault, + net_fds, }) } + + // Ensure all net devices from 'VmConfig' backed by FDs have a + // corresponding 'RestoreNetConfig' with a matched 'id' and expected + // number of FDs. + pub fn validate(&self, vm_config: &VmConfig) -> ValidationResult<()> { + let mut restored_net_with_fds = HashMap::new(); + for n in self.net_fds.iter().flatten() { + assert_eq!( + n.num_fds, + n.fds.as_ref().map_or(0, |f| f.len()), + "Invalid 'RestoredNetConfig' with conflicted fields." + ); + if restored_net_with_fds.insert(n.id.clone(), n).is_some() { + return Err(ValidationError::IdentifierNotUnique(n.id.clone())); + } + } + + for net_fds in vm_config.net.iter().flatten() { + if let Some(expected_fds) = &net_fds.fds { + let expected_id = net_fds + .id + .as_ref() + .expect("Invalid 'NetConfig' with empty 'id' for VM restore."); + if let Some(r) = restored_net_with_fds.remove(expected_id) { + if r.num_fds != expected_fds.len() { + return Err(ValidationError::RestoreNetFdCountMismatch( + expected_id.clone(), + r.num_fds, + expected_fds.len(), + )); + } + } else { + return Err(ValidationError::RestoreMissingRequiredNetId( + expected_id.clone(), + )); + } + } + } + + if !restored_net_with_fds.is_empty() { + warn!("Ignoring unused 'net_fds' for VM restore.") + } + + Ok(()) + } } impl TpmConfig { @@ -3570,6 +3690,183 @@ mod tests { Ok(()) } + #[test] + fn test_restore_parsing() -> Result<()> { + assert_eq!( + RestoreConfig::parse("source_url=/path/to/snapshot")?, + RestoreConfig { + source_url: PathBuf::from("/path/to/snapshot"), + prefault: false, + net_fds: None, + } + ); + assert_eq!( + RestoreConfig::parse( + "source_url=/path/to/snapshot,prefault=off,net_fds=[net0@[3,4],net1@[5,6,7,8]]" + )?, + RestoreConfig { + source_url: PathBuf::from("/path/to/snapshot"), + prefault: false, + net_fds: Some(vec![ + RestoredNetConfig { + id: "net0".to_string(), + num_fds: 2, + fds: Some(vec![3, 4]), + }, + RestoredNetConfig { + id: "net1".to_string(), + num_fds: 4, + fds: Some(vec![5, 6, 7, 8]), + } + ]), + } + ); + // Parsing should fail as source_url is a required field + assert!(RestoreConfig::parse("prefault=off").is_err()); + Ok(()) + } + + #[test] + fn test_restore_config_validation() { + // interested in only VmConfig.net, so set rest to default values + let mut snapshot_vm_config = VmConfig { + cpus: CpusConfig::default(), + memory: MemoryConfig::default(), + payload: None, + rate_limit_groups: None, + disks: None, + rng: RngConfig::default(), + balloon: None, + fs: None, + pmem: None, + serial: default_serial(), + console: default_console(), + #[cfg(target_arch = "x86_64")] + debug_console: DebugConsoleConfig::default(), + devices: None, + user_devices: None, + vdpa: None, + vsock: None, + pvpanic: false, + iommu: false, + #[cfg(target_arch = "x86_64")] + sgx_epc: None, + numa: None, + watchdog: false, + #[cfg(feature = "guest_debug")] + gdb: false, + pci_segments: None, + platform: None, + tpm: None, + preserved_fds: None, + net: Some(vec![ + NetConfig { + id: Some("net0".to_owned()), + num_queues: 2, + fds: Some(vec![-1, -1, -1, -1]), + ..net_fixture() + }, + NetConfig { + id: Some("net1".to_owned()), + num_queues: 1, + fds: Some(vec![-1, -1]), + ..net_fixture() + }, + NetConfig { + id: Some("net2".to_owned()), + fds: None, + ..net_fixture() + }, + ]), + }; + + let valid_config = RestoreConfig { + source_url: PathBuf::from("/path/to/snapshot"), + prefault: false, + net_fds: Some(vec![ + RestoredNetConfig { + id: "net0".to_string(), + num_fds: 4, + fds: Some(vec![3, 4, 5, 6]), + }, + RestoredNetConfig { + id: "net1".to_string(), + num_fds: 2, + fds: Some(vec![7, 8]), + }, + ]), + }; + assert!(valid_config.validate(&snapshot_vm_config).is_ok()); + + let mut invalid_config = valid_config.clone(); + invalid_config.net_fds = Some(vec![RestoredNetConfig { + id: "netx".to_string(), + num_fds: 4, + fds: Some(vec![3, 4, 5, 6]), + }]); + assert_eq!( + invalid_config.validate(&snapshot_vm_config), + Err(ValidationError::RestoreMissingRequiredNetId( + "net0".to_string() + )) + ); + + invalid_config.net_fds = Some(vec![ + RestoredNetConfig { + id: "net0".to_string(), + num_fds: 4, + fds: Some(vec![3, 4, 5, 6]), + }, + RestoredNetConfig { + id: "net0".to_string(), + num_fds: 4, + fds: Some(vec![3, 4, 5, 6]), + }, + ]); + assert_eq!( + invalid_config.validate(&snapshot_vm_config), + Err(ValidationError::IdentifierNotUnique("net0".to_string())) + ); + + invalid_config.net_fds = Some(vec![RestoredNetConfig { + id: "net0".to_string(), + num_fds: 4, + fds: Some(vec![3, 4, 5, 6]), + }]); + assert_eq!( + invalid_config.validate(&snapshot_vm_config), + Err(ValidationError::RestoreMissingRequiredNetId( + "net1".to_string() + )) + ); + + invalid_config.net_fds = Some(vec![RestoredNetConfig { + id: "net0".to_string(), + num_fds: 2, + fds: Some(vec![3, 4]), + }]); + assert_eq!( + invalid_config.validate(&snapshot_vm_config), + Err(ValidationError::RestoreNetFdCountMismatch( + "net0".to_string(), + 2, + 4 + )) + ); + + let another_valid_config = RestoreConfig { + source_url: PathBuf::from("/path/to/snapshot"), + prefault: false, + net_fds: None, + }; + snapshot_vm_config.net = Some(vec![NetConfig { + id: Some("net2".to_owned()), + fds: None, + ..net_fixture() + }]); + assert!(another_valid_config.validate(&snapshot_vm_config).is_ok()); + } + fn platform_fixture() -> PlatformConfig { PlatformConfig { num_pci_segments: MAX_NUM_PCI_SEGMENTS, diff --git a/vmm/src/lib.rs b/vmm/src/lib.rs index c75cce4acc..df76d5d648 100644 --- a/vmm/src/lib.rs +++ b/vmm/src/lib.rs @@ -1321,6 +1321,24 @@ impl RequestHandler for Vmm { let vm_config = Arc::new(Mutex::new( recv_vm_config(source_url).map_err(VmError::Restore)?, )); + restore_cfg + .validate(&vm_config.lock().unwrap().clone()) + .map_err(VmError::ConfigValidation)?; + + // Update VM's net configurations with new fds received for restore operation + if let (Some(restored_nets), Some(vm_net_configs)) = + (restore_cfg.net_fds, &mut vm_config.lock().unwrap().net) + { + for net in restored_nets.iter() { + for net_config in vm_net_configs.iter_mut() { + // update only if the net dev is backed by FDs + if net_config.id == Some(net.id.clone()) && net_config.fds.is_some() { + net_config.fds.clone_from(&net.fds); + } + } + } + } + let snapshot = recv_vm_state(source_url).map_err(VmError::Restore)?; #[cfg(all(feature = "kvm", target_arch = "x86_64"))] let vm_snapshot = get_vm_snapshot(&snapshot).map_err(VmError::Restore)?;