Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ async fn main() -> Result<(), String> {
.expect("could not configure network devices");

info!("Configuring sriov...");
let num_of_vfs: u32 = 6;
configure_sriov(num_of_vfs)
.await
.expect("could not configure sriov");
const VFS_NUM: u32 = 6;
if let Err(e) = configure_sriov(VFS_NUM).await {
warn!("failed to configure sriov: {}", e.to_string())
}

let vmm = vm::Manager::new(String::from("cloud-hypervisor"));

Expand Down
126 changes: 115 additions & 11 deletions src/network.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use std::io;

use crate::dhcpv6::*;
use futures::stream::TryStreamExt;
use log::{info, warn};
use rtnetlink::new_connection;
use tokio::fs::OpenOptions;
use tokio::io::AsyncWriteExt;
use tokio::time::{self, Duration};
use tokio::fs::{read_link, OpenOptions};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::time::{self, sleep, Duration};

use pnet::datalink::{self, Channel::Ethernet, Config};
use pnet::packet::ethernet::EthernetPacket;
Expand Down Expand Up @@ -140,24 +142,126 @@ fn format_mac(bytes: Vec<u8>) -> String {
}

pub async fn configure_sriov(num_vfs: u32) -> Result<(), String> {
let file_path = format!("/sys/class/net/{}/device/sriov_numvfs", INTERFACE_NAME);

let result = OpenOptions::new().write(true).open(&file_path).await;
let base_path = format!("/sys/class/net/{}/device", INTERFACE_NAME);

let mut file = match result {
Ok(file) => file,
Err(e) => return Err(format!("Failed to open the file: {}", e)),
};
let file_path = format!("{}/sriov_numvfs", base_path);
let mut file = OpenOptions::new()
.write(true)
.open(&file_path)
.await
.map_err(|e| e.to_string())?;

let value = format!("{}\n", num_vfs);
if let Err(e) = file.write_all(value.as_bytes()).await {
return Err(format!("Failed to write to the file: {}", e));
}
info!("Created {} sriov virtual functions", num_vfs);

let device_path = read_link(base_path).await.map_err(|e| e.to_string())?;
let pci_address = device_path
.file_name()
.ok_or("No PCI address found".to_string())?;
let pci_address = pci_address.to_str().ok_or("No PCI address found")?;

info!("Found PCI address of {}: {}", INTERFACE_NAME, pci_address);

let sriov_offset = get_device_information(pci_address, "sriov_offset")
.await
.map_err(|e| e.to_string())?;

let sriov_offset = match sriov_offset.parse::<u32>() {
Ok(n) => n,
Err(e) => return Err(e.to_string()),
};

let pci_address_parts: Vec<&str> = pci_address.split(&[':', '.'][..]).collect();
if pci_address_parts.len() != 4 {
return Err(format!("invalid pci address format: {}", pci_address));
}

let virtual_funcs: Vec<String> = (0..num_vfs)
.map(|x| x + sriov_offset)
.map(|x| format!("{}.{}", pci_address_parts[0..3].join(":"), x))
.collect();

const RETRIES: i32 = 3;
for (index, vf) in virtual_funcs.iter().enumerate() {
for i in 1..RETRIES {
info!("try to unbind device {}: {:?}/{}", vf, i, RETRIES);
if let Err(e) = unbind_device(vf).await {
warn!("failed to unbind device {}: {}", vf, e.to_string());
sleep(Duration::from_secs(2)).await;
} else {
info!("successfull unbound device {}", vf);

if let Err(e) = bind_device(index, vf).await {
warn!("failed to bind devices: {}", e.to_string())
}
break;
}
}
}

info!("Successfully wrote to the file.");
Ok(())
}

async fn unbind_device(pci: &str) -> Result<(), io::Error> {
let unbind_path = format!("/sys/bus/pci/devices/{}/driver/unbind", pci);
let mut file = OpenOptions::new().write(true).open(&unbind_path).await?;

file.write_all(pci.as_bytes()).await?;
info!("unbound device: {}", pci);
Ok(())
}

async fn bind_device(index: usize, pci_address: &str) -> Result<(), io::Error> {
info!("try to bind device to vfio: {}", pci_address);
if index == 0 {
vfio_new_id(pci_address).await
} else {
vfio_bind(pci_address).await
}
}

async fn vfio_new_id(pci_address: &str) -> Result<(), io::Error> {
let vendor = get_device_information(pci_address, "vendor").await?;
let vendor = vendor[2..].to_string();

let device = get_device_information(pci_address, "device").await?;
let device = device[2..].to_string();

let mut file = OpenOptions::new()
.write(true)
.open("/sys/bus/pci/drivers/vfio-pci/new_id")
.await?;

let content = format!("{} {}", vendor, device);
file.write_all(content.as_bytes()).await?;
info!("bound devices ({}) to vfio-pci", pci_address);
Ok(())
}

async fn vfio_bind(pci_address: &str) -> Result<(), io::Error> {
let mut file = OpenOptions::new()
.write(true)
.open("/sys/bus/pci/drivers/vfio-pci/bind")
.await?;

file.write_all(pci_address.as_bytes()).await?;
info!("bound devices ({}) to vfio-pci", pci_address);
Ok(())
}

async fn get_device_information(pci: &str, field: &str) -> Result<String, io::Error> {
let path = format!("/sys/bus/pci/devices/{}/{}", pci, field);
let mut file = OpenOptions::new().read(true).open(&path).await?;

let mut dst = String::new();
file.read_to_string(&mut dst).await?;

Ok(dst.trim().to_string())
}

// Print all packets to the console for debugging purposes
async fn _capture_packets(interface_name: String) {
let interfaces = datalink::interfaces();
Expand Down
70 changes: 1 addition & 69 deletions src/vm/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use log::info;
use serde_json::json;
use std::io::{BufRead, BufReader, Write};
use std::time::Duration;
use std::io::{BufRead, BufReader};
use std::{
collections::HashMap,
num::TryFromIntError,
Expand All @@ -12,7 +11,6 @@ use std::{
thread::sleep,
time,
};
use std::{fs, io};
use uuid::Uuid;
use vmm::vm_config;

Expand Down Expand Up @@ -246,8 +244,6 @@ impl Manager {
let mut socket = UnixStream::connect(id.to_string()).map_err(Error::SocketFailure)?;

if let Some(pci) = pci {
self.prepare_device(&pci).map_err(Error::SocketFailure)?;

// Check if the path exists
let path = PathBuf::from(format!("/sys/bus/pci/devices/{}/", pci));
info!("check if path exists {}", path.display());
Expand All @@ -256,9 +252,6 @@ impl Manager {
return Err(Error::NotFound);
}

info!("wait");
sleep(Duration::from_secs(2));

info!("add device");
let device_config = json!(vm_config::DeviceConfig {
path,
Expand Down Expand Up @@ -335,67 +328,6 @@ impl Manager {
Ok(())
}

fn get_vendor(&self, mac: &str) -> Option<String> {
let path = format!("/sys/bus/pci/devices/{}/vendor", mac);
if let Ok(vendor) = fs::read_to_string(path) {
return Some(vendor[2..].trim().to_string());
} else {
None
}
}

fn get_device(&self, mac: &str) -> Option<String> {
let path: String = format!("/sys/bus/pci/devices/{}/device", mac);
if let Ok(device) = fs::read_to_string(path) {
return Some(device[2..].trim().to_string());
} else {
None
}
}

// TODO: move to prepare sriov
fn prepare_device(&self, pci: &str) -> Result<(), io::Error> {
// unbind
// Check if the path exists
let path = format!("/sys/bus/pci/devices/{}/driver/unbind", pci);
let path = Path::new(&path);
if !path.exists() {
info!("UNBIND: The path {} does not exist.", path.display());
} else {
let content = pci.to_string();
info!("try to unbind {}", pci);
let mut file = fs::OpenOptions::new().write(true).open(path)?;

// Write the content to the file
file.write_all(content.as_bytes())?;
info!("unbound");
}

// bind
// Check if the path exists
let path = Path::new("/sys/bus/pci/drivers/vfio-pci/new_id");
if !path.exists() {
info!("BIND:The path {} does not exist.", path.display());
} else {
let vendor = self.get_vendor(pci).unwrap_or_default();
let device = self.get_device(pci).unwrap_or_default();
info!("{} - {}", vendor, device);

let content = format!("{} {}", vendor, device);

let mut file = fs::OpenOptions::new().write(true).open(path)?;

// Write the content to the file
if let Err(e) = file.write_all(content.as_bytes()) {
info!("error {:?}", e);
} else {
info!("bound vfio-pci");
}
}

Ok(())
}

pub fn get_vm(&self, id: Uuid) -> Result<String, Error> {
let vms = self.vms.lock().unwrap();
if !vms.contains_key(&id) {
Expand Down